1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-10 15:48:13 -05:00

Merge branch 'main' into jmccannon/ac/pm-16811-scim-invite-optimization

# Conflicts:
#	src/Core/Constants.cs
This commit is contained in:
jrmccannon 2025-02-17 16:35:02 -06:00
commit 649e8b5c0a
No known key found for this signature in database
GPG Key ID: CF03F3DB01CE96A6
114 changed files with 12740 additions and 610 deletions

View File

@ -17,6 +17,7 @@ on:
- "src/Infrastructure.Dapper/**" # Changes to SQL Server Dapper Repository Layer
- "src/Infrastructure.EntityFramework/**" # Changes to Entity Framework Repository Layer
- "test/Infrastructure.IntegrationTest/**" # Any changes to the tests
- "src/**/Entities/**/*.cs" # Database entity definitions
pull_request:
paths:
- ".github/workflows/test-database.yml" # This file
@ -28,6 +29,7 @@ on:
- "src/Infrastructure.Dapper/**" # Changes to SQL Server Dapper Repository Layer
- "src/Infrastructure.EntityFramework/**" # Changes to Entity Framework Repository Layer
- "test/Infrastructure.IntegrationTest/**" # Any changes to the tests
- "src/**/Entities/**/*.cs" # Database entity definitions
jobs:
check-test-secrets:

View File

@ -1,8 +1,10 @@
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
@ -22,9 +24,10 @@ public class GroupsController : Controller
private readonly IGetGroupsListQuery _getGroupsListQuery;
private readonly IDeleteGroupCommand _deleteGroupCommand;
private readonly IPatchGroupCommand _patchGroupCommand;
private readonly IPatchGroupCommandvNext _patchGroupCommandvNext;
private readonly IPostGroupCommand _postGroupCommand;
private readonly IPutGroupCommand _putGroupCommand;
private readonly ILogger<GroupsController> _logger;
private readonly IFeatureService _featureService;
public GroupsController(
IGroupRepository groupRepository,
@ -32,18 +35,21 @@ public class GroupsController : Controller
IGetGroupsListQuery getGroupsListQuery,
IDeleteGroupCommand deleteGroupCommand,
IPatchGroupCommand patchGroupCommand,
IPatchGroupCommandvNext patchGroupCommandvNext,
IPostGroupCommand postGroupCommand,
IPutGroupCommand putGroupCommand,
ILogger<GroupsController> logger)
IFeatureService featureService
)
{
_groupRepository = groupRepository;
_organizationRepository = organizationRepository;
_getGroupsListQuery = getGroupsListQuery;
_deleteGroupCommand = deleteGroupCommand;
_patchGroupCommand = patchGroupCommand;
_patchGroupCommandvNext = patchGroupCommandvNext;
_postGroupCommand = postGroupCommand;
_putGroupCommand = putGroupCommand;
_logger = logger;
_featureService = featureService;
}
[HttpGet("{id}")]
@ -97,8 +103,21 @@ public class GroupsController : Controller
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
if (_featureService.IsEnabled(FeatureFlagKeys.ShortcutDuplicatePatchRequests))
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
throw new NotFoundException("Group not found.");
}
await _patchGroupCommandvNext.PatchGroupAsync(group, model);
return new NoContentResult();
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
await _patchGroupCommand.PatchGroupAsync(organization, id, model);
return new NoContentResult();
}

View File

@ -0,0 +1,9 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Scim.Models;
namespace Bit.Scim.Groups.Interfaces;
public interface IPatchGroupCommandvNext
{
Task PatchGroupAsync(Group group, ScimPatchModel model);
}

View File

@ -0,0 +1,170 @@
using System.Text.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
namespace Bit.Scim.Groups;
public class PatchGroupCommandvNext : IPatchGroupCommandvNext
{
private readonly IGroupRepository _groupRepository;
private readonly IGroupService _groupService;
private readonly IUpdateGroupCommand _updateGroupCommand;
private readonly ILogger<PatchGroupCommandvNext> _logger;
private readonly IOrganizationRepository _organizationRepository;
public PatchGroupCommandvNext(
IGroupRepository groupRepository,
IGroupService groupService,
IUpdateGroupCommand updateGroupCommand,
ILogger<PatchGroupCommandvNext> logger,
IOrganizationRepository organizationRepository)
{
_groupRepository = groupRepository;
_groupService = groupService;
_updateGroupCommand = updateGroupCommand;
_logger = logger;
_organizationRepository = organizationRepository;
}
public async Task PatchGroupAsync(Group group, ScimPatchModel model)
{
foreach (var operation in model.Operations)
{
await HandleOperationAsync(group, operation);
}
}
private async Task HandleOperationAsync(Group group, ScimPatchModel.OperationModel operation)
{
switch (operation.Op?.ToLowerInvariant())
{
// Replace a list of members
case PatchOps.Replace when operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
var ids = GetOperationValueIds(operation.Value);
await _groupRepository.UpdateUsersAsync(group.Id, ids);
break;
}
// Replace group name from path
case PatchOps.Replace when operation.Path?.ToLowerInvariant() == PatchPaths.DisplayName:
{
group.Name = operation.Value.GetString();
var organization = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _updateGroupCommand.UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
break;
}
// Replace group name from value object
case PatchOps.Replace when
string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty):
{
group.Name = displayNameProperty.GetString();
var organization = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _updateGroupCommand.UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
break;
}
// Add a single member
case PatchOps.Add when
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.StartsWith("members[value eq ", StringComparison.OrdinalIgnoreCase) &&
TryGetOperationPathId(operation.Path, out var addId):
{
await AddMembersAsync(group, [addId]);
break;
}
// Add a list of members
case PatchOps.Add when
operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
await AddMembersAsync(group, GetOperationValueIds(operation.Value));
break;
}
// Remove a single member
case PatchOps.Remove when
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.StartsWith("members[value eq ", StringComparison.OrdinalIgnoreCase) &&
TryGetOperationPathId(operation.Path, out var removeId):
{
await _groupService.DeleteUserAsync(group, removeId, EventSystemUser.SCIM);
break;
}
// Remove a list of members
case PatchOps.Remove when
operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
break;
}
default:
{
_logger.LogWarning("Group patch operation not handled: {OperationOp}:{OperationPath}", operation.Op, operation.Path);
break;
}
}
}
private async Task AddMembersAsync(Group group, HashSet<Guid> usersToAdd)
{
// Azure Entra ID is known to send redundant "add" requests for each existing member every time any member
// is removed. To avoid excessive load on the database, we check against the high availability replica and
// return early if they already exist.
var groupMembers = await _groupRepository.GetManyUserIdsByIdAsync(group.Id, useReadOnlyReplica: true);
if (usersToAdd.IsSubsetOf(groupMembers))
{
_logger.LogDebug("Ignoring duplicate SCIM request to add members {Members} to group {Group}", usersToAdd, group.Id);
return;
}
await _groupRepository.AddGroupUsersByIdAsync(group.Id, usersToAdd);
}
private static HashSet<Guid> GetOperationValueIds(JsonElement objArray)
{
var ids = new HashSet<Guid>();
foreach (var obj in objArray.EnumerateArray())
{
if (obj.TryGetProperty("value", out var valueProperty))
{
if (valueProperty.TryGetGuid(out var guid))
{
ids.Add(guid);
}
}
}
return ids;
}
private static bool TryGetOperationPathId(string path, out Guid pathId)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
return Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out pathId);
}
}

View File

@ -1,11 +1,8 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Scim.Context;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
@ -14,17 +11,13 @@ namespace Bit.Scim.Groups;
public class PostGroupCommand : IPostGroupCommand
{
private readonly IGroupRepository _groupRepository;
private readonly IScimContext _scimContext;
private readonly ICreateGroupCommand _createGroupCommand;
public PostGroupCommand(
IGroupRepository groupRepository,
IOrganizationRepository organizationRepository,
IScimContext scimContext,
ICreateGroupCommand createGroupCommand)
{
_groupRepository = groupRepository;
_scimContext = scimContext;
_createGroupCommand = createGroupCommand;
}
@ -50,11 +43,6 @@ public class PostGroupCommand : IPostGroupCommand
private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model)
{
if (_scimContext.RequestScimProvider != ScimProviderType.Okta)
{
return;
}
if (model.Members == null)
{
return;

View File

@ -1,10 +1,8 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Scim.Context;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
@ -13,16 +11,13 @@ namespace Bit.Scim.Groups;
public class PutGroupCommand : IPutGroupCommand
{
private readonly IGroupRepository _groupRepository;
private readonly IScimContext _scimContext;
private readonly IUpdateGroupCommand _updateGroupCommand;
public PutGroupCommand(
IGroupRepository groupRepository,
IScimContext scimContext,
IUpdateGroupCommand updateGroupCommand)
{
_groupRepository = groupRepository;
_scimContext = scimContext;
_updateGroupCommand = updateGroupCommand;
}
@ -43,12 +38,6 @@ public class PutGroupCommand : IPutGroupCommand
private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model)
{
if (_scimContext.RequestScimProvider != ScimProviderType.Okta &&
_scimContext.RequestScimProvider != ScimProviderType.Ping)
{
return;
}
if (model.Members == null)
{
return;

View File

@ -7,3 +7,16 @@ public static class ScimConstants
public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User";
public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group";
}
public static class PatchOps
{
public const string Replace = "replace";
public const string Add = "add";
public const string Remove = "remove";
}
public static class PatchPaths
{
public const string Members = "members";
public const string DisplayName = "displayname";
}

View File

@ -10,6 +10,7 @@ public static class ScimServiceCollectionExtensions
public static void AddScimGroupCommands(this IServiceCollection services)
{
services.AddScoped<IPatchGroupCommand, PatchGroupCommand>();
services.AddScoped<IPatchGroupCommandvNext, PatchGroupCommandvNext>();
services.AddScoped<IPostGroupCommand, PostGroupCommand>();
services.AddScoped<IPutGroupCommand, PutGroupCommand>();
}

View File

@ -0,0 +1,237 @@
using System.Text.Json;
using Bit.Scim.IntegrationTest.Factories;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.Helpers;
using Xunit;
namespace Bit.Scim.IntegrationTest.Controllers.v2;
public class GroupsControllerPatchTests : IClassFixture<ScimApplicationFactory>, IAsyncLifetime
{
private readonly ScimApplicationFactory _factory;
public GroupsControllerPatchTests(ScimApplicationFactory factory)
{
_factory = factory;
}
public Task InitializeAsync()
{
var databaseContext = _factory.GetDatabaseContext();
_factory.ReinitializeDbForTests(databaseContext);
return Task.CompletedTask;
}
Task IAsyncLifetime.DisposeAsync() => Task.CompletedTask;
[Fact]
public async Task Patch_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_ReplaceMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Single(databaseContext.GroupUsers);
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
var groupUser = databaseContext.GroupUsers.FirstOrDefault();
Assert.Equal(ScimApplicationFactory.TestOrganizationUserId2, groupUser.OrganizationUserId);
}
[Fact]
public async Task Patch_AddSingleMember_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId2}\"]",
Value = JsonDocument.Parse("{}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount + 1, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_AddListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId2;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}},{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId3}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId3));
}
[Fact]
public async Task Patch_RemoveSingleMember_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId1}\"]",
Value = JsonDocument.Parse("{}").RootElement
},
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupCount, databaseContext.Groups.Count());
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
}
[Fact]
public async Task Patch_RemoveListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId1}\"}}, {{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId4}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Empty(databaseContext.GroupUsers);
}
[Fact]
public async Task Patch_NotFound()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = Guid.NewGuid();
var inputModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var expectedResponse = new ScimErrorResponseModel
{
Status = StatusCodes.Status404NotFound,
Detail = "Group not found.",
Schemas = new List<string> { ScimConstants.Scim2SchemaError }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
var responseModel = JsonSerializer.Deserialize<ScimErrorResponseModel>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel);
}
}

View File

@ -0,0 +1,251 @@
using System.Text.Json;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Services;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.IntegrationTest.Factories;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.Helpers;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Xunit;
namespace Bit.Scim.IntegrationTest.Controllers.v2;
public class GroupsControllerPatchTestsvNext : IClassFixture<ScimApplicationFactory>, IAsyncLifetime
{
private readonly ScimApplicationFactory _factory;
public GroupsControllerPatchTestsvNext(ScimApplicationFactory factory)
{
_factory = factory;
// Enable the feature flag for new PatchGroupsCommand and stub out the old command to be safe
_factory.SubstituteService((IFeatureService featureService)
=> featureService.IsEnabled(FeatureFlagKeys.ShortcutDuplicatePatchRequests).Returns(true));
_factory.SubstituteService((IPatchGroupCommand patchGroupCommand)
=> patchGroupCommand.PatchGroupAsync(Arg.Any<Organization>(), Arg.Any<Guid>(), Arg.Any<ScimPatchModel>())
.ThrowsAsync(new Exception("This test suite should be testing the vNext command, but the existing command was called.")));
}
public Task InitializeAsync()
{
var databaseContext = _factory.GetDatabaseContext();
_factory.ReinitializeDbForTests(databaseContext);
return Task.CompletedTask;
}
Task IAsyncLifetime.DisposeAsync() => Task.CompletedTask;
[Fact]
public async Task Patch_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_ReplaceMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Single(databaseContext.GroupUsers);
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
var groupUser = databaseContext.GroupUsers.FirstOrDefault();
Assert.Equal(ScimApplicationFactory.TestOrganizationUserId2, groupUser.OrganizationUserId);
}
[Fact]
public async Task Patch_AddSingleMember_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId2}\"]",
Value = JsonDocument.Parse("{}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount + 1, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_AddListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId2;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}},{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId3}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId3));
}
[Fact]
public async Task Patch_RemoveSingleMember_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId1}\"]",
Value = JsonDocument.Parse("{}").RootElement
},
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupCount, databaseContext.Groups.Count());
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
}
[Fact]
public async Task Patch_RemoveListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId1}\"}}, {{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId4}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Empty(databaseContext.GroupUsers);
}
[Fact]
public async Task Patch_NotFound()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = Guid.NewGuid();
var inputModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var expectedResponse = new ScimErrorResponseModel
{
Status = StatusCodes.Status404NotFound,
Detail = "Group not found.",
Schemas = new List<string> { ScimConstants.Scim2SchemaError }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
var responseModel = JsonSerializer.Deserialize<ScimErrorResponseModel>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel);
}
}

View File

@ -9,9 +9,6 @@ namespace Bit.Scim.IntegrationTest.Controllers.v2;
public class GroupsControllerTests : IClassFixture<ScimApplicationFactory>, IAsyncLifetime
{
private const int _initialGroupCount = 3;
private const int _initialGroupUsersCount = 2;
private readonly ScimApplicationFactory _factory;
public GroupsControllerTests(ScimApplicationFactory factory)
@ -237,10 +234,10 @@ public class GroupsControllerTests : IClassFixture<ScimApplicationFactory>, IAsy
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel, "Id");
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(_initialGroupCount + 1, databaseContext.Groups.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupCount + 1, databaseContext.Groups.Count());
Assert.True(databaseContext.Groups.Any(g => g.Name == displayName && g.ExternalId == externalId));
Assert.Equal(_initialGroupUsersCount + 1, databaseContext.GroupUsers.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount + 1, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == responseModel.Id && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
}
@ -281,7 +278,7 @@ public class GroupsControllerTests : IClassFixture<ScimApplicationFactory>, IAsy
Assert.Equal(StatusCodes.Status409Conflict, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(_initialGroupCount, databaseContext.Groups.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupCount, databaseContext.Groups.Count());
Assert.False(databaseContext.Groups.Any(g => g.Name == "New Group"));
}
@ -354,216 +351,6 @@ public class GroupsControllerTests : IClassFixture<ScimApplicationFactory>, IAsy
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel);
}
[Fact]
public async Task Patch_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
Assert.Equal(_initialGroupUsersCount, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_ReplaceMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Single(databaseContext.GroupUsers);
Assert.Equal(_initialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
var groupUser = databaseContext.GroupUsers.FirstOrDefault();
Assert.Equal(ScimApplicationFactory.TestOrganizationUserId2, groupUser.OrganizationUserId);
}
[Fact]
public async Task Patch_AddSingleMember_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId2}\"]",
Value = JsonDocument.Parse("{}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(_initialGroupUsersCount + 1, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_AddListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId2;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}},{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId3}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId3));
}
[Fact]
public async Task Patch_RemoveSingleMember_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId1}\"]",
Value = JsonDocument.Parse("{}").RootElement
},
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(_initialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
Assert.Equal(_initialGroupCount, databaseContext.Groups.Count());
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
}
[Fact]
public async Task Patch_RemoveListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId1}\"}}, {{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId4}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Empty(databaseContext.GroupUsers);
}
[Fact]
public async Task Patch_NotFound()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = Guid.NewGuid();
var inputModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var expectedResponse = new ScimErrorResponseModel
{
Status = StatusCodes.Status404NotFound,
Detail = "Group not found.",
Schemas = new List<string> { ScimConstants.Scim2SchemaError }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
var responseModel = JsonSerializer.Deserialize<ScimErrorResponseModel>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel);
}
[Fact]
public async Task Delete_Success()
{
@ -575,7 +362,7 @@ public class GroupsControllerTests : IClassFixture<ScimApplicationFactory>, IAsy
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(_initialGroupCount - 1, databaseContext.Groups.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupCount - 1, databaseContext.Groups.Count());
Assert.True(databaseContext.Groups.FirstOrDefault(g => g.Id == groupId) == null);
}

View File

