1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-05 05:00:19 -05:00

[SM-918] Enforce project maximums on import (#3253)

* Refactor MaxProjectsQuery for multiple adds

* Update unit tests

* Add max project enforcement to imports
This commit is contained in:
Thomas Avery 2023-09-07 17:51:35 -05:00 committed by GitHub
parent 2aaef3cf64
commit 4b482f0a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 20 deletions

View File

@ -20,7 +20,7 @@ public class MaxProjectsQuery : IMaxProjectsQuery
_projectRepository = projectRepository; _projectRepository = projectRepository;
} }
public async Task<(short? max, bool? atMax)> GetByOrgIdAsync(Guid organizationId) public async Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd)
{ {
var org = await _organizationRepository.GetByIdAsync(organizationId); var org = await _organizationRepository.GetByIdAsync(organizationId);
if (org == null) if (org == null)
@ -37,7 +37,7 @@ public class MaxProjectsQuery : IMaxProjectsQuery
if (plan.Type == PlanType.Free) if (plan.Type == PlanType.Free)
{ {
var projects = await _projectRepository.GetProjectCountByOrganizationIdAsync(organizationId); var projects = await _projectRepository.GetProjectCountByOrganizationIdAsync(organizationId);
return projects >= plan.MaxProjects ? (plan.MaxProjects, true) : (plan.MaxProjects, false); return projects + projectsToAdd > plan.MaxProjects ? (plan.MaxProjects, true) : (plan.MaxProjects, false);
} }
return (null, null); return (null, null);

View File

@ -22,7 +22,7 @@ public class MaxProjectsQueryTests
{ {
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(default).ReturnsNull(); sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(default).ReturnsNull();
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.GetByOrgIdAsync(organizationId)); await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.GetByOrgIdAsync(organizationId, 1));
await sutProvider.GetDependency<IProjectRepository>().DidNotReceiveWithAnyArgs() await sutProvider.GetDependency<IProjectRepository>().DidNotReceiveWithAnyArgs()
.GetProjectCountByOrganizationIdAsync(organizationId); .GetProjectCountByOrganizationIdAsync(organizationId);
@ -43,7 +43,7 @@ public class MaxProjectsQueryTests
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
await Assert.ThrowsAsync<BadRequestException>( await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.GetByOrgIdAsync(organization.Id)); async () => await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1));
await sutProvider.GetDependency<IProjectRepository>().DidNotReceiveWithAnyArgs() await sutProvider.GetDependency<IProjectRepository>().DidNotReceiveWithAnyArgs()
.GetProjectCountByOrganizationIdAsync(organization.Id); .GetProjectCountByOrganizationIdAsync(organization.Id);
@ -60,7 +60,7 @@ public class MaxProjectsQueryTests
organization.PlanType = planType; organization.PlanType = planType;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id); var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1);
Assert.Null(limit); Assert.Null(limit);
Assert.Null(overLimit); Assert.Null(overLimit);
@ -70,13 +70,31 @@ public class MaxProjectsQueryTests
} }
[Theory] [Theory]
[BitAutoData(PlanType.Free, 0, false)] [BitAutoData(PlanType.Free, 0, 1, false)]
[BitAutoData(PlanType.Free, 1, false)] [BitAutoData(PlanType.Free, 1, 1, false)]
[BitAutoData(PlanType.Free, 2, false)] [BitAutoData(PlanType.Free, 2, 1, false)]
[BitAutoData(PlanType.Free, 3, true)] [BitAutoData(PlanType.Free, 3, 1, true)]
[BitAutoData(PlanType.Free, 4, true)] [BitAutoData(PlanType.Free, 4, 1, true)]
[BitAutoData(PlanType.Free, 40, true)] [BitAutoData(PlanType.Free, 40, 1, true)]
public async Task GetByOrgIdAsync_SmFreePlan_Success(PlanType planType, int projects, bool shouldBeAtMax, [BitAutoData(PlanType.Free, 0, 2, false)]
[BitAutoData(PlanType.Free, 1, 2, false)]
[BitAutoData(PlanType.Free, 2, 2, true)]
[BitAutoData(PlanType.Free, 3, 2, true)]
[BitAutoData(PlanType.Free, 4, 2, true)]
[BitAutoData(PlanType.Free, 40, 2, true)]
[BitAutoData(PlanType.Free, 0, 3, false)]
[BitAutoData(PlanType.Free, 1, 3, true)]
[BitAutoData(PlanType.Free, 2, 3, true)]
[BitAutoData(PlanType.Free, 3, 3, true)]
[BitAutoData(PlanType.Free, 4, 3, true)]
[BitAutoData(PlanType.Free, 40, 3, true)]
[BitAutoData(PlanType.Free, 0, 4, true)]
[BitAutoData(PlanType.Free, 1, 4, true)]
[BitAutoData(PlanType.Free, 2, 4, true)]
[BitAutoData(PlanType.Free, 3, 4, true)]
[BitAutoData(PlanType.Free, 4, 4, true)]
[BitAutoData(PlanType.Free, 40, 4, true)]
public async Task GetByOrgIdAsync_SmFreePlan__Success(PlanType planType, int projects, int projectsToAdd, bool expectedOverMax,
SutProvider<MaxProjectsQuery> sutProvider, Organization organization) SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{ {
organization.PlanType = planType; organization.PlanType = planType;
@ -84,12 +102,12 @@ public class MaxProjectsQueryTests
sutProvider.GetDependency<IProjectRepository>().GetProjectCountByOrganizationIdAsync(organization.Id) sutProvider.GetDependency<IProjectRepository>().GetProjectCountByOrganizationIdAsync(organization.Id)
.Returns(projects); .Returns(projects);
var (max, atMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id); var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd);
Assert.NotNull(max); Assert.NotNull(max);
Assert.NotNull(atMax); Assert.NotNull(overMax);
Assert.Equal(3, max.Value); Assert.Equal(3, max.Value);
Assert.Equal(shouldBeAtMax, atMax); Assert.Equal(expectedOverMax, overMax);
await sutProvider.GetDependency<IProjectRepository>().Received(1) await sutProvider.GetDependency<IProjectRepository>().Received(1)
.GetProjectCountByOrganizationIdAsync(organization.Id); .GetProjectCountByOrganizationIdAsync(organization.Id);

View File

@ -79,8 +79,8 @@ public class ProjectsController : Controller
throw new NotFoundException(); throw new NotFoundException();
} }
var (max, atMax) = await _maxProjectsQuery.GetByOrgIdAsync(organizationId); var (max, overMax) = await _maxProjectsQuery.GetByOrgIdAsync(organizationId, 1);
if (atMax != null && atMax.Value) if (overMax != null && overMax.Value)
{ {
throw new BadRequestException($"You have reached the maximum number of projects ({max}) for this plan."); throw new BadRequestException($"You have reached the maximum number of projects ({max}) for this plan.");
} }