@ -9,8 +9,6 @@ using Bit.Infrastructure.EntityFramework.Repositories;
using Bit.IntegrationTestCommon.Factories;
using Bit.Scim.Models;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Mvc.Testing;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.Options;
using Microsoft.Net.Http.Headers;
@ -18,7 +16,8 @@ namespace Bit.Scim.IntegrationTest.Factories;
public class ScimApplicationFactory : WebApplicationFactoryBase<Startup>
{
public readonly new TestServer Server;
public const int InitialGroupCount = 3;
public const int InitialGroupUsersCount = 2;
public static readonly Guid TestUserId1 = Guid.Parse("2e8173db-8e8d-4de1-ac38-91b15c6d8dcb");
public static readonly Guid TestUserId2 = Guid.Parse("b57846fc-0e94-4c93-9de5-9d0389eeadfb");
@ -33,32 +32,29 @@ public class ScimApplicationFactory : WebApplicationFactoryBase<Startup>
public static readonly Guid TestOrganizationUserId3 = Guid.Parse("be2f9045-e2b6-4173-ad44-4c69c3ea8140");
public static readonly Guid TestOrganizationUserId4 = Guid.Parse("1f5689b7-e96e-4840-b0b1-eb3d5b5fd514");
public ScimApplicationFactory()
protected override void ConfigureWebHost(IWebHostBuilder builder)
{
WebApplicationFactory<Startup> webApplicationFactory = WithWebHostBuilder(builder =>
base.ConfigureWebHost(builder);
builder.ConfigureServices(services =>
{
builder.ConfigureServices(services =>
services
.AddAuthentication("Test")
.AddScheme<AuthenticationSchemeOptions, TestAuthHandler>("Test", options => { });
// Override to bypass SCIM authorization
services.AddAuthorization(config =>
{
services
.AddAuthentication("Test")
.AddScheme<AuthenticationSchemeOptions, TestAuthHandler>("Test", options => { });
// Override to bypass SCIM authorization
services.AddAuthorization(config =>
config.AddPolicy("Scim", policy =>
{
config.AddPolicy("Scim", policy =>
{
policy.RequireAssertion(a => true);
});
policy.RequireAssertion(a => true);
});
var mailService = services.First(sd => sd.ServiceType == typeof(IMailService));
services.Remove(mailService);
services.AddSingleton<IMailService, NoopMailService>();
});
});
Server = webApplicationFactory.Server;
var mailService = services.First(sd => sd.ServiceType == typeof(IMailService));
services.Remove(mailService);
services.AddSingleton<IMailService, NoopMailService>();
});
}
public async Task<HttpContext> GroupsGetAsync(Guid organizationId, Guid id)

View File

@ -0,0 +1,381 @@
using System.Text.Json;
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Scim.Groups;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Scim.Test.Groups;
[SutProviderCustomize]
public class PatchGroupCommandvNextTests
{
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceListMembers_Success(SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization, Group group, IEnumerable<Guid> userIds)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Path = "members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count() &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromPath_Success(
SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Path = "displayname",
Value = JsonDocument.Parse($"\"{displayName}\"").RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IUpdateGroupCommand>().Received(1).UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
Assert.Equal(displayName, group.Name);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromValueObject_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{displayName}\"}}").RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IUpdateGroupCommand>().Received(1).UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
Assert.Equal(displayName, group.Name);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddSingleMember_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, ICollection<Guid> existingMembers, Guid userId)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg => arg.Single() == userId));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddSingleMember_ReturnsEarlyIfAlreadyInGroup(
SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization,
Group group,
ICollection<Guid> existingMembers)
{
// User being added is already in group
var userId = existingMembers.First();
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>()
.DidNotReceiveWithAnyArgs()
.AddGroupUsersByIdAsync(default, default);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, ICollection<Guid> existingMembers, ICollection<Guid> userIds)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_IgnoresDuplicatesInRequest(
SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group,
ICollection<Guid> existingMembers)
{
// Create 3 userIds
var fixture = new Fixture { RepeatCount = 3 };
var userIds = fixture.CreateMany<Guid>().ToList();
// Copy the list and add a duplicate
var userIdsWithDuplicate = userIds.Append(userIds.First()).ToList();
Assert.Equal(4, userIdsWithDuplicate.Count);
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer
.Serialize(userIdsWithDuplicate
.Select(uid => new { value = uid })
.ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == 3 &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_SuccessIfOnlySomeUsersAreInGroup(
SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization, Group group,
ICollection<Guid> existingMembers,
ICollection<Guid> userIds)
{
// A user is already in the group, but some still need to be added
userIds.Add(existingMembers.First());
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>()
.Received(1)
.AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_RemoveSingleMember_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, Guid userId)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupService>().Received(1).DeleteUserAsync(group, userId, EventSystemUser.SCIM);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_RemoveListMembers_Success(SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization, Group group, ICollection<Guid> existingMembers)
{
List<Guid> usersToRemove = [existingMembers.First(), existingMembers.Skip(1).First()];
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id)
.Returns(existingMembers);
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "remove",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(usersToRemove.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
var expectedRemainingUsers = existingMembers.Skip(2).ToList();
await sutProvider.GetDependency<IGroupRepository>()
.Received(1)
.UpdateUsersAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == expectedRemainingUsers.Count &&
arg.ToHashSet().SetEquals(expectedRemainingUsers)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_NoAction_Success(
SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().UpdateUsersAsync(default, default);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().GetManyUserIdsByIdAsync(default);
await sutProvider.GetDependency<IUpdateGroupCommand>().DidNotReceiveWithAnyArgs().UpdateGroupAsync(default, default);
await sutProvider.GetDependency<IGroupService>().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default);
}
}

View File

@ -1,10 +1,8 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Scim.Context;
using Bit.Scim.Groups;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
@ -73,10 +71,6 @@ public class PostGroupCommandTests
.GetManyByOrganizationIdAsync(organization.Id)
.Returns(groups);
sutProvider.GetDependency<IScimContext>()
.RequestScimProvider
.Returns(ScimProviderType.Okta);
var group = await sutProvider.Sut.PostGroupAsync(organization, scimGroupRequestModel);
await sutProvider.GetDependency<ICreateGroupCommand>().Received(1).CreateGroupAsync(group, organization, EventSystemUser.SCIM, null);

View File

@ -1,10 +1,8 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Scim.Context;
using Bit.Scim.Groups;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
@ -62,10 +60,6 @@ public class PutGroupCommandTests
.GetByIdAsync(group.Id)
.Returns(group);
sutProvider.GetDependency<IScimContext>()
.RequestScimProvider
.Returns(ScimProviderType.Okta);
var inputModel = new ScimGroupRequestModel
{
DisplayName = displayName,

View File

@ -16,7 +16,6 @@
</ItemGroup>
<ItemGroup>
<Folder Include="Billing\Controllers\" />
<Folder Include="Billing\Models\" />
</ItemGroup>
<Choose>

View File

@ -235,7 +235,8 @@ public class ProvidersController : Controller
var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id);
var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
return View(new ProviderViewModel(provider, users, providerOrganizations));
var providerPlans = await _providerPlanRepository.GetByProviderId(id);
return View(new ProviderViewModel(provider, users, providerOrganizations, providerPlans.ToList()));
}
[SelfHosted(NotSelfHostedOnly = true)]

View File

@ -19,7 +19,7 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject
IEnumerable<ProviderOrganizationOrganizationDetails> organizations,
IReadOnlyCollection<ProviderPlan> providerPlans,
string gatewayCustomerUrl = null,
string gatewaySubscriptionUrl = null) : base(provider, providerUsers, organizations)
string gatewaySubscriptionUrl = null) : base(provider, providerUsers, organizations, providerPlans)
{
Name = provider.DisplayName();
BusinessName = provider.DisplayBusinessName();

View File

@ -1,6 +1,9 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Admin.Billing.Models;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
namespace Bit.Admin.AdminConsole.Models;
@ -8,17 +11,57 @@ public class ProviderViewModel
{
public ProviderViewModel() { }
public ProviderViewModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
public ProviderViewModel(
Provider provider,
IEnumerable<ProviderUserUserDetails> providerUsers,
IEnumerable<ProviderOrganizationOrganizationDetails> organizations,
IReadOnlyCollection<ProviderPlan> providerPlans)
{
Provider = provider;
UserCount = providerUsers.Count();
ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin);
ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id);
if (Provider.Type == ProviderType.Msp)
{
var usedTeamsSeats = ProviderOrganizations.Where(po => po.PlanType == PlanType.TeamsMonthly)
.Sum(po => po.OccupiedSeats) ?? 0;
var teamsProviderPlan = providerPlans.FirstOrDefault(plan => plan.PlanType == PlanType.TeamsMonthly);
if (teamsProviderPlan != null && teamsProviderPlan.IsConfigured())
{
ProviderPlanViewModels.Add(new ProviderPlanViewModel("Teams (Monthly) Subscription", teamsProviderPlan, usedTeamsSeats));
}
var usedEnterpriseSeats = ProviderOrganizations.Where(po => po.PlanType == PlanType.EnterpriseMonthly)
.Sum(po => po.OccupiedSeats) ?? 0;
var enterpriseProviderPlan = providerPlans.FirstOrDefault(plan => plan.PlanType == PlanType.EnterpriseMonthly);
if (enterpriseProviderPlan != null && enterpriseProviderPlan.IsConfigured())
{
ProviderPlanViewModels.Add(new ProviderPlanViewModel("Enterprise (Monthly) Subscription", enterpriseProviderPlan, usedEnterpriseSeats));
}
}
else if (Provider.Type == ProviderType.MultiOrganizationEnterprise)
{
var usedEnterpriseSeats = ProviderOrganizations.Where(po => po.PlanType == PlanType.EnterpriseMonthly)
.Sum(po => po.OccupiedSeats).GetValueOrDefault(0);
var enterpriseProviderPlan = providerPlans.FirstOrDefault();
if (enterpriseProviderPlan != null && enterpriseProviderPlan.IsConfigured())
{
var planLabel = enterpriseProviderPlan.PlanType switch
{
PlanType.EnterpriseMonthly => "Enterprise (Monthly) Subscription",
PlanType.EnterpriseAnnually => "Enterprise (Annually) Subscription",
_ => string.Empty
};
ProviderPlanViewModels.Add(new ProviderPlanViewModel(planLabel, enterpriseProviderPlan, usedEnterpriseSeats));
}
}
}
public int UserCount { get; set; }
public Provider Provider { get; set; }
public IEnumerable<ProviderUserUserDetails> ProviderAdmins { get; set; }
public IEnumerable<ProviderOrganizationOrganizationDetails> ProviderOrganizations { get; set; }
public List<ProviderPlanViewModel> ProviderPlanViewModels { get; set; } = [];
}

View File

@ -17,6 +17,10 @@
<h2>Provider Information</h2>
@await Html.PartialAsync("_ViewInformation", Model)
@if (Model.ProviderPlanViewModels.Any())
{
@await Html.PartialAsync("~/Billing/Views/Providers/ProviderPlans.cshtml", Model.ProviderPlanViewModels)
}
@await Html.PartialAsync("Admins", Model)
<form method="post" id="edit-form">
<div asp-validation-summary="All" class="alert alert-danger"></div>

View File

@ -7,5 +7,9 @@
<h2>Information</h2>
@await Html.PartialAsync("_ViewInformation", Model)
@if (Model.ProviderPlanViewModels.Any())
{
@await Html.PartialAsync("ProviderPlans", Model.ProviderPlanViewModels)
}
@await Html.PartialAsync("Admins", Model)
@await Html.PartialAsync("Organizations", Model)

View File

@ -0,0 +1,26 @@
using Bit.Core.Billing.Entities;
namespace Bit.Admin.Billing.Models;
public class ProviderPlanViewModel
{
public string Name { get; set; }
public int PurchasedSeats { get; set; }
public int AssignedSeats { get; set; }
public int UsedSeats { get; set; }
public int RemainingSeats { get; set; }
public ProviderPlanViewModel(
string name,
ProviderPlan providerPlan,
int usedSeats)
{
var purchasedSeats = (providerPlan.SeatMinimum ?? 0) + (providerPlan.PurchasedSeats ?? 0);
Name = name;
PurchasedSeats = purchasedSeats;
AssignedSeats = providerPlan.AllocatedSeats ?? 0;
UsedSeats = usedSeats;
RemainingSeats = purchasedSeats - AssignedSeats;
}
}

View File

@ -0,0 +1,18 @@
@model List<Bit.Admin.Billing.Models.ProviderPlanViewModel>
@foreach (var plan in Model)
{
<h2>@plan.Name</h2>
<dl class="row">
<dt class="col-sm-4 col-lg-3">Purchased Seats</dt>
<dd class="col-sm-8 col-lg-9">@plan.PurchasedSeats</dd>
<dt class="col-sm-4 col-lg-3">Assigned Seats</dt>
<dd class="col-sm-8 col-lg-9">@plan.AssignedSeats</dd>
<dt class="col-sm-4 col-lg-3">Used Seats</dt>
<dd class="col-sm-8 col-lg-9">@plan.UsedSeats</dd>
<dt class="col-sm-4 col-lg-3">Remaining Seats</dt>
<dd class="col-sm-8 col-lg-9">@plan.RemainingSeats</dd>
</dl>
}

View File

@ -304,7 +304,7 @@ public class TwoFactorController : Controller
if (user != null)
{
// check if 2FA email is from passwordless
// Check if 2FA email is from Passwordless.
if (!string.IsNullOrEmpty(requestModel.AuthRequestAccessCode))
{
if (await _verifyAuthRequestCommand
@ -317,17 +317,14 @@ public class TwoFactorController : Controller
}
else if (!string.IsNullOrEmpty(requestModel.SsoEmail2FaSessionToken))
{
if (this.ValidateSsoEmail2FaToken(requestModel.SsoEmail2FaSessionToken, user))
if (ValidateSsoEmail2FaToken(requestModel.SsoEmail2FaSessionToken, user))
{
await _userService.SendTwoFactorEmailAsync(user);
return;
}
else
{
await this.ThrowDelayedBadRequestExceptionAsync(
"Cannot send two-factor email: a valid, non-expired SSO Email 2FA Session token is required to send 2FA emails.",
2000);
}
await ThrowDelayedBadRequestExceptionAsync(
"Cannot send two-factor email: a valid, non-expired SSO Email 2FA Session token is required to send 2FA emails.");
}
else if (await _userService.VerifySecretAsync(user, requestModel.Secret))
{
@ -336,8 +333,7 @@ public class TwoFactorController : Controller
}
}
await this.ThrowDelayedBadRequestExceptionAsync(
"Cannot send two-factor email.", 2000);
await ThrowDelayedBadRequestExceptionAsync("Cannot send two-factor email.");
}
[HttpPut("email")]
@ -374,7 +370,7 @@ public class TwoFactorController : Controller
public async Task<TwoFactorProviderResponseModel> PutOrganizationDisable(string id,
[FromBody] TwoFactorProviderRequestModel model)
{
var user = await CheckAsync(model, false);
await CheckAsync(model, false);
var orgIdGuid = new Guid(id);
if (!await _currentContext.ManagePolicies(orgIdGuid))
@ -401,6 +397,10 @@ public class TwoFactorController : Controller
return response;
}
/// <summary>
/// To be removed when the feature flag pm-17128-recovery-code-login is removed PM-18175.
/// </summary>
[Obsolete("Two Factor recovery is handled in the TwoFactorAuthenticationValidator.")]
[HttpPost("recover")]
[AllowAnonymous]
public async Task PostRecover([FromBody] TwoFactorRecoveryRequestModel model)
@ -463,10 +463,8 @@ public class TwoFactorController : Controller
await Task.Delay(2000);
throw new BadRequestException(name, $"{name} is invalid.");
}
else
{
await Task.Delay(500);
}
await Task.Delay(500);
}
private bool ValidateSsoEmail2FaToken(string ssoEmail2FaSessionToken, User user)

View File

@ -96,12 +96,6 @@ public class ImportCiphersController : Controller
return true;
}
//Users allowed to import if they CanCreate Collections
if (!(await _authorizationService.AuthorizeAsync(User, collections, BulkCollectionOperations.Create)).Succeeded)
{
return false;
}
//Calling Repository instead of Service as we want to get all the collections, regardless of permission
//Permissions check will be done later on AuthorizationService
var orgCollectionIds =
@ -118,6 +112,12 @@ public class ImportCiphersController : Controller
return false;
};
//Users allowed to import if they CanCreate Collections
if (!(await _authorizationService.AuthorizeAsync(User, collections, BulkCollectionOperations.Create)).Succeeded)
{
return false;
}
return true;
}
}

View File

@ -1,4 +1,5 @@
using Bit.Billing.Constants;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Context;
@ -17,7 +18,6 @@ public class PaymentSucceededHandler : IPaymentSucceededHandler
{
private readonly ILogger<PaymentSucceededHandler> _logger;
private readonly IStripeEventService _stripeEventService;
private readonly IOrganizationService _organizationService;
private readonly IUserService _userService;
private readonly IStripeFacade _stripeFacade;
private readonly IProviderRepository _providerRepository;
@ -27,6 +27,7 @@ public class PaymentSucceededHandler : IPaymentSucceededHandler
private readonly IUserRepository _userRepository;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly IPushNotificationService _pushNotificationService;
private readonly IOrganizationEnableCommand _organizationEnableCommand;
public PaymentSucceededHandler(
ILogger<PaymentSucceededHandler> logger,
@ -39,8 +40,8 @@ public class PaymentSucceededHandler : IPaymentSucceededHandler
IUserRepository userRepository,
IStripeEventUtilityService stripeEventUtilityService,
IUserService userService,
IOrganizationService organizationService,
IPushNotificationService pushNotificationService)
IPushNotificationService pushNotificationService,
IOrganizationEnableCommand organizationEnableCommand)
{
_logger = logger;
_stripeEventService = stripeEventService;
@ -52,8 +53,8 @@ public class PaymentSucceededHandler : IPaymentSucceededHandler
_userRepository = userRepository;
_stripeEventUtilityService = stripeEventUtilityService;
_userService = userService;
_organizationService = organizationService;
_pushNotificationService = pushNotificationService;
_organizationEnableCommand = organizationEnableCommand;
}
/// <summary>
@ -142,7 +143,7 @@ public class PaymentSucceededHandler : IPaymentSucceededHandler
return;
}
await _organizationService.EnableAsync(organizationId.Value, subscription.CurrentPeriodEnd);
await _organizationEnableCommand.EnableAsync(organizationId.Value, subscription.CurrentPeriodEnd);
var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);
await _pushNotificationService.PushSyncOrganizationStatusAsync(organization);

View File

@ -1,6 +1,7 @@
using Bit.Billing.Constants;
using Bit.Billing.Jobs;
using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
@ -24,6 +25,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private readonly IOrganizationRepository _organizationRepository;
private readonly ISchedulerFactory _schedulerFactory;
private readonly IFeatureService _featureService;
private readonly IOrganizationEnableCommand _organizationEnableCommand;
public SubscriptionUpdatedHandler(
IStripeEventService stripeEventService,
@ -35,7 +37,8 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
IPushNotificationService pushNotificationService,
IOrganizationRepository organizationRepository,
ISchedulerFactory schedulerFactory,
IFeatureService featureService)
IFeatureService featureService,
IOrganizationEnableCommand organizationEnableCommand)
{
_stripeEventService = stripeEventService;
_stripeEventUtilityService = stripeEventUtilityService;
@ -47,6 +50,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
_organizationRepository = organizationRepository;
_schedulerFactory = schedulerFactory;
_featureService = featureService;
_organizationEnableCommand = organizationEnableCommand;
}
/// <summary>
@ -90,7 +94,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
}
case StripeSubscriptionStatus.Active when organizationId.HasValue:
{
await _organizationService.EnableAsync(organizationId.Value);
await _organizationEnableCommand.EnableAsync(organizationId.Value);
var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);
await _pushNotificationService.PushSyncOrganizationStatusAsync(organization);
break;

View File

@ -0,0 +1,39 @@
#nullable enable
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Utilities;
namespace Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
/// <summary>
/// Represents an OrganizationUser and a Policy which *may* be enforced against them.
/// You may assume that the Policy is enabled and that the organization's plan supports policies.
/// This is consumed by <see cref="IPolicyRequirement"/> to create requirements for specific policy types.
/// </summary>
public class PolicyDetails
{
public Guid OrganizationUserId { get; set; }
public Guid OrganizationId { get; set; }
public PolicyType PolicyType { get; set; }
public string? PolicyData { get; set; }
public OrganizationUserType OrganizationUserType { get; set; }
public OrganizationUserStatusType OrganizationUserStatus { get; set; }
/// <summary>
/// Custom permissions for the organization user, if any. Use <see cref="GetOrganizationUserCustomPermissions"/>
/// to deserialize.
/// </summary>
public string? OrganizationUserPermissionsData { get; set; }
/// <summary>
/// True if the user is also a ProviderUser for the organization, false otherwise.
/// </summary>
public bool IsProvider { get; set; }
public T GetDataModel<T>() where T : IPolicyDataModel, new()
=> CoreHelpers.LoadClassFromJsonData<T>(PolicyData);
public Permissions GetOrganizationUserCustomPermissions()
=> CoreHelpers.LoadClassFromJsonData<Permissions>(OrganizationUserPermissionsData);
}

View File

@ -1,5 +1,6 @@
using System.Net;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
using Bit.Core.Utilities;
@ -23,6 +24,7 @@ public class ProviderOrganizationOrganizationDetails
public int? OccupiedSeats { get; set; }
public int? Seats { get; set; }
public string Plan { get; set; }
public PlanType PlanType { get; set; }
public OrganizationStatusType Status { get; set; }
/// <summary>

View File

@ -0,0 +1,11 @@
namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
public interface IOrganizationEnableCommand
{
/// <summary>
/// Enables an organization that is currently disabled and has a gateway configured.
/// </summary>
/// <param name="organizationId">The unique identifier of the organization to enable.</param>
/// <param name="expirationDate">When provided, sets the date the organization's subscription will expire. If not provided, no expiration date will be set.</param>
Task EnableAsync(Guid organizationId, DateTime? expirationDate = null);
}

View File

@ -0,0 +1,39 @@
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
public class OrganizationEnableCommand : IOrganizationEnableCommand
{
private readonly IApplicationCacheService _applicationCacheService;
private readonly IOrganizationRepository _organizationRepository;
public OrganizationEnableCommand(
IApplicationCacheService applicationCacheService,
IOrganizationRepository organizationRepository)
{
_applicationCacheService = applicationCacheService;
_organizationRepository = organizationRepository;
}
public async Task EnableAsync(Guid organizationId, DateTime? expirationDate = null)
{
var organization = await _organizationRepository.GetByIdAsync(organizationId);
if (organization is null || organization.Enabled || expirationDate is not null && organization.Gateway is null)
{
return;
}
organization.Enabled = true;
if (expirationDate is not null && organization.Gateway is not null)
{
organization.ExpirationDate = expirationDate;
organization.RevisionDate = DateTime.UtcNow;
}
await _organizationRepository.ReplaceAsync(organization);
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
}
}

View File

@ -0,0 +1,18 @@
#nullable enable
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies;
public interface IPolicyRequirementQuery
{
/// <summary>
/// Get a policy requirement for a specific user.
/// The policy requirement represents how one or more policy types should be enforced against the user.
/// It will always return a value even if there are no policies that should be enforced.
/// This should be used for all policy checks.
/// </summary>
/// <param name="userId">The user that you need to enforce the policy against.</param>
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement;
}

View File

@ -0,0 +1,28 @@
#nullable enable
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
public class PolicyRequirementQuery(
IPolicyRepository policyRepository,
IEnumerable<RequirementFactory<IPolicyRequirement>> factories)
: IPolicyRequirementQuery
{
public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
{
var factory = factories.OfType<RequirementFactory<T>>().SingleOrDefault();
if (factory is null)
{
throw new NotImplementedException("No Policy Requirement found for " + typeof(T));
}
return factory(await GetPolicyDetails(userId));
}
private Task<IEnumerable<PolicyDetails>> GetPolicyDetails(Guid userId) =>
policyRepository.GetPolicyDetailsByUserId(userId);
}

View File

@ -0,0 +1,24 @@
#nullable enable
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// Represents the business requirements of how one or more enterprise policies will be enforced against a user.
/// The implementation of this interface will depend on how the policies are enforced in the relevant domain.
/// </summary>
public interface IPolicyRequirement;
/// <summary>
/// A factory function that takes a sequence of <see cref="PolicyDetails"/> and transforms them into a single
/// <see cref="IPolicyRequirement"/> for consumption by the relevant domain. This will receive *all* policy types
/// that may be enforced against a user; when implementing this delegate, you must filter out irrelevant policy types
/// as well as policies that should not be enforced against a user (e.g. due to the user's role or status).
/// </summary>
/// <remarks>
/// See <see cref="PolicyRequirementHelpers"/> for extension methods to handle common requirements when implementing
/// this delegate.
/// </remarks>
public delegate T RequirementFactory<out T>(IEnumerable<PolicyDetails> policyDetails)
where T : IPolicyRequirement;

View File

@ -0,0 +1,41 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.Enums;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
public static class PolicyRequirementHelpers
{
/// <summary>
/// Filters the PolicyDetails by PolicyType. This is generally required to only get the PolicyDetails that your
/// IPolicyRequirement relates to.
/// </summary>
public static IEnumerable<PolicyDetails> GetPolicyType(
this IEnumerable<PolicyDetails> policyDetails,
PolicyType type)
=> policyDetails.Where(x => x.PolicyType == type);
/// <summary>
/// Filters the PolicyDetails to remove the specified user roles. This can be used to exempt
/// owners and admins from policy enforcement.
/// </summary>
public static IEnumerable<PolicyDetails> ExemptRoles(
this IEnumerable<PolicyDetails> policyDetails,
IEnumerable<OrganizationUserType> roles)
=> policyDetails.Where(x => !roles.Contains(x.OrganizationUserType));
/// <summary>
/// Filters the PolicyDetails to remove organization users who are also provider users for the organization.
/// This can be used to exempt provider users from policy enforcement.
/// </summary>
public static IEnumerable<PolicyDetails> ExemptProviders(this IEnumerable<PolicyDetails> policyDetails)
=> policyDetails.Where(x => !x.IsProvider);
/// <summary>
/// Filters the PolicyDetails to remove the specified organization user statuses. For example, this can be used
/// to exempt users in the invited and revoked statuses from policy enforcement.
/// </summary>
public static IEnumerable<PolicyDetails> ExemptStatus(
this IEnumerable<PolicyDetails> policyDetails, IEnumerable<OrganizationUserStatusType> status)
=> policyDetails.Where(x => !status.Contains(x.OrganizationUserStatus));
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
using Bit.Core.AdminConsole.Services;
using Bit.Core.AdminConsole.Services.Implementations;
@ -12,7 +13,14 @@ public static class PolicyServiceCollectionExtensions
{
services.AddScoped<IPolicyService, PolicyService>();
services.AddScoped<ISavePolicyCommand, SavePolicyCommand>();
services.AddScoped<IPolicyRequirementQuery, PolicyRequirementQuery>();
services.AddPolicyValidators();
services.AddPolicyRequirements();
}
private static void AddPolicyValidators(this IServiceCollection services)
{
services.AddScoped<IPolicyValidator, TwoFactorAuthenticationPolicyValidator>();
services.AddScoped<IPolicyValidator, SingleOrgPolicyValidator>();
services.AddScoped<IPolicyValidator, RequireSsoPolicyValidator>();
@ -20,4 +28,34 @@ public static class PolicyServiceCollectionExtensions
services.AddScoped<IPolicyValidator, MaximumVaultTimeoutPolicyValidator>();
services.AddScoped<IPolicyValidator, FreeFamiliesForEnterprisePolicyValidator>();
}
private static void AddPolicyRequirements(this IServiceCollection services)
{
// Register policy requirement factories here
}
/// <summary>
/// Used to register simple policy requirements where its factory method implements CreateRequirement.
/// This MUST be used rather than calling AddScoped directly, because it will ensure the factory method has
/// the correct type to be injected and then identified by <see cref="PolicyRequirementQuery"/> at runtime.
/// </summary>
/// <typeparam name="T">The specific PolicyRequirement being registered.</typeparam>
private static void AddPolicyRequirement<T>(this IServiceCollection serviceCollection, RequirementFactory<T> factory)
where T : class, IPolicyRequirement
=> serviceCollection.AddPolicyRequirement(_ => factory);
/// <summary>
/// Used to register policy requirements where you need to access additional dependencies (usually to return a
/// curried factory method).
/// This MUST be used rather than calling AddScoped directly, because it will ensure the factory method has
/// the correct type to be injected and then identified by <see cref="PolicyRequirementQuery"/> at runtime.
/// </summary>
/// <typeparam name="T">
/// A callback that takes IServiceProvider and returns a RequirementFactory for
/// your policy requirement.
/// </typeparam>
private static void AddPolicyRequirement<T>(this IServiceCollection serviceCollection,
Func<IServiceProvider, RequirementFactory<T>> factory)
where T : class, IPolicyRequirement
=> serviceCollection.AddScoped<RequirementFactory<IPolicyRequirement>>(factory);
}

View File

@ -14,11 +14,29 @@ public interface IGroupRepository : IRepository<Group, Guid>
Guid organizationId);
Task<ICollection<Group>> GetManyByManyIds(IEnumerable<Guid> groupIds);
Task<ICollection<Guid>> GetManyIdsByUserIdAsync(Guid organizationUserId);
Task<ICollection<Guid>> GetManyUserIdsByIdAsync(Guid id);
/// <summary>
/// Query all OrganizationUserIds who are a member of the specified group.
/// </summary>
/// <param name="id">The group id.</param>
/// <param name="useReadOnlyReplica">
/// Whether to use the high-availability database replica. This is for paths with high traffic where immediate data
/// consistency is not required. You generally do not want this.
/// </param>
/// <returns></returns>
Task<ICollection<Guid>> GetManyUserIdsByIdAsync(Guid id, bool useReadOnlyReplica = false);
Task<ICollection<GroupUser>> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId);
Task CreateAsync(Group obj, IEnumerable<CollectionAccessSelection> collections);
Task ReplaceAsync(Group obj, IEnumerable<CollectionAccessSelection> collections);
Task DeleteUserAsync(Guid groupId, Guid organizationUserId);
/// <summary>
/// Update a group's members. Replaces all members currently in the group.
/// Ignores members that do not belong to the same organization as the group.
/// </summary>
Task UpdateUsersAsync(Guid groupId, IEnumerable<Guid> organizationUserIds);
/// <summary>
/// Add members to a group. Gracefully ignores members that are already in the group,
/// duplicate organizationUserIds, and organizationUsers who are not part of the organization.
/// </summary>
Task AddGroupUsersByIdAsync(Guid groupId, IEnumerable<Guid> organizationUserIds);
Task DeleteManyAsync(IEnumerable<Guid> groupIds);
}

View File

@ -1,5 +1,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.Repositories;
#nullable enable
@ -8,7 +10,25 @@ namespace Bit.Core.AdminConsole.Repositories;
public interface IPolicyRepository : IRepository<Policy, Guid>
{
/// <summary>
/// Gets all policies of a given type for an organization.
/// </summary>
/// <remarks>
/// WARNING: do not use this to enforce policies against a user! It returns raw data and does not take into account
/// various business rules. Use <see cref="IPolicyRequirementQuery"/> instead.
/// </remarks>
Task<Policy?> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type);
Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId);
Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId);
/// <summary>
/// Gets all PolicyDetails for a user for all policy types.
/// </summary>
/// <remarks>
/// Each PolicyDetail represents an OrganizationUser and a Policy which *may* be enforced
/// against them. It only returns PolicyDetails for policies that are enabled and where the organization's plan
/// supports policies. It also excludes "revoked invited" users who are not subject to policy enforcement.
/// This is consumed by <see cref="IPolicyRequirementQuery"/> to create requirements for specific policy types.
/// You probably do not want to call it directly.
/// </remarks>
Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId);
}

View File

@ -5,4 +5,6 @@ namespace Bit.Core.Services;
public interface IEventMessageHandler
{
Task HandleEventAsync(EventMessage eventMessage);
Task HandleManyEventsAsync(IEnumerable<EventMessage> eventMessages);
}

View File

@ -28,10 +28,8 @@ public interface IOrganizationService
/// </summary>
Task<(Organization organization, OrganizationUser organizationUser)> SignUpAsync(OrganizationLicense license, User owner,
string ownerKey, string collectionName, string publicKey, string privateKey);
Task EnableAsync(Guid organizationId, DateTime? expirationDate);
Task DisableAsync(Guid organizationId, DateTime? expirationDate);
Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate);
Task EnableAsync(Guid organizationId);
Task UpdateAsync(Organization organization, bool updateBilling = false, EventType eventType = EventType.Organization_Updated);
Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type);
Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type);

View File

@ -1,4 +1,5 @@
using System.Text.Json;
using System.Text;
using System.Text.Json;
using Azure.Messaging.ServiceBus;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
@ -29,9 +30,20 @@ public class AzureServiceBusEventListenerService : EventLoggingListenerService
{
try
{
var eventMessage = JsonSerializer.Deserialize<EventMessage>(args.Message.Body.ToString());
using var jsonDocument = JsonDocument.Parse(Encoding.UTF8.GetString(args.Message.Body));
var root = jsonDocument.RootElement;
await _handler.HandleEventAsync(eventMessage);
if (root.ValueKind == JsonValueKind.Array)
{
var eventMessages = root.Deserialize<IEnumerable<EventMessage>>();
await _handler.HandleManyEventsAsync(eventMessages);
}
else if (root.ValueKind == JsonValueKind.Object)
{
var eventMessage = root.Deserialize<EventMessage>();
await _handler.HandleEventAsync(eventMessage);
}
await args.CompleteMessageAsync(args.Message);
}
catch (Exception exception)

View File

@ -29,10 +29,12 @@ public class AzureServiceBusEventWriteService : IEventWriteService, IAsyncDispos
public async Task CreateManyAsync(IEnumerable<IEvent> events)
{
foreach (var e in events)
var message = new ServiceBusMessage(JsonSerializer.SerializeToUtf8Bytes(events))
{
await CreateAsync(e);
}
ContentType = "application/json"
};
await _sender.SendMessageAsync(message);
}
public async ValueTask DisposeAsync()

View File

@ -11,4 +11,9 @@ public class AzureTableStorageEventHandler(
{
return eventWriteService.CreateManyAsync(EventTableEntity.IndexEvent(eventMessage));
}
public Task HandleManyEventsAsync(IEnumerable<EventMessage> eventMessages)
{
return eventWriteService.CreateManyAsync(eventMessages.SelectMany(EventTableEntity.IndexEvent));
}
}

View File

@ -11,4 +11,9 @@ public class EventRepositoryHandler(
{
return eventWriteService.CreateAsync(eventMessage);
}
public Task HandleManyEventsAsync(IEnumerable<EventMessage> eventMessages)
{
return eventWriteService.CreateManyAsync(eventMessages);
}
}

View File