View File

@ -4,6 +4,7 @@ using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.SecretsManager.Commands.Porting.Interfaces; using Bit.Core.SecretsManager.Commands.Porting.Interfaces;
using Bit.Core.SecretsManager.Queries.Projects.Interfaces;
using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
@ -19,14 +20,18 @@ public class SecretsManagerPortingController : Controller
private readonly ISecretRepository _secretRepository; private readonly ISecretRepository _secretRepository;
private readonly IProjectRepository _projectRepository; private readonly IProjectRepository _projectRepository;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly IMaxProjectsQuery _maxProjectsQuery;
private readonly IImportCommand _importCommand; private readonly IImportCommand _importCommand;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
public SecretsManagerPortingController(ISecretRepository secretRepository, IProjectRepository projectRepository, IUserService userService, IImportCommand importCommand, ICurrentContext currentContext) public SecretsManagerPortingController(ISecretRepository secretRepository, IProjectRepository projectRepository,
IUserService userService, IMaxProjectsQuery maxProjectsQuery, IImportCommand importCommand,
ICurrentContext currentContext)
{ {
_secretRepository = secretRepository; _secretRepository = secretRepository;
_projectRepository = projectRepository; _projectRepository = projectRepository;
_userService = userService; _userService = userService;
_maxProjectsQuery = maxProjectsQuery;
_importCommand = importCommand; _importCommand = importCommand;
_currentContext = currentContext; _currentContext = currentContext;
} }
@ -69,6 +74,16 @@ public class SecretsManagerPortingController : Controller
throw new BadRequestException("A secret can only be in one project at a time."); throw new BadRequestException("A secret can only be in one project at a time.");
} }
var projectsToAdd = importRequest.Projects?.Count();
if (projectsToAdd is > 0)
{
var (max, overMax) = await _maxProjectsQuery.GetByOrgIdAsync(organizationId, projectsToAdd.Value);
if (overMax != null && overMax.Value)
{
throw new BadRequestException($"The maximum number of projects for this plan is ({max}).");
}
}
await _importCommand.ImportAsync(organizationId, importRequest.ToSMImport()); await _importCommand.ImportAsync(organizationId, importRequest.ToSMImport());
} }
} }

View File

@ -2,5 +2,5 @@
public interface IMaxProjectsQuery public interface IMaxProjectsQuery
{ {
Task<(short? max, bool? atMax)> GetByOrgIdAsync(Guid organizationId); Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd);
} }

View File

@ -132,7 +132,7 @@ public class ProjectsControllerTests
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(), data.ToProject(orgId), .AuthorizeAsync(Arg.Any<ClaimsPrincipal>(), data.ToProject(orgId),
Arg.Any<IEnumerable<IAuthorizationRequirement>>()).ReturnsForAnyArgs(AuthorizationResult.Success()); Arg.Any<IEnumerable<IAuthorizationRequirement>>()).ReturnsForAnyArgs(AuthorizationResult.Success());
sutProvider.GetDependency<IUserService>().GetProperUserId(default).ReturnsForAnyArgs(Guid.NewGuid()); sutProvider.GetDependency<IUserService>().GetProperUserId(default).ReturnsForAnyArgs(Guid.NewGuid());
sutProvider.GetDependency<IMaxProjectsQuery>().GetByOrgIdAsync(orgId).Returns(((short)3, true)); sutProvider.GetDependency<IMaxProjectsQuery>().GetByOrgIdAsync(orgId, 1).Returns(((short)3, true));
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.CreateAsync(orgId, data)); await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.CreateAsync(orgId, data));