@ -686,18 +686,6 @@ public class OrganizationService : IOrganizationService
}
}
public async Task EnableAsync(Guid organizationId, DateTime? expirationDate)
{
var org = await GetOrgById(organizationId);
if (org != null && !org.Enabled && org.Gateway.HasValue)
{
org.Enabled = true;
org.ExpirationDate = expirationDate;
org.RevisionDate = DateTime.UtcNow;
await ReplaceAndUpdateCacheAsync(org);
}
}
public async Task DisableAsync(Guid organizationId, DateTime? expirationDate)
{
var org = await GetOrgById(organizationId);
@ -723,16 +711,6 @@ public class OrganizationService : IOrganizationService
}
}
public async Task EnableAsync(Guid organizationId)
{
var org = await GetOrgById(organizationId);
if (org != null && !org.Enabled)
{
org.Enabled = true;
await ReplaceAndUpdateCacheAsync(org);
}
}
public async Task UpdateAsync(Organization organization, bool updateBilling = false, EventType eventType = EventType.Organization_Updated)
{
if (organization.Id == default(Guid))

View File

@ -1,4 +1,5 @@
using System.Text.Json;
using System.Text;
using System.Text.Json;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
@ -62,8 +63,20 @@ public class RabbitMqEventListenerService : EventLoggingListenerService
{
try
{
var eventMessage = JsonSerializer.Deserialize<EventMessage>(eventArgs.Body.Span);
await _handler.HandleEventAsync(eventMessage);
using var jsonDocument = JsonDocument.Parse(Encoding.UTF8.GetString(eventArgs.Body.Span));
var root = jsonDocument.RootElement;
if (root.ValueKind == JsonValueKind.Array)
{
var eventMessages = root.Deserialize<IEnumerable<EventMessage>>();
await _handler.HandleManyEventsAsync(eventMessages);
}
else if (root.ValueKind == JsonValueKind.Object)
{
var eventMessage = root.Deserialize<EventMessage>();
await _handler.HandleEventAsync(eventMessage);
}
}
catch (Exception ex)
{

View File

@ -41,12 +41,9 @@ public class RabbitMqEventWriteService : IEventWriteService, IAsyncDisposable
using var channel = await connection.CreateChannelAsync();
await channel.ExchangeDeclareAsync(exchange: _exchangeName, type: ExchangeType.Fanout, durable: true);
foreach (var e in events)
{
var body = JsonSerializer.SerializeToUtf8Bytes(e);
var body = JsonSerializer.SerializeToUtf8Bytes(events);
await channel.BasicPublishAsync(exchange: _exchangeName, routingKey: string.Empty, body: body);
}
await channel.BasicPublishAsync(exchange: _exchangeName, routingKey: string.Empty, body: body);
}
public async ValueTask DisposeAsync()

View File

@ -4,25 +4,27 @@ using Bit.Core.Settings;
namespace Bit.Core.Services;
public class WebhookEventHandler : IEventMessageHandler
public class WebhookEventHandler(
IHttpClientFactory httpClientFactory,
GlobalSettings globalSettings)
: IEventMessageHandler
{
private readonly HttpClient _httpClient;
private readonly string _webhookUrl;
private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName);
private readonly string _webhookUrl = globalSettings.EventLogging.WebhookUrl;
public const string HttpClientName = "WebhookEventHandlerHttpClient";
public WebhookEventHandler(
IHttpClientFactory httpClientFactory,
GlobalSettings globalSettings)
{
_httpClient = httpClientFactory.CreateClient(HttpClientName);
_webhookUrl = globalSettings.EventLogging.WebhookUrl;
}
public async Task HandleEventAsync(EventMessage eventMessage)
{
var content = JsonContent.Create(eventMessage);
var response = await _httpClient.PostAsync(_webhookUrl, content);
response.EnsureSuccessStatusCode();
}
public async Task HandleManyEventsAsync(IEnumerable<EventMessage> eventMessages)
{
var content = JsonContent.Create(eventMessages);
var response = await _httpClient.PostAsync(_webhookUrl, content);
response.EnsureSuccessStatusCode();
}
}

View File

@ -10,4 +10,5 @@ public enum TwoFactorProviderType : byte
Remember = 5,
OrganizationDuo = 6,
WebAuthn = 7,
RecoveryCode = 8,
}

View File

@ -92,32 +92,7 @@ public class PremiumUserBillingService(
* If the customer was previously set up with credit, which does not require a billing location,
* we need to update the customer on the fly before we start the subscription.
*/
if (customerSetup is
{
TokenizedPaymentSource.Type: PaymentMethodType.Credit,
TaxInformation: { Country: not null and not "", PostalCode: not null and not "" }
})
{
var options = new CustomerUpdateOptions
{
Address = new AddressOptions
{
Line1 = customerSetup.TaxInformation.Line1,
Line2 = customerSetup.TaxInformation.Line2,
City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country,
},
Expand = ["tax"],
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, options);
}
customer = await ReconcileBillingLocationAsync(customer, customerSetup.TaxInformation);
var subscription = await CreateSubscriptionAsync(user.Id, customer, storage);
@ -167,6 +142,11 @@ public class PremiumUserBillingService(
User user,
CustomerSetup customerSetup)
{
/*
* Creating a Customer via the adding of a payment method or the purchasing of a subscription requires
* an actual payment source. The only time this is not the case is when the Customer is created when the
* User purchases credit.
*/
if (customerSetup.TokenizedPaymentSource is not
{
Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal,
@ -367,4 +347,34 @@ public class PremiumUserBillingService(
return subscription;
}
private async Task<Customer> ReconcileBillingLocationAsync(
Customer customer,
TaxInformation taxInformation)
{
if (customer is { Address: { Country: not null and not "", PostalCode: not null and not "" } })
{
return customer;
}
var options = new CustomerUpdateOptions
{
Address = new AddressOptions
{
Line1 = taxInformation.Line1,
Line2 = taxInformation.Line2,
City = taxInformation.City,
PostalCode = taxInformation.PostalCode,
State = taxInformation.State,
Country = taxInformation.Country,
},
Expand = ["tax"],
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
return await stripeAdapter.CustomerUpdateAsync(customer.Id, options);
}
}

View File

@ -108,7 +108,9 @@ public static class FeatureFlagKeys
public const string IntegrationPage = "pm-14505-admin-console-integration-page";
public const string DeviceApprovalRequestAdminNotifications = "pm-15637-device-approval-request-admin-notifications";
public const string LimitItemDeletion = "pm-15493-restrict-item-deletion-to-can-manage-permission";
public const string ShortcutDuplicatePatchRequests = "pm-16812-shortcut-duplicate-patch-requests";
public const string PushSyncOrgKeysOnRevokeRestore = "pm-17168-push-sync-org-keys-on-revoke-restore";
public const string PolicyRequirements = "pm-14439-policy-requirements";
public const string ScimInviteUserOptimization = "pm-16811-optimize-invite-user-flow-to-fail-fast";
/* Tools Team */
@ -171,6 +173,7 @@ public static class FeatureFlagKeys
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal";
public const string AndroidMutualTls = "mutual-tls";
public const string RecoveryCodeLogin = "pm-17128-recovery-code-login";
public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias";
public static List<string> GetAllKeys()

View File

@ -42,7 +42,7 @@
<PackageReference Include="DnsClient" Version="1.8.0" />
<PackageReference Include="Fido2.AspNet" Version="3.0.1" />
<PackageReference Include="Handlebars.Net" Version="2.1.6" />
<PackageReference Include="MailKit" Version="4.9.0" />
<PackageReference Include="MailKit" Version="4.10.0" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.10" />
<PackageReference Include="Microsoft.Azure.Cosmos" Version="3.46.1" />
<PackageReference Include="Microsoft.Azure.NotificationHubs" Version="4.2.0" />

View File

@ -29,4 +29,5 @@ public enum PushType : byte
SyncOrganizationCollectionSettingChanged = 19,
SyncNotification = 20,
SyncNotificationStatus = 21
}

View File

@ -1,11 +1,12 @@
using Bit.Core.Enums;
#nullable enable
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Enums;
namespace Bit.Core.Models;
public class PushNotificationData<T>
{
public PushNotificationData(PushType type, T payload, string contextId)
public PushNotificationData(PushType type, T payload, string? contextId)
{
Type = type;
Payload = payload;
@ -14,7 +15,7 @@ public class PushNotificationData<T>
public PushType Type { get; set; }
public T Payload { get; set; }
public string ContextId { get; set; }
public string? ContextId { get; set; }
}
public class SyncCipherPushNotification
@ -22,7 +23,7 @@ public class SyncCipherPushNotification
public Guid Id { get; set; }
public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; }
public IEnumerable<Guid> CollectionIds { get; set; }
public IEnumerable<Guid>? CollectionIds { get; set; }
public DateTime RevisionDate { get; set; }
}
@ -46,7 +47,6 @@ public class SyncSendPushNotification
public DateTime RevisionDate { get; set; }
}
#nullable enable
public class NotificationPushNotification
{
public Guid Id { get; set; }
@ -59,8 +59,9 @@ public class NotificationPushNotification
public string? Body { get; set; }
public DateTime CreationDate { get; set; }
public DateTime RevisionDate { get; set; }
public DateTime? ReadDate { get; set; }
public DateTime? DeletedDate { get; set; }
}
#nullable disable
public class AuthRequestPushNotification
{

View File

@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands.Interfaces;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -16,16 +17,19 @@ public class CreateNotificationStatusCommand : ICreateNotificationStatusCommand
private readonly IAuthorizationService _authorizationService;
private readonly INotificationRepository _notificationRepository;
private readonly INotificationStatusRepository _notificationStatusRepository;
private readonly IPushNotificationService _pushNotificationService;
public CreateNotificationStatusCommand(ICurrentContext currentContext,
IAuthorizationService authorizationService,
INotificationRepository notificationRepository,
INotificationStatusRepository notificationStatusRepository)
INotificationStatusRepository notificationStatusRepository,
IPushNotificationService pushNotificationService)
{
_currentContext = currentContext;
_authorizationService = authorizationService;
_notificationRepository = notificationRepository;
_notificationStatusRepository = notificationStatusRepository;
_pushNotificationService = pushNotificationService;
}
public async Task<NotificationStatus> CreateAsync(NotificationStatus notificationStatus)
@ -42,6 +46,10 @@ public class CreateNotificationStatusCommand : ICreateNotificationStatusCommand
await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus,
NotificationStatusOperations.Create);
return await _notificationStatusRepository.CreateAsync(notificationStatus);
var newNotificationStatus = await _notificationStatusRepository.CreateAsync(notificationStatus);
await _pushNotificationService.PushNotificationStatusAsync(notification, newNotificationStatus);
return newNotificationStatus;
}
}

View File

@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands.Interfaces;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -16,16 +17,19 @@ public class MarkNotificationDeletedCommand : IMarkNotificationDeletedCommand
private readonly IAuthorizationService _authorizationService;
private readonly INotificationRepository _notificationRepository;
private readonly INotificationStatusRepository _notificationStatusRepository;
private readonly IPushNotificationService _pushNotificationService;
public MarkNotificationDeletedCommand(ICurrentContext currentContext,
IAuthorizationService authorizationService,
INotificationRepository notificationRepository,
INotificationStatusRepository notificationStatusRepository)
INotificationStatusRepository notificationStatusRepository,
IPushNotificationService pushNotificationService)
{
_currentContext = currentContext;
_authorizationService = authorizationService;
_notificationRepository = notificationRepository;
_notificationStatusRepository = notificationStatusRepository;
_pushNotificationService = pushNotificationService;
}
public async Task MarkDeletedAsync(Guid notificationId)
@ -59,7 +63,9 @@ public class MarkNotificationDeletedCommand : IMarkNotificationDeletedCommand
await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus,
NotificationStatusOperations.Create);
await _notificationStatusRepository.CreateAsync(notificationStatus);
var newNotificationStatus = await _notificationStatusRepository.CreateAsync(notificationStatus);
await _pushNotificationService.PushNotificationStatusAsync(notification, newNotificationStatus);
}
else
{
@ -69,6 +75,8 @@ public class MarkNotificationDeletedCommand : IMarkNotificationDeletedCommand
notificationStatus.DeletedDate = DateTime.UtcNow;
await _notificationStatusRepository.UpdateAsync(notificationStatus);
await _pushNotificationService.PushNotificationStatusAsync(notification, notificationStatus);
}
}
}

View File

@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands.Interfaces;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -16,16 +17,19 @@ public class MarkNotificationReadCommand : IMarkNotificationReadCommand
private readonly IAuthorizationService _authorizationService;
private readonly INotificationRepository _notificationRepository;
private readonly INotificationStatusRepository _notificationStatusRepository;
private readonly IPushNotificationService _pushNotificationService;
public MarkNotificationReadCommand(ICurrentContext currentContext,
IAuthorizationService authorizationService,
INotificationRepository notificationRepository,
INotificationStatusRepository notificationStatusRepository)
INotificationStatusRepository notificationStatusRepository,
IPushNotificationService pushNotificationService)
{
_currentContext = currentContext;
_authorizationService = authorizationService;
_notificationRepository = notificationRepository;
_notificationStatusRepository = notificationStatusRepository;
_pushNotificationService = pushNotificationService;
}
public async Task MarkReadAsync(Guid notificationId)
@ -59,7 +63,9 @@ public class MarkNotificationReadCommand : IMarkNotificationReadCommand
await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notificationStatus,
NotificationStatusOperations.Create);
await _notificationStatusRepository.CreateAsync(notificationStatus);
var newNotificationStatus = await _notificationStatusRepository.CreateAsync(notificationStatus);
await _pushNotificationService.PushNotificationStatusAsync(notification, newNotificationStatus);
}
else
{
@ -69,6 +75,8 @@ public class MarkNotificationReadCommand : IMarkNotificationReadCommand
notificationStatus.ReadDate = DateTime.UtcNow;
await _notificationStatusRepository.UpdateAsync(notificationStatus);
await _pushNotificationService.PushNotificationStatusAsync(notification, notificationStatus);
}
}
}

View File

@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands.Interfaces;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -15,14 +16,17 @@ public class UpdateNotificationCommand : IUpdateNotificationCommand
private readonly ICurrentContext _currentContext;
private readonly IAuthorizationService _authorizationService;
private readonly INotificationRepository _notificationRepository;
private readonly IPushNotificationService _pushNotificationService;
public UpdateNotificationCommand(ICurrentContext currentContext,
IAuthorizationService authorizationService,
INotificationRepository notificationRepository)
INotificationRepository notificationRepository,
IPushNotificationService pushNotificationService)
{
_currentContext = currentContext;
_authorizationService = authorizationService;
_notificationRepository = notificationRepository;
_pushNotificationService = pushNotificationService;
}
public async Task UpdateAsync(Notification notificationToUpdate)
@ -43,5 +47,7 @@ public class UpdateNotificationCommand : IUpdateNotificationCommand
notification.RevisionDate = DateTime.UtcNow;
await _notificationRepository.ReplaceAsync(notification);
await _pushNotificationService.PushNotificationAsync(notification);
}
}

View File

@ -1,4 +1,5 @@
using System.Text.Json;
#nullable enable
using System.Text.Json;
using System.Text.RegularExpressions;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
@ -6,6 +7,7 @@ using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.Models.Data;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Tools.Entities;
@ -51,7 +53,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid>? collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
@ -209,6 +211,36 @@ public class NotificationHubPushNotificationService : IPushNotificationService
}
}
public async Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate,
ReadDate = notificationStatus.ReadDate,
DeletedDate = notificationStatus.DeletedDate
};
if (notification.UserId.HasValue)
{
await SendPayloadToUserAsync(notification.UserId.Value, PushType.SyncNotificationStatus, message, true,
notification.ClientType);
}
else if (notification.OrganizationId.HasValue)
{
await SendPayloadToOrganizationAsync(notification.OrganizationId.Value, PushType.SyncNotificationStatus, message,
true, notification.ClientType);
}
}
private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type)
{
var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId };
@ -230,8 +262,8 @@ public class NotificationHubPushNotificationService : IPushNotificationService
GetContextIdentifier(excludeCurrentContext), clientType: clientType);
}
public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier, clientType);
await SendPayloadAsync(tag, type, payload);
@ -241,8 +273,8 @@ public class NotificationHubPushNotificationService : IPushNotificationService
}
}
public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier, clientType);
await SendPayloadAsync(tag, type, payload);
@ -277,7 +309,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
false
);
private string GetContextIdentifier(bool excludeCurrentContext)
private string? GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
{
@ -285,11 +317,11 @@ public class NotificationHubPushNotificationService : IPushNotificationService
}
var currentContext =
_httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
_httpContextAccessor.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
private string BuildTag(string tag, string identifier, ClientType? clientType)
private string BuildTag(string tag, string? identifier, ClientType? clientType)
{
if (!string.IsNullOrWhiteSpace(identifier))
{

View File

@ -56,6 +56,7 @@ public static class OrganizationServiceCollectionExtensions
services.AddOrganizationDomainCommandsQueries();
services.AddOrganizationSignUpCommands();
services.AddOrganizationDeleteCommands();
services.AddOrganizationEnableCommands();
services.AddOrganizationAuthCommands();
services.AddOrganizationUserCommands();
services.AddOrganizationUserCommandsQueries();
@ -71,6 +72,9 @@ public static class OrganizationServiceCollectionExtensions
services.AddScoped<IOrganizationInitiateDeleteCommand, OrganizationInitiateDeleteCommand>();
}
private static void AddOrganizationEnableCommands(this IServiceCollection services) =>
services.AddScoped<IOrganizationEnableCommand, OrganizationEnableCommand>();
private static void AddOrganizationConnectionCommands(this IServiceCollection services)
{
services.AddScoped<ICreateOrganizationConnectionCommand, CreateOrganizationConnectionCommand>();

View File

@ -1,4 +1,5 @@
using System.Text.Json;
#nullable enable
using System.Text.Json;
using Azure.Storage.Queues;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
@ -42,7 +43,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid>? collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
@ -184,6 +185,27 @@ public class AzureQueuePushNotificationService : IPushNotificationService
await SendMessageAsync(PushType.SyncNotification, message, true);
}
public async Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate,
ReadDate = notificationStatus.ReadDate,
DeletedDate = notificationStatus.DeletedDate
};
await SendMessageAsync(PushType.SyncNotificationStatus, message, true);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
@ -207,7 +229,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
await _queueClient.SendMessageAsync(message);
}
private string GetContextIdentifier(bool excludeCurrentContext)
private string? GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
{
@ -219,15 +241,15 @@ public class AzureQueuePushNotificationService : IPushNotificationService
return currentContext?.DeviceIdentifier;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
// Noop
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
// Noop
return Task.FromResult(0);

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
@ -25,12 +26,13 @@ public interface IPushNotificationService
Task PushSyncSendUpdateAsync(Send send);
Task PushSyncSendDeleteAsync(Send send);
Task PushNotificationAsync(Notification notification);
Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus);
Task PushAuthRequestAsync(AuthRequest authRequest);
Task PushAuthRequestResponseAsync(AuthRequest authRequest);
Task PushSyncOrganizationStatusAsync(Organization organization);
Task PushSyncOrganizationCollectionManagementSettingsAsync(Organization organization);
Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null);
Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null);
Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null);
Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null);
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
@ -24,7 +25,7 @@ public class MultiServicePushNotificationService : IPushNotificationService
_logger = logger;
_logger.LogInformation("Hub services: {Services}", _services.Count());
globalSettings?.NotificationHubPool?.NotificationHubs?.ForEach(hub =>
globalSettings.NotificationHubPool?.NotificationHubs?.ForEach(hub =>
{
_logger.LogInformation("HubName: {HubName}, EnableSendTracing: {EnableSendTracing}, RegistrationStartDate: {RegistrationStartDate}, RegistrationEndDate: {RegistrationEndDate}", hub.HubName, hub.EnableSendTracing, hub.RegistrationStartDate, hub.RegistrationEndDate);
});
@ -150,15 +151,21 @@ public class MultiServicePushNotificationService : IPushNotificationService
return Task.CompletedTask;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus)
{
PushToServices((s) => s.PushNotificationStatusAsync(notification, notificationStatus));
return Task.CompletedTask;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType));
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId, clientType));
return Task.FromResult(0);
@ -166,12 +173,16 @@ public class MultiServicePushNotificationService : IPushNotificationService
private void PushToServices(Func<IPushNotificationService, Task> pushFunc)
{
if (_services != null)
if (!_services.Any())
{
foreach (var service in _services)
{
pushFunc(service);
}
_logger.LogWarning("No services found to push notification");
return;
}
foreach (var service in _services)
{
_logger.LogDebug("Pushing notification to service {ServiceName}", service.GetType().Name);
pushFunc(service);
}
}
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
@ -84,8 +85,8 @@ public class NoopPushNotificationService : IPushNotificationService
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
return Task.FromResult(0);
}
@ -107,11 +108,14 @@ public class NoopPushNotificationService : IPushNotificationService
return Task.FromResult(0);
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
return Task.FromResult(0);
}
public Task PushNotificationAsync(Notification notification) => Task.CompletedTask;
public Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus) =>
Task.CompletedTask;
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Context;
using Bit.Core.Enums;
@ -16,7 +17,6 @@ namespace Bit.Core.Platform.Push;
public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService
{
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
public NotificationsApiPushNotificationService(
@ -33,7 +33,6 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
globalSettings.InternalIdentityKey,
logger)
{
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
}
@ -52,7 +51,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid>? collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
@ -203,6 +202,27 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
await SendMessageAsync(PushType.SyncNotification, message, true);
}
public async Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate,
ReadDate = notificationStatus.ReadDate,
DeletedDate = notificationStatus.DeletedDate
};
await SendMessageAsync(PushType.SyncNotificationStatus, message, true);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
@ -225,7 +245,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
await SendAsync(HttpMethod.Post, "send", request);
}
private string GetContextIdentifier(bool excludeCurrentContext)
private string? GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
{
@ -233,19 +253,19 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
}
var currentContext =
_httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
_httpContextAccessor.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
// Noop
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
// Noop
return Task.FromResult(0);

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Context;
using Bit.Core.Enums;
@ -55,7 +56,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid>? collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
@ -219,6 +220,36 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
}
}
public async Task PushNotificationStatusAsync(Notification notification, NotificationStatus notificationStatus)
{
var message = new NotificationPushNotification
{
Id = notification.Id,
Priority = notification.Priority,
Global = notification.Global,
ClientType = notification.ClientType,
UserId = notification.UserId,
OrganizationId = notification.OrganizationId,
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate,
ReadDate = notificationStatus.ReadDate,
DeletedDate = notificationStatus.DeletedDate
};
if (notification.UserId.HasValue)
{
await SendPayloadToUserAsync(notification.UserId.Value, PushType.SyncNotificationStatus, message, true,
notification.ClientType);
}
else if (notification.OrganizationId.HasValue)
{
await SendPayloadToOrganizationAsync(notification.OrganizationId.Value, PushType.SyncNotificationStatus, message,
true, notification.ClientType);
}
}
public async Task PushSyncOrganizationStatusAsync(Organization organization)
{
var message = new OrganizationStatusPushNotification
@ -277,7 +308,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier)
{
var currentContext =
_httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
_httpContextAccessor.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier))
{
var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier);
@ -293,14 +324,14 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
}
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
throw new NotImplementedException();
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null, ClientType? clientType = null)
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string? identifier,
string? deviceId = null, ClientType? clientType = null)
{
throw new NotImplementedException();
}

View File

@ -22,7 +22,6 @@ public interface IUserService
Task<IdentityResult> CreateUserAsync(User user, string masterPasswordHash);
Task SendMasterPasswordHintAsync(string email);
Task SendTwoFactorEmailAsync(User user);
Task<bool> VerifyTwoFactorEmailAsync(User user, string token);
Task<CredentialCreateOptions> StartWebAuthnRegistrationAsync(User user);
Task<bool> DeleteWebAuthnKeyAsync(User user, int id);
Task<bool> CompleteWebAuthRegistrationAsync(User user, int value, string name, AuthenticatorAttestationRawResponse attestationResponse);
@ -41,8 +40,6 @@ public interface IUserService
Task<IdentityResult> RefreshSecurityStampAsync(User user, string masterPasswordHash);
Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true);
Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type);
Task<bool> RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode);
Task<string> GenerateUserTokenAsync(User user, string tokenProvider, string purpose);
Task<IdentityResult> DeleteAsync(User user);
Task<IdentityResult> DeleteAsync(User user, string token);
Task SendDeleteConfirmationAsync(string email);
@ -55,9 +52,7 @@ public interface IUserService
Task CancelPremiumAsync(User user, bool? endOfPeriod = null);
Task ReinstatePremiumAsync(User user);
Task EnablePremiumAsync(Guid userId, DateTime? expirationDate);
Task EnablePremiumAsync(User user, DateTime? expirationDate);
Task DisablePremiumAsync(Guid userId, DateTime? expirationDate);
Task DisablePremiumAsync(User user, DateTime? expirationDate);
Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate);
Task<UserLicense> GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null,
int? version = null);
@ -91,9 +86,26 @@ public interface IUserService
void SetTwoFactorProvider(User user, TwoFactorProviderType type, bool setEnabled = true);
[Obsolete("To be removed when the feature flag pm-17128-recovery-code-login is removed PM-18175.")]
Task<bool> RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode);
/// <summary>
/// Returns true if the user is a legacy user. Legacy users use their master key as their encryption key.
/// We force these users to the web to migrate their encryption scheme.
/// This method is used by the TwoFactorAuthenticationValidator to recover two
/// factor for a user. This allows users to be logged in after a successful recovery
/// attempt.
///
/// This method logs the event, sends an email to the user, and removes two factor
/// providers on the user account. This means that a user will have to accomplish
/// new device verification on their account on new logins, if it is enabled for their user.
/// </summary>
/// <param name="recoveryCode">recovery code associated with the user logging in</param>
/// <param name="user">The user to refresh the 2FA and Recovery Code on.</param>
/// <returns>true if the recovery code is valid; false otherwise</returns>
Task<bool> RecoverTwoFactorAsync(User user, string recoveryCode);
/// <summary>
/// Returns true if the user is a legacy user. Legacy users use their master key as their
/// encryption key. We force these users to the web to migrate their encryption scheme.
/// </summary>
Task<bool> IsLegacyUser(string userId);
@ -101,7 +113,8 @@ public interface IUserService
/// Indicates if the user is managed by any organization.
/// </summary>
/// <remarks>
/// A user is considered managed by an organization if their email domain matches one of the verified domains of that organization, and the user is a member of it.
/// A user is considered managed by an organization if their email domain matches one of the
/// verified domains of that organization, and the user is a member of it.
/// The organization must be enabled and able to have verified domains.
/// </remarks>
/// <returns>

View File

@ -1852,7 +1852,6 @@ public class StripePaymentService : IPaymentService
Enabled = true,
},
Currency = "usd",
Discounts = new List<InvoiceDiscountOptions>(),
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
@ -1903,29 +1902,23 @@ public class StripePaymentService : IPaymentService
];
}
if (gatewayCustomerId != null)
if (!string.IsNullOrWhiteSpace(gatewayCustomerId))
{
var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId);
if (gatewayCustomer.Discount != null)
{
options.Discounts.Add(new InvoiceDiscountOptions
{
Discount = gatewayCustomer.Discount.Id
});
options.Coupon = gatewayCustomer.Discount.Coupon.Id;
}
}
if (gatewaySubscriptionId != null)
if (!string.IsNullOrWhiteSpace(gatewaySubscriptionId))
{
var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId);
if (gatewaySubscription?.Discount != null)
{
var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId);
if (gatewaySubscription?.Discount != null)
{
options.Discounts.Add(new InvoiceDiscountOptions
{
Discount = gatewaySubscription.Discount.Id
});
}
options.Coupon ??= gatewaySubscription.Discount.Coupon.Id;
}
}
@ -1976,7 +1969,6 @@ public class StripePaymentService : IPaymentService
Enabled = true,
},
Currency = "usd",
Discounts = new List<InvoiceDiscountOptions>(),
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items =
@ -2069,7 +2061,7 @@ public class StripePaymentService : IPaymentService
if (gatewayCustomer.Discount != null)
{
options.Discounts.Add(new InvoiceDiscountOptions { Discount = gatewayCustomer.Discount.Id });
options.Coupon = gatewayCustomer.Discount.Coupon.Id;
}
}
@ -2079,10 +2071,7 @@ public class StripePaymentService : IPaymentService
if (gatewaySubscription?.Discount != null)
{
options.Discounts.Add(new InvoiceDiscountOptions
{
Discount = gatewaySubscription.Discount.Id
});
options.Coupon ??= gatewaySubscription.Discount.Coupon.Id;
}
}

View File

@ -315,7 +315,7 @@ public class UserService : UserManager<User>, IUserService, IDisposable
return;
}
var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount");
var token = await GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount");
await _mailService.SendVerifyDeleteEmailAsync(user.Email, user.Id, token);
}
@ -868,6 +868,10 @@ public class UserService : UserManager<User>, IUserService, IDisposable
}
}
/// <summary>
/// To be removed when the feature flag pm-17128-recovery-code-login is removed PM-18175.
/// </summary>
[Obsolete("Two Factor recovery is handled in the TwoFactorAuthenticationValidator.")]
public async Task<bool> RecoverTwoFactorAsync(string email, string secret, string recoveryCode)
{
var user = await _userRepository.GetByEmailAsync(email);
@ -897,6 +901,25 @@ public class UserService : UserManager<User>, IUserService, IDisposable
return true;
}
public async Task<bool> RecoverTwoFactorAsync(User user, string recoveryCode)
{
if (!CoreHelpers.FixedTimeEquals(
user.TwoFactorRecoveryCode,
recoveryCode.Replace(" ", string.Empty).Trim().ToLower()))
{
return false;
}
user.TwoFactorProviders = null;
user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false);
await SaveUserAsync(user);
await _mailService.SendRecoverTwoFactorEmail(user.Email, DateTime.UtcNow, _currentContext.IpAddress);
await _eventService.LogUserEventAsync(user.Id, EventType.User_Recovered2fa);
await CheckPoliciesOnTwoFactorRemovalAsync(user);
return true;
}
public async Task<Tuple<bool, string>> SignUpPremiumAsync(User user, string paymentToken,
PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license,
TaxInfo taxInfo)
@ -1081,7 +1104,7 @@ public class UserService : UserManager<User>, IUserService, IDisposable
await EnablePremiumAsync(user, expirationDate);
}
public async Task EnablePremiumAsync(User user, DateTime? expirationDate)
private async Task EnablePremiumAsync(User user, DateTime? expirationDate)
{
if (user != null && !user.Premium && user.Gateway.HasValue)
{
@ -1098,7 +1121,7 @@ public class UserService : UserManager<User>, IUserService, IDisposable
await DisablePremiumAsync(user, expirationDate);
}
public async Task DisablePremiumAsync(User user, DateTime? expirationDate)
private async Task DisablePremiumAsync(User user, DateTime? expirationDate)
{
if (user != null && user.Premium)
{

View File

@ -1,4 +1,5 @@
using System.Globalization;
using Bit.Core.AdminConsole.Services.Implementations;
using Bit.Core.Context;
using Bit.Core.IdentityServer;
using Bit.Core.Services;
@ -63,11 +64,29 @@ public class Startup
services.AddScoped<IEventService, EventService>();
if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString))
{
services.AddSingleton<IEventWriteService, AzureQueueEventWriteService>();
if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.AzureServiceBus.ConnectionString) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.AzureServiceBus.TopicName))
{
services.AddSingleton<IEventWriteService, AzureServiceBusEventWriteService>();
}
else
{
services.AddSingleton<IEventWriteService, AzureQueueEventWriteService>();
}
}
else
{
services.AddSingleton<IEventWriteService, RepositoryEventWriteService>();
if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.HostName) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.Username) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.Password) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.ExchangeName))
{
services.AddSingleton<IEventWriteService, RabbitMqEventWriteService>();
}
else
{
services.AddSingleton<IEventWriteService, RepositoryEventWriteService>();
}
}
services.AddOptionality();

View File

@ -77,7 +77,7 @@ public abstract class BaseRequestValidator<T> where T : class
protected async Task ValidateAsync(T context, ValidatedTokenRequest request,
CustomValidatorRequestContext validatorContext)
{
// 1. we need to check if the user is a bot and if their master password hash is correct
// 1. We need to check if the user is a bot and if their master password hash is correct.
var isBot = validatorContext.CaptchaResponse?.IsBot ?? false;
var valid = await ValidateContextAsync(context, validatorContext);
var user = validatorContext.User;
@ -99,7 +99,7 @@ public abstract class BaseRequestValidator<T> where T : class
return;
}
// 2. Does this user belong to an organization that requires SSO
// 2. Decide if this user belongs to an organization that requires SSO.
validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType);
if (validatorContext.SsoRequired)
{
@ -111,17 +111,22 @@ public abstract class BaseRequestValidator<T> where T : class
return;
}
// 3. Check if 2FA is required
(validatorContext.TwoFactorRequired, var twoFactorOrganization) = await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request);
// This flag is used to determine if the user wants a rememberMe token sent when authentication is successful
// 3. Check if 2FA is required.
(validatorContext.TwoFactorRequired, var twoFactorOrganization) =
await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request);
// This flag is used to determine if the user wants a rememberMe token sent when
// authentication is successful.
var returnRememberMeToken = false;
if (validatorContext.TwoFactorRequired)
{
var twoFactorToken = request.Raw["TwoFactorToken"]?.ToString();
var twoFactorProvider = request.Raw["TwoFactorProvider"]?.ToString();
var twoFactorToken = request.Raw["TwoFactorToken"];
var twoFactorProvider = request.Raw["TwoFactorProvider"];
var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) &&
!string.IsNullOrWhiteSpace(twoFactorProvider);
// response for 2FA required and not provided state
// 3a. Response for 2FA required and not provided state.
if (!validTwoFactorRequest ||
!Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType))
{
@ -133,26 +138,27 @@ public abstract class BaseRequestValidator<T> where T : class
return;
}
// Include Master Password Policy in 2FA response
resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicy(user));
// Include Master Password Policy in 2FA response.
resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user));
SetTwoFactorResult(context, resultDict);
return;
}
var twoFactorTokenValid = await _twoFactorAuthenticationValidator
.VerifyTwoFactor(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken);
var twoFactorTokenValid =
await _twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken);
// response for 2FA required but request is not valid or remember token expired state
// 3b. Response for 2FA required but request is not valid or remember token expired state.
if (!twoFactorTokenValid)
{
// The remember me token has expired
// The remember me token has expired.
if (twoFactorProviderType == TwoFactorProviderType.Remember)
{
var resultDict = await _twoFactorAuthenticationValidator
.BuildTwoFactorResultAsync(user, twoFactorOrganization);
// Include Master Password Policy in 2FA response
resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicy(user));
resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user));
SetTwoFactorResult(context, resultDict);
}
else
@ -163,17 +169,19 @@ public abstract class BaseRequestValidator<T> where T : class
return;
}
// When the two factor authentication is successful, we can check if the user wants a rememberMe token
var twoFactorRemember = request.Raw["TwoFactorRemember"]?.ToString() == "1";
if (twoFactorRemember // Check if the user wants a rememberMe token
&& twoFactorTokenValid // Make sure two factor authentication was successful
&& twoFactorProviderType != TwoFactorProviderType.Remember) // if the two factor auth was rememberMe do not send another token
// 3c. When the 2FA authentication is successful, we can check if the user wants a
// rememberMe token.
var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1";
// Check if the user wants a rememberMe token.
if (twoFactorRemember
// if the 2FA auth was rememberMe do not send another token.
&& twoFactorProviderType != TwoFactorProviderType.Remember)
{
returnRememberMeToken = true;
}
}
// 4. Check if the user is logging in from a new device
// 4. Check if the user is logging in from a new device.
var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext);
if (!deviceValid)
{
@ -182,7 +190,7 @@ public abstract class BaseRequestValidator<T> where T : class
return;
}
// 5. Force legacy users to the web for migration
// 5. Force legacy users to the web for migration.
if (UserService.IsLegacyUser(user) && request.ClientId != "web")
{
await FailAuthForLegacyUserAsync(user, context);
@ -224,7 +232,7 @@ public abstract class BaseRequestValidator<T> where T : class
customResponse.Add("Key", user.Key);
}
customResponse.Add("MasterPasswordPolicy", await GetMasterPasswordPolicy(user));
customResponse.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user));
customResponse.Add("ForcePasswordReset", user.ForcePasswordReset);
customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword));
customResponse.Add("Kdf", (byte)user.Kdf);
@ -403,7 +411,7 @@ public abstract class BaseRequestValidator<T> where T : class
return unknownDevice && failedLoginCeiling > 0 && failedLoginCount == failedLoginCeiling;
}
private async Task<MasterPasswordPolicyResponseModel> GetMasterPasswordPolicy(User user)
private async Task<MasterPasswordPolicyResponseModel> GetMasterPasswordPolicyAsync(User user)
{
// Check current context/cache to see if user is in any organizations, avoids extra DB call if not
var orgs = (await CurrentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id))

View File

@ -1,4 +1,5 @@
using System.Text.Json;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Identity.TokenProviders;
@ -44,7 +45,7 @@ public interface ITwoFactorAuthenticationValidator
/// <param name="twoFactorProviderType">Two Factor Provider to use to verify the token</param>
/// <param name="token">secret passed from the user and consumed by the two-factor provider's verify method</param>
/// <returns>boolean</returns>
Task<bool> VerifyTwoFactor(User user, Organization organization, TwoFactorProviderType twoFactorProviderType, string token);
Task<bool> VerifyTwoFactorAsync(User user, Organization organization, TwoFactorProviderType twoFactorProviderType, string token);
}
public class TwoFactorAuthenticationValidator(
@ -139,7 +140,7 @@ public class TwoFactorAuthenticationValidator(
return twoFactorResultDict;
}
public async Task<bool> VerifyTwoFactor(
public async Task<bool> VerifyTwoFactorAsync(
User user,
Organization organization,
TwoFactorProviderType type,
@ -154,24 +155,39 @@ public class TwoFactorAuthenticationValidator(
return false;
}
switch (type)
if (_featureService.IsEnabled(FeatureFlagKeys.RecoveryCodeLogin))
{
case TwoFactorProviderType.Authenticator:
case TwoFactorProviderType.Email:
case TwoFactorProviderType.Duo:
case TwoFactorProviderType.YubiKey:
case TwoFactorProviderType.WebAuthn:
case TwoFactorProviderType.Remember:
if (type != TwoFactorProviderType.Remember &&
!await _userService.TwoFactorProviderIsEnabledAsync(type, user))
{
return false;
}
return await _userManager.VerifyTwoFactorTokenAsync(user,
CoreHelpers.CustomProviderName(type), token);
default:
return false;
if (type is TwoFactorProviderType.RecoveryCode)
{
return await _userService.RecoverTwoFactorAsync(user, token);
}
}
// These cases we want to always return false, U2f is deprecated and OrganizationDuo
// uses a different flow than the other two factor providers, it follows the same
// structure of a UserTokenProvider but has it's logic ran outside the usual token
// provider flow. See IOrganizationDuoUniversalTokenProvider.cs
if (type is TwoFactorProviderType.U2f or TwoFactorProviderType.OrganizationDuo)
{
return false;
}
// Now we are concerning the rest of the Two Factor Provider Types
// The intent of this check is to make sure that the user is using a 2FA provider that
// is enabled and allowed by their premium status. The exception for Remember
// is because it is a "special" 2FA type that isn't ever explicitly
// enabled by a user, so we can't check the user's 2FA providers to see if they're
// enabled. We just have to check if the token is valid.
if (type != TwoFactorProviderType.Remember &&
!await _userService.TwoFactorProviderIsEnabledAsync(type, user))
{
return false;
}
// Finally, verify the token based on the provider type.
return await _userManager.VerifyTwoFactorTokenAsync(
user, CoreHelpers.CustomProviderName(type), token);
}
private async Task<List<KeyValuePair<TwoFactorProviderType, TwoFactorProvider>>> GetEnabledTwoFactorProvidersAsync(

View File

@ -109,9 +109,13 @@ public class GroupRepository : Repository<Group, Guid>, IGroupRepository
}
}
public async Task<ICollection<Guid>> GetManyUserIdsByIdAsync(Guid id)
public async Task<ICollection<Guid>> GetManyUserIdsByIdAsync(Guid id, bool useReadOnlyReplica = false)
{
using (var connection = new SqlConnection(ConnectionString))
var connectionString = useReadOnlyReplica
? ReadOnlyConnectionString
: ConnectionString;
using (var connection = new SqlConnection(connectionString))
{
var results = await connection.QueryAsync<Guid>(
$"[{Schema}].[GroupUser_ReadOrganizationUserIdsByGroupId]",
@ -186,6 +190,17 @@ public class GroupRepository : Repository<Group, Guid>, IGroupRepository
}
}
public async Task AddGroupUsersByIdAsync(Guid groupId, IEnumerable<Guid> organizationUserIds)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.ExecuteAsync(
"[dbo].[GroupUser_AddUsers]",
new { GroupId = groupId, OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
}
}
public async Task DeleteManyAsync(IEnumerable<Guid> groupIds)
{
using (var connection = new SqlConnection(ConnectionString))

View File

@ -1,6 +1,7 @@
using System.Data;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Settings;
using Bit.Infrastructure.Dapper.Repositories;
@ -59,4 +60,17 @@ public class PolicyRepository : Repository<Policy, Guid>, IPolicyRepository
return results.ToList();
}
}
public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<PolicyDetails>(
$"[{Schema}].[PolicyDetails_ReadByUserId]",
new { UserId = userId },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
}
}

View File

@ -163,8 +163,10 @@ public class GroupRepository : Repository<AdminConsoleEntities.Group, Group, Gui
}
}
public async Task<ICollection<Guid>> GetManyUserIdsByIdAsync(Guid id)
public async Task<ICollection<Guid>> GetManyUserIdsByIdAsync(Guid id, bool useReadOnlyReplica = false)
{
// EF is only used for self-hosted so read-only replica parameter is ignored
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -255,6 +257,29 @@ public class GroupRepository : Repository<AdminConsoleEntities.Group, Group, Gui
}
}
public async Task AddGroupUsersByIdAsync(Guid groupId, IEnumerable<Guid> organizationUserIds)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgId = (await dbContext.Groups.FindAsync(groupId)).OrganizationId;
var insert = from ou in dbContext.OrganizationUsers
where organizationUserIds.Contains(ou.Id) &&
ou.OrganizationId == orgId &&
!dbContext.GroupUsers.Any(gu => gu.GroupId == groupId && ou.Id == gu.OrganizationUserId)
select new GroupUser
{
GroupId = groupId,
OrganizationUserId = ou.Id,
};
await dbContext.AddRangeAsync(insert);
await dbContext.SaveChangesAsync();
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(orgId);
await dbContext.SaveChangesAsync();
}
}
public async Task DeleteManyAsync(IEnumerable<Guid> groupIds)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@ -1,6 +1,8 @@
using AutoMapper;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Infrastructure.EntityFramework.AdminConsole.Models;
using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries;
using Bit.Infrastructure.EntityFramework.Repositories;
@ -50,4 +52,43 @@ public class PolicyRepository : Repository<AdminConsoleEntities.Policy, Policy,
return Mapper.Map<List<AdminConsoleEntities.Policy>>(results);
}
}
public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var providerOrganizations = from pu in dbContext.ProviderUsers
where pu.UserId == userId
join po in dbContext.ProviderOrganizations
on pu.ProviderId equals po.ProviderId
select po;
var query = from p in dbContext.Policies
join ou in dbContext.OrganizationUsers
on p.OrganizationId equals ou.OrganizationId
join o in dbContext.Organizations
on p.OrganizationId equals o.Id
where
p.Enabled &&
o.Enabled &&
o.UsePolicies &&
(
(ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) ||
// Invited orgUsers do not have a UserId associated with them, so we have to match up their email
(ou.Status == OrganizationUserStatusType.Invited && ou.Email == dbContext.Users.Find(userId).Email)
)
select new PolicyDetails
{
OrganizationUserId = ou.Id,
OrganizationId = p.OrganizationId,
PolicyType = p.Type,
PolicyData = p.Data,
OrganizationUserType = ou.Type,
OrganizationUserStatus = ou.Status,
OrganizationUserPermissionsData = ou.Permissions,
IsProvider = providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId)
};
return await query.ToListAsync();
}
}

View File

@ -35,6 +35,7 @@ public class ProviderOrganizationOrganizationDetailsReadByProviderIdQuery : IQue
OccupiedSeats = x.o.OrganizationUsers.Count(ou => ou.Status >= 0),
Seats = x.o.Seats,
Plan = x.o.Plan,
PlanType = x.o.PlanType,
Status = x.o.Status
});
}

View File

@ -104,6 +104,7 @@ public static class HubHelpers
.SendAsync("ReceiveMessage", organizationCollectionSettingsChangedNotification, cancellationToken);
break;
case PushType.SyncNotification:
case PushType.SyncNotificationStatus:
var syncNotification =
JsonSerializer.Deserialize<PushNotificationData<NotificationPushNotification>>(
notificationJson, _deserializerOptions);

View File

@ -0,0 +1,39 @@
CREATE PROCEDURE [dbo].[GroupUser_AddUsers]
@GroupId UNIQUEIDENTIFIER,
@OrganizationUserIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
DECLARE @OrgId UNIQUEIDENTIFIER = (
SELECT TOP 1
[OrganizationId]
FROM
[dbo].[Group]
WHERE
[Id] = @GroupId
)
-- Insert
INSERT INTO
[dbo].[GroupUser] (GroupId, OrganizationUserId)
SELECT DISTINCT
@GroupId,
[Source].[Id]
FROM
@OrganizationUserIds AS [Source]
INNER JOIN
[dbo].[OrganizationUser] OU ON [Source].[Id] = OU.[Id] AND OU.[OrganizationId] = @OrgId
WHERE
NOT EXISTS (
SELECT
1
FROM
[dbo].[GroupUser]
WHERE
[GroupId] = @GroupId
AND [OrganizationUserId] = [Source].[Id]
)
EXEC [dbo].[User_BumpAccountRevisionDateByOrganizationId] @OrgId
END

View File

@ -0,0 +1,43 @@
CREATE PROCEDURE [dbo].[PolicyDetails_ReadByUserId]
@UserId UNIQUEIDENTIFIER
AS
BEGIN
SET NOCOUNT ON
SELECT
OU.[Id] AS OrganizationUserId,
P.[OrganizationId],
P.[Type] AS PolicyType,
P.[Data] AS PolicyData,
OU.[Type] AS OrganizationUserType,
OU.[Status] AS OrganizationUserStatus,
OU.[Permissions] AS OrganizationUserPermissionsData,
CASE WHEN EXISTS (
SELECT 1
FROM [dbo].[ProviderUserView] PU
INNER JOIN [dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId]
WHERE PU.[UserId] = OU.[UserId] AND PO.[OrganizationId] = P.[OrganizationId]
) THEN 1 ELSE 0 END AS IsProvider
FROM [dbo].[PolicyView] P
INNER JOIN [dbo].[OrganizationUserView] OU
ON P.[OrganizationId] = OU.[OrganizationId]
INNER JOIN [dbo].[OrganizationView] O
ON P.[OrganizationId] = O.[Id]
WHERE
P.Enabled = 1
AND O.Enabled = 1
AND O.UsePolicies = 1
AND (
-- OrgUsers who have accepted their invite and are linked to a UserId
-- (Note: this excludes "invited but revoked" users who don't have an OU.UserId yet,
-- but those users will go through policy enforcement later as part of accepting their invite after being restored.
-- This is an intentionally unhandled edge case for now.)
(OU.[Status] != 0 AND OU.[UserId] = @UserId)
-- 'Invited' OrgUsers are not linked to a UserId yet, so we have to look up their email
OR EXISTS (
SELECT 1
FROM [dbo].[UserView] U
WHERE U.[Id] = @UserId AND OU.[Email] = U.[Email] AND OU.[Status] = 0
)
)
END

View File

@ -13,6 +13,7 @@ SELECT
(SELECT COUNT(1) FROM [dbo].[OrganizationUser] OU WHERE OU.OrganizationId = PO.OrganizationId AND OU.Status >= 0) OccupiedSeats,
O.[Seats],
O.[Plan],
O.[PlanType],
O.[Status]
FROM
[dbo].[ProviderOrganization] PO

View File

@ -0,0 +1,147 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations;
[SutProviderCustomize]
public class OrganizationEnableCommandTests
{
[Theory, BitAutoData]
public async Task EnableAsync_WhenOrganizationDoesNotExist_DoesNothing(
Guid organizationId,
SutProvider<OrganizationEnableCommand> sutProvider)
{
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organizationId)
.Returns((Organization)null);
await sutProvider.Sut.EnableAsync(organizationId);
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.ReplaceAsync(Arg.Any<Organization>());
await sutProvider.GetDependency<IApplicationCacheService>()
.DidNotReceive()
.UpsertOrganizationAbilityAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task EnableAsync_WhenOrganizationAlreadyEnabled_DoesNothing(
Organization organization,
SutProvider<OrganizationEnableCommand> sutProvider)
{
organization.Enabled = true;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
await sutProvider.Sut.EnableAsync(organization.Id);
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.ReplaceAsync(Arg.Any<Organization>());
await sutProvider.GetDependency<IApplicationCacheService>()
.DidNotReceive()
.UpsertOrganizationAbilityAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task EnableAsync_WhenOrganizationDisabled_EnablesAndSaves(
Organization organization,
SutProvider<OrganizationEnableCommand> sutProvider)
{
organization.Enabled = false;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
await sutProvider.Sut.EnableAsync(organization.Id);
Assert.True(organization.Enabled);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.ReplaceAsync(organization);
await sutProvider.GetDependency<IApplicationCacheService>()
.Received(1)
.UpsertOrganizationAbilityAsync(organization);
}
[Theory, BitAutoData]
public async Task EnableAsync_WithExpiration_WhenOrganizationHasNoGateway_DoesNothing(
Organization organization,
DateTime expirationDate,
SutProvider<OrganizationEnableCommand> sutProvider)
{
organization.Enabled = false;
organization.Gateway = null;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
await sutProvider.Sut.EnableAsync(organization.Id, expirationDate);
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.ReplaceAsync(Arg.Any<Organization>());
await sutProvider.GetDependency<IApplicationCacheService>()
.DidNotReceive()
.UpsertOrganizationAbilityAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task EnableAsync_WithExpiration_WhenValid_EnablesAndSetsExpiration(
Organization organization,
DateTime expirationDate,
SutProvider<OrganizationEnableCommand> sutProvider)
{
organization.Enabled = false;
organization.Gateway = GatewayType.Stripe;
organization.RevisionDate = DateTime.UtcNow.AddDays(-1);
var originalRevisionDate = organization.RevisionDate;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
await sutProvider.Sut.EnableAsync(organization.Id, expirationDate);
Assert.True(organization.Enabled);
Assert.Equal(expirationDate, organization.ExpirationDate);
Assert.True(organization.RevisionDate > originalRevisionDate);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.ReplaceAsync(organization);
await sutProvider.GetDependency<IApplicationCacheService>()
.Received(1)
.UpsertOrganizationAbilityAsync(organization);
}
[Theory, BitAutoData]
public async Task EnableAsync_WithoutExpiration_DoesNotUpdateRevisionDate(
Organization organization,
SutProvider<OrganizationEnableCommand> sutProvider)
{
organization.Enabled = false;
var originalRevisionDate = organization.RevisionDate;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
await sutProvider.Sut.EnableAsync(organization.Id);
Assert.True(organization.Enabled);
Assert.Equal(originalRevisionDate, organization.RevisionDate);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.ReplaceAsync(organization);
await sutProvider.GetDependency<IApplicationCacheService>()
.Received(1)
.UpsertOrganizationAbilityAsync(organization);
}
}

View File

@ -0,0 +1,60 @@
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
[SutProviderCustomize]
public class PolicyRequirementQueryTests
{
/// <summary>
/// Tests that the query correctly registers, retrieves and instantiates arbitrary IPolicyRequirements
/// according to their provided CreateRequirement delegate.
/// </summary>
[Theory, BitAutoData]
public async Task GetAsync_Works(Guid userId, Guid organizationId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var factories = new List<RequirementFactory<IPolicyRequirement>>
{
// In prod this cast is handled when the CreateRequirement delegate is registered in DI
(RequirementFactory<TestPolicyRequirement>)TestPolicyRequirement.Create
};
var sut = new PolicyRequirementQuery(policyRepository, factories);
policyRepository.GetPolicyDetailsByUserId(userId).Returns([
new PolicyDetails
{
OrganizationId = organizationId
}
]);
var requirement = await sut.GetAsync<TestPolicyRequirement>(userId);
Assert.Equal(organizationId, requirement.OrganizationId);
}
[Theory, BitAutoData]
public async Task GetAsync_ThrowsIfNoRequirementRegistered(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var sut = new PolicyRequirementQuery(policyRepository, []);
var exception = await Assert.ThrowsAsync<NotImplementedException>(()
=> sut.GetAsync<TestPolicyRequirement>(userId));
Assert.Contains("No Policy Requirement found", exception.Message);
}
/// <summary>
/// Intentionally simplified PolicyRequirement that just holds the Policy.OrganizationId for us to assert against.
/// </summary>
private class TestPolicyRequirement : IPolicyRequirement
{
public Guid OrganizationId { get; init; }
public static TestPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
=> new() { OrganizationId = policyDetails.Single().OrganizationId };
}
}

View File

@ -21,4 +21,15 @@ public class EventRepositoryHandlerTests
Arg.Is(AssertHelper.AssertPropertyEqual<IEvent>(eventMessage))
);
}
[Theory, BitAutoData]
public async Task HandleManyEventAsync_WritesEventsToIEventWriteService(
IEnumerable<EventMessage> eventMessages,
SutProvider<EventRepositoryHandler> sutProvider)
{
await sutProvider.Sut.HandleManyEventsAsync(eventMessages);
await sutProvider.GetDependency<IEventWriteService>().Received(1).CreateManyAsync(
Arg.Is(AssertHelper.AssertPropertyEqual<IEvent>(eventMessages))
);
}
}

View File

@ -44,10 +44,9 @@ public class WebhookEventHandlerTests
}
[Theory, BitAutoData]
public async Task HandleEventAsync_PostsEventsToUrl(EventMessage eventMessage)
public async Task HandleEventAsync_PostsEventToUrl(EventMessage eventMessage)
{
var sutProvider = GetSutProvider();
var content = JsonContent.Create(eventMessage);
await sutProvider.Sut.HandleEventAsync(eventMessage);
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
@ -63,4 +62,24 @@ public class WebhookEventHandlerTests
Assert.Equal(_webhookUrl, request.RequestUri.ToString());
AssertHelper.AssertPropertyEqual(eventMessage, returned, new[] { "IdempotencyId" });
}
[Theory, BitAutoData]
public async Task HandleEventManyAsync_PostsEventsToUrl(IEnumerable<EventMessage> eventMessages)
{
var sutProvider = GetSutProvider();
await sutProvider.Sut.HandleManyEventsAsync(eventMessages);
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
Arg.Is(AssertHelper.AssertPropertyEqual<string>(WebhookEventHandler.HttpClientName))
);
Assert.Single(_handler.CapturedRequests);
var request = _handler.CapturedRequests[0];
Assert.NotNull(request);
var returned = request.Content.ReadFromJsonAsAsyncEnumerable<EventMessage>();
Assert.Equal(HttpMethod.Post, request.Method);
Assert.Equal(_webhookUrl, request.RequestUri.ToString());
AssertHelper.AssertPropertyEqual(eventMessages, returned, new[] { "IdempotencyId" });
}
}

View File

@ -41,6 +41,12 @@ public class CreateNotificationCommandTest
Setup(sutProvider, notification, authorized: false);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.CreateAsync(notification));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
}
[Theory]
@ -59,5 +65,8 @@ public class CreateNotificationCommandTest
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationAsync(newNotification);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
}
}

View File

@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@ -50,6 +51,12 @@ public class CreateNotificationStatusCommandTest
Setup(sutProvider, notification: null, notificationStatus, true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.CreateAsync(notificationStatus));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -61,6 +68,12 @@ public class CreateNotificationStatusCommandTest
Setup(sutProvider, notification, notificationStatus, authorizedNotification: false, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.CreateAsync(notificationStatus));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -72,6 +85,12 @@ public class CreateNotificationStatusCommandTest
Setup(sutProvider, notification, notificationStatus, true, authorizedCreate: false);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.CreateAsync(notificationStatus));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -85,5 +104,11 @@ public class CreateNotificationStatusCommandTest
var newNotificationStatus = await sutProvider.Sut.CreateAsync(notificationStatus);
Assert.Equal(notificationStatus, newNotificationStatus);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationStatusAsync(notification, notificationStatus);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
}

View File

@ -6,6 +6,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@ -63,6 +64,12 @@ public class MarkNotificationDeletedCommandTest
Setup(sutProvider, notificationId, userId: null, notification, notificationStatus, true, true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkDeletedAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -74,6 +81,12 @@ public class MarkNotificationDeletedCommandTest
Setup(sutProvider, notificationId, userId, notification: null, notificationStatus, true, true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkDeletedAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -86,6 +99,12 @@ public class MarkNotificationDeletedCommandTest
true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkDeletedAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -98,6 +117,12 @@ public class MarkNotificationDeletedCommandTest
authorizedCreate: false, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkDeletedAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -110,6 +135,12 @@ public class MarkNotificationDeletedCommandTest
authorizedUpdate: false);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkDeletedAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -119,13 +150,25 @@ public class MarkNotificationDeletedCommandTest
Guid notificationId, Guid userId, Notification notification)
{
Setup(sutProvider, notificationId, userId, notification, notificationStatus: null, true, true, true);
var expectedNotificationStatus = new NotificationStatus
{
NotificationId = notificationId,
UserId = userId,
ReadDate = null,
DeletedDate = DateTime.UtcNow
};
await sutProvider.Sut.MarkDeletedAsync(notificationId);
await sutProvider.GetDependency<INotificationStatusRepository>().Received(1)
.CreateAsync(Arg.Is<NotificationStatus>(ns =>
ns.NotificationId == notificationId && ns.UserId == userId && !ns.ReadDate.HasValue &&
ns.DeletedDate.HasValue && DateTime.UtcNow - ns.DeletedDate.Value < TimeSpan.FromMinutes(1)));
.CreateAsync(Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(expectedNotificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationStatusAsync(notification,
Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(expectedNotificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -134,18 +177,30 @@ public class MarkNotificationDeletedCommandTest
SutProvider<MarkNotificationDeletedCommand> sutProvider,
Guid notificationId, Guid userId, Notification notification, NotificationStatus notificationStatus)
{
var deletedDate = notificationStatus.DeletedDate;
Setup(sutProvider, notificationId, userId, notification, notificationStatus, true, true, true);
await sutProvider.Sut.MarkDeletedAsync(notificationId);
await sutProvider.GetDependency<INotificationStatusRepository>().Received(1)
.UpdateAsync(Arg.Is<NotificationStatus>(ns =>
ns.Equals(notificationStatus) &&
ns.NotificationId == notificationStatus.NotificationId && ns.UserId == notificationStatus.UserId &&
ns.ReadDate == notificationStatus.ReadDate && ns.DeletedDate != deletedDate &&
ns.DeletedDate.HasValue &&
DateTime.UtcNow - ns.DeletedDate.Value < TimeSpan.FromMinutes(1)));
.UpdateAsync(Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(notificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationStatusAsync(notification,
Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(notificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
private static void AssertNotificationStatus(NotificationStatus expectedNotificationStatus,
NotificationStatus? actualNotificationStatus)
{
Assert.NotNull(actualNotificationStatus);
Assert.Equal(expectedNotificationStatus.NotificationId, actualNotificationStatus.NotificationId);
Assert.Equal(expectedNotificationStatus.UserId, actualNotificationStatus.UserId);
Assert.Equal(expectedNotificationStatus.ReadDate, actualNotificationStatus.ReadDate);
Assert.NotEqual(expectedNotificationStatus.DeletedDate, actualNotificationStatus.DeletedDate);
Assert.NotNull(actualNotificationStatus.DeletedDate);
Assert.Equal(DateTime.UtcNow, actualNotificationStatus.DeletedDate.Value, TimeSpan.FromMinutes(1));
}
}

View File

@ -6,6 +6,7 @@ using Bit.Core.NotificationCenter.Authorization;
using Bit.Core.NotificationCenter.Commands;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@ -63,6 +64,12 @@ public class MarkNotificationReadCommandTest
Setup(sutProvider, notificationId, userId: null, notification, notificationStatus, true, true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkReadAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -74,6 +81,12 @@ public class MarkNotificationReadCommandTest
Setup(sutProvider, notificationId, userId, notification: null, notificationStatus, true, true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkReadAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -86,6 +99,12 @@ public class MarkNotificationReadCommandTest
true, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkReadAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -98,6 +117,12 @@ public class MarkNotificationReadCommandTest
authorizedCreate: false, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkReadAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -110,6 +135,12 @@ public class MarkNotificationReadCommandTest
authorizedUpdate: false);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.MarkReadAsync(notificationId));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -119,13 +150,25 @@ public class MarkNotificationReadCommandTest
Guid notificationId, Guid userId, Notification notification)
{
Setup(sutProvider, notificationId, userId, notification, notificationStatus: null, true, true, true);
var expectedNotificationStatus = new NotificationStatus
{
NotificationId = notificationId,
UserId = userId,
ReadDate = DateTime.UtcNow,
DeletedDate = null
};
await sutProvider.Sut.MarkReadAsync(notificationId);
await sutProvider.GetDependency<INotificationStatusRepository>().Received(1)
.CreateAsync(Arg.Is<NotificationStatus>(ns =>
ns.NotificationId == notificationId && ns.UserId == userId && !ns.DeletedDate.HasValue &&
ns.ReadDate.HasValue && DateTime.UtcNow - ns.ReadDate.Value < TimeSpan.FromMinutes(1)));
.CreateAsync(Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(expectedNotificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationStatusAsync(notification,
Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(expectedNotificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
[Theory]
@ -134,18 +177,30 @@ public class MarkNotificationReadCommandTest
SutProvider<MarkNotificationReadCommand> sutProvider,
Guid notificationId, Guid userId, Notification notification, NotificationStatus notificationStatus)
{
var readDate = notificationStatus.ReadDate;
Setup(sutProvider, notificationId, userId, notification, notificationStatus, true, true, true);
await sutProvider.Sut.MarkReadAsync(notificationId);
await sutProvider.GetDependency<INotificationStatusRepository>().Received(1)
.UpdateAsync(Arg.Is<NotificationStatus>(ns =>
ns.Equals(notificationStatus) &&
ns.NotificationId == notificationStatus.NotificationId && ns.UserId == notificationStatus.UserId &&
ns.DeletedDate == notificationStatus.DeletedDate && ns.ReadDate != readDate &&
ns.ReadDate.HasValue &&
DateTime.UtcNow - ns.ReadDate.Value < TimeSpan.FromMinutes(1)));
.UpdateAsync(Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(notificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationStatusAsync(notification,
Arg.Do<NotificationStatus>(ns => AssertNotificationStatus(notificationStatus, ns)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
}
private static void AssertNotificationStatus(NotificationStatus expectedNotificationStatus,
NotificationStatus? actualNotificationStatus)
{
Assert.NotNull(actualNotificationStatus);
Assert.Equal(expectedNotificationStatus.NotificationId, actualNotificationStatus.NotificationId);
Assert.Equal(expectedNotificationStatus.UserId, actualNotificationStatus.UserId);
Assert.NotEqual(expectedNotificationStatus.ReadDate, actualNotificationStatus.ReadDate);
Assert.NotNull(actualNotificationStatus.ReadDate);
Assert.Equal(DateTime.UtcNow, actualNotificationStatus.ReadDate.Value, TimeSpan.FromMinutes(1));
Assert.Equal(expectedNotificationStatus.DeletedDate, actualNotificationStatus.DeletedDate);
}
}

View File

@ -7,6 +7,7 @@ using Bit.Core.NotificationCenter.Commands;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Enums;
using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Platform.Push;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
@ -45,6 +46,12 @@ public class UpdateNotificationCommandTest
Setup(sutProvider, notification.Id, notification: null, true);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.UpdateAsync(notification));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
}
[Theory]
@ -56,6 +63,12 @@ public class UpdateNotificationCommandTest
Setup(sutProvider, notification.Id, notification, authorized: false);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.UpdateAsync(notification));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationAsync(Arg.Any<Notification>());
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
}
[Theory]
@ -91,5 +104,11 @@ public class UpdateNotificationCommandTest
n.Priority == notificationToUpdate.Priority && n.ClientType == notificationToUpdate.ClientType &&
n.Title == notificationToUpdate.Title && n.Body == notificationToUpdate.Body &&
DateTime.UtcNow - n.RevisionDate < TimeSpan.FromMinutes(1)));
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushNotificationAsync(notification);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(0)
.PushNotificationStatusAsync(Arg.Any<Notification>(), Arg.Any<NotificationStatus>());
}
}

View File

@ -15,12 +15,13 @@ using Xunit;
namespace Bit.Core.Test.NotificationHub;
[SutProviderCustomize]
[NotificationStatusCustomize]
public class NotificationHubPushNotificationServiceTests
{
[Theory]
[BitAutoData]
[NotificationCustomize]
public async void PushNotificationAsync_Global_NotSent(
public async Task PushNotificationAsync_Global_NotSent(
SutProvider<NotificationHubPushNotificationService> sutProvider, Notification notification)
{
await sutProvider.Sut.PushNotificationAsync(notification);
@ -39,7 +40,7 @@ public class NotificationHubPushNotificationServiceTests
[BitAutoData(false)]
[BitAutoData(true)]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdProvidedClientTypeAll_SentToUser(
public async Task PushNotificationAsync_UserIdProvidedClientTypeAll_SentToUser(
bool organizationIdNull, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{
@ -49,11 +50,12 @@ public class NotificationHubPushNotificationServiceTests
}
notification.ClientType = ClientType.All;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
var expectedNotification = ToNotificationPushNotification(notification, null);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification,
expectedNotification,
$"(template:payload_userId:{notification.UserId})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
@ -61,30 +63,46 @@ public class NotificationHubPushNotificationServiceTests
}
[Theory]
[BitAutoData(false, ClientType.Browser)]
[BitAutoData(false, ClientType.Desktop)]
[BitAutoData(false, ClientType.Web)]
[BitAutoData(false, ClientType.Mobile)]
[BitAutoData(true, ClientType.Browser)]
[BitAutoData(true, ClientType.Desktop)]
[BitAutoData(true, ClientType.Web)]
[BitAutoData(true, ClientType.Mobile)]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdProvidedClientTypeNotAll_SentToUser(bool organizationIdNull,
public async Task PushNotificationAsync_UserIdProvidedOrganizationIdNullClientTypeNotAll_SentToUser(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{
if (organizationIdNull)
{
notification.OrganizationId = null;
}
notification.OrganizationId = null;
notification.ClientType = clientType;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
var expectedNotification = ToNotificationPushNotification(notification, null);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification,
expectedNotification,
$"(template:payload_userId:{notification.UserId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async Task PushNotificationAsync_UserIdProvidedOrganizationIdProvidedClientTypeNotAll_SentToUser(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{
notification.ClientType = clientType;
var expectedNotification = ToNotificationPushNotification(notification, null);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification,
expectedNotification,
$"(template:payload_userId:{notification.UserId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
@ -94,16 +112,17 @@ public class NotificationHubPushNotificationServiceTests
[Theory]
[BitAutoData]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeAll_SentToOrganization(
public async Task PushNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeAll_SentToOrganization(
SutProvider<NotificationHubPushNotificationService> sutProvider, Notification notification)
{
notification.UserId = null;
notification.ClientType = ClientType.All;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
var expectedNotification = ToNotificationPushNotification(notification, null);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification,
expectedNotification,
$"(template:payload && organizationId:{notification.OrganizationId})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
@ -116,18 +135,156 @@ public class NotificationHubPushNotificationServiceTests
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async void PushNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeNotAll_SentToOrganization(
public async Task PushNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeNotAll_SentToOrganization(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification)
{
notification.UserId = null;
notification.ClientType = clientType;
var expectedSyncNotification = ToSyncNotificationPushNotification(notification);
var expectedNotification = ToNotificationPushNotification(notification, null);
await sutProvider.Sut.PushNotificationAsync(notification);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification,
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification,
expectedNotification,
$"(template:payload && organizationId:{notification.OrganizationId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData]
[NotificationCustomize]
public async Task PushNotificationStatusAsync_Global_NotSent(
SutProvider<NotificationHubPushNotificationService> sutProvider, Notification notification,
NotificationStatus notificationStatus)
{
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await sutProvider.GetDependency<INotificationHubPool>()
.Received(0)
.AllClients
.Received(0)
.SendTemplateNotificationAsync(Arg.Any<IDictionary<string, string>>(), Arg.Any<string>());
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(false)]
[BitAutoData(true)]
[NotificationCustomize(false)]
public async Task PushNotificationStatusAsync_UserIdProvidedClientTypeAll_SentToUser(
bool organizationIdNull, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification, NotificationStatus notificationStatus)
{
if (organizationIdNull)
{
notification.OrganizationId = null;
}
notification.ClientType = ClientType.All;
var expectedNotification = ToNotificationPushNotification(notification, notificationStatus);
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotificationStatus,
expectedNotification,
$"(template:payload_userId:{notification.UserId})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async Task PushNotificationStatusAsync_UserIdProvidedOrganizationIdNullClientTypeNotAll_SentToUser(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification, NotificationStatus notificationStatus)
{
notification.OrganizationId = null;
notification.ClientType = clientType;
var expectedNotification = ToNotificationPushNotification(notification, notificationStatus);
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotificationStatus,
expectedNotification,
$"(template:payload_userId:{notification.UserId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async Task PushNotificationStatusAsync_UserIdProvidedOrganizationIdProvidedClientTypeNotAll_SentToUser(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification, NotificationStatus notificationStatus)
{
notification.ClientType = clientType;
var expectedNotification = ToNotificationPushNotification(notification, notificationStatus);
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotificationStatus,
expectedNotification,
$"(template:payload_userId:{notification.UserId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData]
[NotificationCustomize(false)]
public async Task PushNotificationStatusAsync_UserIdNullOrganizationIdProvidedClientTypeAll_SentToOrganization(
SutProvider<NotificationHubPushNotificationService> sutProvider, Notification notification,
NotificationStatus notificationStatus)
{
notification.UserId = null;
notification.ClientType = ClientType.All;
var expectedNotification = ToNotificationPushNotification(notification, notificationStatus);
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotificationStatus,
expectedNotification,
$"(template:payload && organizationId:{notification.OrganizationId})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Web)]
[BitAutoData(ClientType.Mobile)]
[NotificationCustomize(false)]
public async Task
PushNotificationStatusAsync_UserIdNullOrganizationIdProvidedClientTypeNotAll_SentToOrganization(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider,
Notification notification, NotificationStatus notificationStatus)
{
notification.UserId = null;
notification.ClientType = clientType;
var expectedNotification = ToNotificationPushNotification(notification, notificationStatus);
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotificationStatus,
expectedNotification,
$"(template:payload && organizationId:{notification.OrganizationId} && clientType:{clientType})");
await sutProvider.GetDependency<IInstallationDeviceRepository>()
.Received(0)
@ -137,7 +294,7 @@ public class NotificationHubPushNotificationServiceTests
[Theory]
[BitAutoData([null])]
[BitAutoData(ClientType.All)]
public async void SendPayloadToUserAsync_ClientTypeNullOrAll_SentToUser(ClientType? clientType,
public async Task SendPayloadToUserAsync_ClientTypeNullOrAll_SentToUser(ClientType? clientType,
SutProvider<NotificationHubPushNotificationService> sutProvider, Guid userId, PushType pushType, string payload,
string identifier)
{
@ -156,7 +313,7 @@ public class NotificationHubPushNotificationServiceTests
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Mobile)]
[BitAutoData(ClientType.Web)]
public async void SendPayloadToUserAsync_ClientTypeExplicit_SentToUserAndClientType(ClientType clientType,
public async Task SendPayloadToUserAsync_ClientTypeExplicit_SentToUserAndClientType(ClientType clientType,
SutProvider<NotificationHubPushNotificationService> sutProvider, Guid userId, PushType pushType, string payload,
string identifier)
{
@ -173,7 +330,7 @@ public class NotificationHubPushNotificationServiceTests
[Theory]
[BitAutoData([null])]
[BitAutoData(ClientType.All)]
public async void SendPayloadToOrganizationAsync_ClientTypeNullOrAll_SentToOrganization(ClientType? clientType,
public async Task SendPayloadToOrganizationAsync_ClientTypeNullOrAll_SentToOrganization(ClientType? clientType,
SutProvider<NotificationHubPushNotificationService> sutProvider, Guid organizationId, PushType pushType,
string payload, string identifier)
{
@ -192,7 +349,7 @@ public class NotificationHubPushNotificationServiceTests
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Mobile)]
[BitAutoData(ClientType.Web)]
public async void SendPayloadToOrganizationAsync_ClientTypeExplicit_SentToOrganizationAndClientType(
public async Task SendPayloadToOrganizationAsync_ClientTypeExplicit_SentToOrganizationAndClientType(
ClientType clientType, SutProvider<NotificationHubPushNotificationService> sutProvider, Guid organizationId,
PushType pushType, string payload, string identifier)
{
@ -206,7 +363,8 @@ public class NotificationHubPushNotificationServiceTests
.UpsertAsync(Arg.Any<InstallationDeviceEntity>());
}
private static NotificationPushNotification ToSyncNotificationPushNotification(Notification notification) =>
private static NotificationPushNotification ToNotificationPushNotification(Notification notification,
NotificationStatus? notificationStatus) =>
new()
{
Id = notification.Id,
@ -218,7 +376,9 @@ public class NotificationHubPushNotificationServiceTests
Title = notification.Title,
Body = notification.Body,
CreationDate = notification.CreationDate,
RevisionDate = notification.RevisionDate
RevisionDate = notification.RevisionDate,
ReadDate = notificationStatus?.ReadDate,
DeletedDate = notificationStatus?.DeletedDate
};
private static async Task AssertSendTemplateNotificationAsync(

View File

@ -24,7 +24,7 @@ public class AzureQueuePushNotificationServiceTests
[BitAutoData]
[NotificationCustomize]
[CurrentContextCustomize]
public async void PushNotificationAsync_Notification_Sent(
public async Task PushNotificationAsync_Notification_Sent(
SutProvider<AzureQueuePushNotificationService> sutProvider, Notification notification, Guid deviceIdentifier,
ICurrentContext currentContext)
{
@ -36,7 +36,30 @@ public class AzureQueuePushNotificationServiceTests
await sutProvider.GetDependency<QueueClient>().Received(1)
.SendMessageAsync(Arg.Is<string>(message =>
MatchMessage(PushType.SyncNotification, message, new SyncNotificationEquals(notification),
MatchMessage(PushType.SyncNotification, message,
new NotificationPushNotificationEquals(notification, null),
deviceIdentifier.ToString())));
}
[Theory]
[BitAutoData]
[NotificationCustomize]
[NotificationStatusCustomize]
[CurrentContextCustomize]
public async Task PushNotificationStatusAsync_Notification_Sent(
SutProvider<AzureQueuePushNotificationService> sutProvider, Notification notification, Guid deviceIdentifier,
ICurrentContext currentContext, NotificationStatus notificationStatus)
{
currentContext.DeviceIdentifier.Returns(deviceIdentifier.ToString());
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext!.RequestServices
.GetService(Arg.Any<Type>()).Returns(currentContext);
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await sutProvider.GetDependency<QueueClient>().Received(1)
.SendMessageAsync(Arg.Is<string>(message =>
MatchMessage(PushType.SyncNotificationStatus, message,
new NotificationPushNotificationEquals(notification, notificationStatus),
deviceIdentifier.ToString())));
}
@ -50,7 +73,8 @@ public class AzureQueuePushNotificationServiceTests
pushNotificationData.ContextId == contextId;
}
private class SyncNotificationEquals(Notification notification) : IEquatable<NotificationPushNotification>
private class NotificationPushNotificationEquals(Notification notification, NotificationStatus? notificationStatus)
: IEquatable<NotificationPushNotification>
{
public bool Equals(NotificationPushNotification? other)
{
@ -66,7 +90,9 @@ public class AzureQueuePushNotificationServiceTests
other.Title == notification.Title &&
other.Body == notification.Body &&
other.CreationDate == notification.CreationDate &&
other.RevisionDate == notification.RevisionDate;
other.RevisionDate == notification.RevisionDate &&
other.ReadDate == notificationStatus?.ReadDate &&
other.DeletedDate == notificationStatus?.DeletedDate;
}
}
}

View File

@ -26,6 +26,22 @@ public class MultiServicePushNotificationServiceTests
.PushNotificationAsync(notification);
}
[Theory]
[BitAutoData]
[NotificationCustomize]
[NotificationStatusCustomize]
public async Task PushNotificationStatusAsync_Notification_Sent(
SutProvider<MultiServicePushNotificationService> sutProvider, Notification notification,
NotificationStatus notificationStatus)
{
await sutProvider.Sut.PushNotificationStatusAsync(notification, notificationStatus);
await sutProvider.GetDependency<IEnumerable<IPushNotificationService>>()
.First()
.Received(1)
.PushNotificationStatusAsync(notification, notificationStatus);
}
[Theory]
[BitAutoData([null, null])]
[BitAutoData(ClientType.All, null)]

View File

@ -730,6 +730,46 @@ public class UserServiceTests
.RemoveAsync(Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task RecoverTwoFactorAsync_CorrectCode_ReturnsTrueAndProcessesPolicies(
User user, SutProvider<UserService> sutProvider)
{
// Arrange
var recoveryCode = "1234";
user.TwoFactorRecoveryCode = recoveryCode;
// Act
var response = await sutProvider.Sut.RecoverTwoFactorAsync(user, recoveryCode);
// Assert
Assert.True(response);
Assert.Null(user.TwoFactorProviders);
// Make sure a new code was generated for the user
Assert.NotEqual(recoveryCode, user.TwoFactorRecoveryCode);
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendRecoverTwoFactorEmail(Arg.Any<string>(), Arg.Any<DateTime>(), Arg.Any<string>());
await sutProvider.GetDependency<IEventService>()
.Received(1)
.LogUserEventAsync(user.Id, EventType.User_Recovered2fa);
}
[Theory, BitAutoData]
public async Task RecoverTwoFactorAsync_IncorrectCode_ReturnsFalse(
User user, SutProvider<UserService> sutProvider)
{
// Arrange
var recoveryCode = "1234";
user.TwoFactorRecoveryCode = "4567";
// Act
var response = await sutProvider.Sut.RecoverTwoFactorAsync(user, recoveryCode);
// Assert
Assert.False(response);
Assert.NotNull(user.TwoFactorProviders);
}
private static void SetupUserAndDevice(User user,
bool shouldHavePassword)
{

View File

@ -105,7 +105,7 @@ public class BaseRequestValidatorTests
// Assert
await _eventService.Received(1)
.LogUserEventAsync(context.CustomValidatorRequestContext.User.Id,
Core.Enums.EventType.User_FailedLogIn);
EventType.User_FailedLogIn);
Assert.True(context.GrantResult.IsError);
Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message);
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Identity.TokenProviders;
using Bit.Core.Auth.Models.Business.Tokenables;
@ -328,7 +329,7 @@ public class TwoFactorAuthenticationValidatorTests
_userManager.TWO_FACTOR_PROVIDERS = ["email"];
// Act
var result = await _sut.VerifyTwoFactor(
var result = await _sut.VerifyTwoFactorAsync(
user, null, TwoFactorProviderType.U2f, token);
// Assert
@ -348,7 +349,7 @@ public class TwoFactorAuthenticationValidatorTests
_userManager.TWO_FACTOR_PROVIDERS = ["email"];
// Act
var result = await _sut.VerifyTwoFactor(
var result = await _sut.VerifyTwoFactorAsync(
user, null, TwoFactorProviderType.Email, token);
// Assert
@ -368,7 +369,7 @@ public class TwoFactorAuthenticationValidatorTests
_userManager.TWO_FACTOR_PROVIDERS = ["OrganizationDuo"];
// Act
var result = await _sut.VerifyTwoFactor(
var result = await _sut.VerifyTwoFactorAsync(
user, null, TwoFactorProviderType.OrganizationDuo, token);
// Assert
@ -394,7 +395,7 @@ public class TwoFactorAuthenticationValidatorTests
_userManager.TWO_FACTOR_TOKEN_VERIFIED = true;
// Act
var result = await _sut.VerifyTwoFactor(user, null, providerType, token);
var result = await _sut.VerifyTwoFactorAsync(user, null, providerType, token);
// Assert
Assert.True(result);
@ -419,7 +420,7 @@ public class TwoFactorAuthenticationValidatorTests
_userManager.TWO_FACTOR_TOKEN_VERIFIED = false;
// Act
var result = await _sut.VerifyTwoFactor(user, null, providerType, token);
var result = await _sut.VerifyTwoFactorAsync(user, null, providerType, token);
// Assert
Assert.False(result);
@ -445,13 +446,56 @@ public class TwoFactorAuthenticationValidatorTests
organization.Enabled = true;
// Act
var result = await _sut.VerifyTwoFactor(
var result = await _sut.VerifyTwoFactorAsync(
user, organization, providerType, token);
// Assert
Assert.True(result);
}
[Theory]
[BitAutoData(TwoFactorProviderType.RecoveryCode)]
public async void VerifyTwoFactorAsync_RecoveryCode_ValidToken_ReturnsTrue(
TwoFactorProviderType providerType,
User user,
Organization organization)
{
var token = "1234";
user.TwoFactorRecoveryCode = token;
_userService.RecoverTwoFactorAsync(Arg.Is(user), Arg.Is(token)).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.RecoveryCodeLogin).Returns(true);
// Act
var result = await _sut.VerifyTwoFactorAsync(
user, organization, providerType, token);
// Assert
Assert.True(result);
}
[Theory]
[BitAutoData(TwoFactorProviderType.RecoveryCode)]
public async void VerifyTwoFactorAsync_RecoveryCode_InvalidToken_ReturnsFalse(
TwoFactorProviderType providerType,
User user,
Organization organization)
{
// Arrange
var token = "1234";
user.TwoFactorRecoveryCode = token;
_userService.RecoverTwoFactorAsync(Arg.Is(user), Arg.Is(token)).Returns(false);
_featureService.IsEnabled(FeatureFlagKeys.RecoveryCodeLogin).Returns(true);
// Act
var result = await _sut.VerifyTwoFactorAsync(
user, organization, providerType, token);
// Assert
Assert.False(result);
}
private static UserManagerTestWrapper<User> SubstituteUserManager()
{
return new UserManagerTestWrapper<User>(

View File

@ -0,0 +1,57 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole;
/// <summary>
/// A set of extension methods used to arrange simple test data.
/// This should only be used for basic, repetitive data arrangement, not for anything complex or for
/// the repository method under test.
/// </summary>
public static class OrganizationTestHelpers
{
public static Task<User> CreateTestUserAsync(this IUserRepository userRepository, string identifier = "test")
{
var id = Guid.NewGuid();
return userRepository.CreateAsync(new User
{
Id = id,
Name = $"{identifier}-{id}",
Email = $"{id}@example.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
}
public static Task<Organization> CreateTestOrganizationAsync(this IOrganizationRepository organizationRepository,
string identifier = "test")
=> organizationRepository.CreateAsync(new Organization
{
Name = $"{identifier}-{Guid.NewGuid()}",
BillingEmail = "billing@example.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
});
public static Task<OrganizationUser> CreateTestOrganizationUserAsync(
this IOrganizationUserRepository organizationUserRepository,
Organization organization,
User user)
=> organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner
});
public static Task<Group> CreateTestGroupAsync(
this IGroupRepository groupRepository,
Organization organization,
string identifier = "test")
=> groupRepository.CreateAsync(
new Group { OrganizationId = organization.Id, Name = $"{identifier} {Guid.NewGuid()}" }
);
}

View File

@ -0,0 +1,129 @@
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories;
public class GroupRepositoryTests
{
[DatabaseTheory, DatabaseData]
public async Task AddGroupUsersByIdAsync_CreatesGroupUsers(
IGroupRepository groupRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository)
{
// Arrange
var user1 = await userRepository.CreateTestUserAsync("user1");
var user2 = await userRepository.CreateTestUserAsync("user2");
var user3 = await userRepository.CreateTestUserAsync("user3");
var org = await organizationRepository.CreateTestOrganizationAsync();
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user1);
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user2);
var orgUser3 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user3);
var orgUserIds = new List<Guid>([orgUser1.Id, orgUser2.Id, orgUser3.Id]);
var group = await groupRepository.CreateTestGroupAsync(org);
// Act
await groupRepository.AddGroupUsersByIdAsync(group.Id, orgUserIds);
// Assert
var actual = await groupRepository.GetManyUserIdsByIdAsync(group.Id);
Assert.Equal(orgUserIds!.Order(), actual.Order());
}
[DatabaseTheory, DatabaseData]
public async Task AddGroupUsersByIdAsync_IgnoresExistingGroupUsers(
IGroupRepository groupRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository)
{
// Arrange
var user1 = await userRepository.CreateTestUserAsync("user1");
var user2 = await userRepository.CreateTestUserAsync("user2");
var user3 = await userRepository.CreateTestUserAsync("user3");
var org = await organizationRepository.CreateTestOrganizationAsync();
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user1);
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user2);
var orgUser3 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user3);
var orgUserIds = new List<Guid>([orgUser1.Id, orgUser2.Id, orgUser3.Id]);
var group = await groupRepository.CreateTestGroupAsync(org);
// Add user 2 to the group already, make sure this is executed correctly before proceeding
await groupRepository.UpdateUsersAsync(group.Id, [orgUser2.Id]);
var existingUsers = await groupRepository.GetManyUserIdsByIdAsync(group.Id);
Assert.Equal([orgUser2.Id], existingUsers);
// Act
await groupRepository.AddGroupUsersByIdAsync(group.Id, orgUserIds);
// Assert - group should contain all users
var actual = await groupRepository.GetManyUserIdsByIdAsync(group.Id);
Assert.Equal(orgUserIds!.Order(), actual.Order());
}
[DatabaseTheory, DatabaseData]
public async Task AddGroupUsersByIdAsync_IgnoresUsersNotInOrganization(
IGroupRepository groupRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository)
{
// Arrange
var user1 = await userRepository.CreateTestUserAsync("user1");
var user2 = await userRepository.CreateTestUserAsync("user2");
var user3 = await userRepository.CreateTestUserAsync("user3");
var org = await organizationRepository.CreateTestOrganizationAsync();
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user1);
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user2);
// User3 belongs to a different org
var otherOrg = await organizationRepository.CreateTestOrganizationAsync();
var orgUser3 = await organizationUserRepository.CreateTestOrganizationUserAsync(otherOrg, user3);
var orgUserIds = new List<Guid>([orgUser1.Id, orgUser2.Id, orgUser3.Id]);
var group = await groupRepository.CreateTestGroupAsync(org);
// Act
await groupRepository.AddGroupUsersByIdAsync(group.Id, orgUserIds);
// Assert
var actual = await groupRepository.GetManyUserIdsByIdAsync(group.Id);
Assert.Equal(2, actual.Count);
Assert.Contains(orgUser1.Id, actual);
Assert.Contains(orgUser2.Id, actual);
Assert.DoesNotContain(orgUser3.Id, actual);
}
[DatabaseTheory, DatabaseData]
public async Task AddGroupUsersByIdAsync_IgnoresDuplicateUsers(
IGroupRepository groupRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository)
{
// Arrange
var user1 = await userRepository.CreateTestUserAsync("user1");
var user2 = await userRepository.CreateTestUserAsync("user2");
var org = await organizationRepository.CreateTestOrganizationAsync();
var orgUser1 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user1);
var orgUser2 = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user2);
var orgUserIds = new List<Guid>([orgUser1.Id, orgUser2.Id, orgUser2.Id]); // duplicate orgUser2
var group = await groupRepository.CreateTestGroupAsync(org);
// Act
await groupRepository.AddGroupUsersByIdAsync(group.Id, orgUserIds);
// Assert
var actual = await groupRepository.GetManyUserIdsByIdAsync(group.Id);
Assert.Equal(2, actual.Count);
Assert.Contains(orgUser1.Id, actual);
Assert.Contains(orgUser2.Id, actual);
}
}

View File

@ -3,7 +3,7 @@ using Bit.Core.Entities;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.Repositories;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories;
public class OrganizationDomainRepositoryTests
{

View File

@ -4,7 +4,7 @@ using Bit.Core.Enums;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.Repositories;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories;
public class OrganizationRepositoryTests
{

Some files were not shown because too many files have changed in this diff Show More