From 3206b9aee592149559a68dde681aa314910db944 Mon Sep 17 00:00:00 2001 From: Cy Okeke Date: Tue, 20 May 2025 12:20:59 +0100 Subject: [PATCH] Resolve the comment regarding abstraction --- .../src/Sso/Controllers/AccountController.cs | 2 +- .../InviteOrganizationUsersCommand.cs | 3 +- .../v1/RestoreOrganizationUserCommand.cs | 8 +-- .../IOrganizationUserRepository.cs | 2 +- .../Implementations/OrganizationService.cs | 19 +++-- .../OrganizationSeatCounts.cs | 8 +++ .../CreateSponsorshipCommand.cs | 4 +- .../UpgradeOrganizationPlanCommand.cs | 8 +-- .../OrganizationUserRepository.cs | 6 +- .../OrganizationUserRepository.cs | 24 ++++++- .../InviteOrganizationUserCommandTests.cs | 6 +- .../CreateSponsorshipCommandTests.cs | 25 +++++-- ...erReadOccupiedSeatCountForSponsorships.sql | 72 +++++++++++++++++++ 13 files changed, 152 insertions(+), 35 deletions(-) create mode 100644 src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationSeatCounts.cs create mode 100644 util/Migrator/DbScripts/2025-05-20_00_UpdateOrgUserReadOccupiedSeatCountForSponsorships.sql diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index f41d2d3c65..774b4a4bc0 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -501,7 +501,7 @@ public class AccountController : Controller { var occupiedSeats = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); var initialSeatCount = organization.Seats.Value; - var availableSeats = initialSeatCount - occupiedSeats; + var availableSeats = initialSeatCount - occupiedSeats.Total; if (availableSeats < 1) { try diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs index 072bc5fc05..22ca906143 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUsersCommand.cs @@ -87,13 +87,14 @@ public class InviteOrganizationUsersCommand(IEventService eventService, new InviteOrganizationUsersResponse(request.InviteOrganization.OrganizationId))); } + var seatCounts = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(request.InviteOrganization.OrganizationId); var validationResult = await inviteUsersValidator.ValidateAsync(new InviteOrganizationUsersValidationRequest { Invites = invitesToSend.ToArray(), InviteOrganization = request.InviteOrganization, PerformedBy = request.PerformedBy, PerformedAt = request.PerformedAt, - OccupiedPmSeats = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(request.InviteOrganization.OrganizationId), + OccupiedPmSeats = seatCounts.Total, OccupiedSmSeats = await organizationUserRepository.GetOccupiedSmSeatCountByOrganizationIdAsync(request.InviteOrganization.OrganizationId) }); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index 74165a5a71..54a7ce3be8 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -66,8 +66,8 @@ public class RestoreOrganizationUserCommand( } var organization = await organizationRepository.GetByIdAsync(organizationUser.OrganizationId); - var occupiedSeats = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - var availableSeats = organization.Seats.GetValueOrDefault(0) - occupiedSeats; + var seatCounts = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + var availableSeats = organization.Seats.GetValueOrDefault(0) - seatCounts.Total; if (availableSeats < 1) { @@ -159,8 +159,8 @@ public class RestoreOrganizationUserCommand( } var organization = await organizationRepository.GetByIdAsync(organizationId); - var occupiedSeats = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - var availableSeats = organization.Seats.GetValueOrDefault(0) - occupiedSeats; + var seatCounts = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + var availableSeats = organization.Seats.GetValueOrDefault(0) - seatCounts.Total; var newSeatsRequired = organizationUserIds.Count() - availableSeats; await organizationService.AutoAddSeatsAsync(organization, newSeatsRequired); diff --git a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs index 9692de897c..bcd562203f 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs @@ -27,7 +27,7 @@ public interface IOrganizationUserRepository : IRepository /// The ID of the organization to get the occupied seat count for. /// The number of occupied seats for the organization. - Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId); + Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId); Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers); Task GetByOrganizationAsync(Guid organizationId, Guid userId); Task>> GetByIdWithCollectionsAsync(Guid id); diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 320fbdfa01..6b647f2f50 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -342,15 +342,12 @@ public class OrganizationService : IOrganizationService if (!organization.Seats.HasValue || organization.Seats.Value > newSeatTotal) { - var totalConsumedSeats = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - var organizationUsers = await _organizationUserRepository.GetManyByOrganizationAsync(organization.Id, null); - var organizationUserOccupiedSeats = organizationUsers.Where(user => user.Status >= 0).Count(); - var sponsoredFamiliesOccupiedSeats = totalConsumedSeats - organizationUserOccupiedSeats; + var seatCounts = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - if (totalConsumedSeats > newSeatTotal) + if (seatCounts.Total > newSeatTotal) { - throw new BadRequestException($"Your organization has {organizationUserOccupiedSeats} members and {sponsoredFamiliesOccupiedSeats} sponsored families. " + - $"To decrease the seat count below {totalConsumedSeats}, you must remove members or sponsorships."); + throw new BadRequestException($"Your organization has {seatCounts.Users} members and {seatCounts.Sponsored} sponsored families. " + + $"To decrease the seat count below {seatCounts.Total}, you must remove members or sponsorships."); } } @@ -846,8 +843,8 @@ public class OrganizationService : IOrganizationService var newSeatsRequired = 0; if (organization.Seats.HasValue) { - var occupiedSeats = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - var availableSeats = organization.Seats.Value - occupiedSeats; + var seatCounts = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + var availableSeats = organization.Seats.Value - seatCounts.Total; newSeatsRequired = invites.Sum(i => i.invite.Emails.Count()) - existingEmails.Count() - availableSeats; } @@ -1303,8 +1300,8 @@ public class OrganizationService : IOrganizationService var enoughSeatsAvailable = true; if (organization.Seats.HasValue) { - var occupiedSeats = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - seatsAvailable = organization.Seats.Value - occupiedSeats; + var seatCounts = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + seatsAvailable = organization.Seats.Value - seatCounts.Total; enoughSeatsAvailable = seatsAvailable >= usersToAdd.Count; } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationSeatCounts.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationSeatCounts.cs new file mode 100644 index 0000000000..6b9f615f64 --- /dev/null +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationSeatCounts.cs @@ -0,0 +1,8 @@ +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationSeatCounts +{ + public int Users { get; set; } + public int Sponsored { get; set; } + public int Total => Users + Sponsored; +} diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index b15cbea240..083f41c01c 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -89,8 +89,8 @@ public class CreateSponsorshipCommand( if (isAdminInitiated && sponsoringOrganization.Seats.HasValue) { - var occupiedSeats = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrganization.Id); - var availableSeats = sponsoringOrganization.Seats.Value - occupiedSeats; + var seatCounts = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrganization.Id); + var availableSeats = sponsoringOrganization.Seats.Value - seatCounts.Total; if (availableSeats <= 0) { diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs index cb37e478f7..716efa2656 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs @@ -117,12 +117,12 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand (newPlan.PasswordManager.HasAdditionalSeatsOption ? upgrade.AdditionalSeats : 0)); if (!organization.Seats.HasValue || organization.Seats.Value > updatedPasswordManagerSeats) { - var occupiedSeats = + var seatCounts = await _organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); - if (occupiedSeats > updatedPasswordManagerSeats) + if (seatCounts.Total > updatedPasswordManagerSeats) { - throw new BadRequestException($"Your organization currently has {occupiedSeats} seats filled. " + - $"Your new plan only has ({updatedPasswordManagerSeats}) seats. Remove some users."); + throw new BadRequestException($"Your organization has {seatCounts.Users} members and {seatCounts.Sponsored} sponsored families. " + + $"To decrease the seat count below {seatCounts.Total}, you must remove members or sponsorships."); } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index 8968d1d243..bfd5e2170d 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -88,16 +88,16 @@ public class OrganizationUserRepository : Repository, IO } } - public async Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) + public async Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) { using (var connection = new SqlConnection(ConnectionString)) { - var result = await connection.ExecuteScalarAsync( + var result = await connection.QueryFirstOrDefaultAsync( "[dbo].[OrganizationUser_ReadOccupiedSeatCountByOrganizationId]", new { OrganizationId = organizationId }, commandType: CommandType.StoredProcedure); - return result; + return result ?? new OrganizationSeatCounts(); } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs index 10d92357fe..82b6e6f776 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -227,10 +227,28 @@ public class OrganizationUserRepository : Repository GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) + public async Task GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) { - var query = new OrganizationUserReadOccupiedSeatCountByOrganizationIdQuery(organizationId); - return await GetCountFromQuery(query); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var users = await dbContext.OrganizationUsers + .Where(ou => ou.OrganizationId == organizationId && ou.Status >= 0) + .CountAsync(); + + var sponsored = await dbContext.OrganizationSponsorships + .Where(os => os.SponsoringOrganizationId == organizationId && + os.IsAdminInitiated && + (os.ToDelete == false || (os.ToDelete == true && os.ValidUntil != null && os.ValidUntil > DateTime.UtcNow)) && + (os.SponsoredOrganizationId == null || (os.SponsoredOrganizationId != null && (os.ValidUntil == null || os.ValidUntil > DateTime.UtcNow)))) + .CountAsync(); + + return new OrganizationSeatCounts + { + Users = users, + Sponsored = sponsored + }; + } } public async Task GetCountByOrganizationIdAsync(Guid organizationId) diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs index e54e4aa99b..aa567b2951 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs @@ -476,7 +476,11 @@ public class InviteOrganizationUserCommandTests orgUserRepository .GetManyByMinimumRoleAsync(inviteOrganization.OrganizationId, OrganizationUserType.Owner) .Returns([ownerDetails]); - orgUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(1); + orgUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 1 + }); orgUserRepository.GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id).Returns(1); var orgRepository = sutProvider.GetDependency(); diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs index 7dc6b7360d..cc69c95d97 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs @@ -5,6 +5,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; using Bit.Core.Repositories; using Bit.Core.Services; @@ -171,7 +172,11 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase // Setup for checking available seats sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) - .Returns(0); + .Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 0 + }); await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, @@ -320,7 +325,11 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase // Setup for checking available seats - organization has plenty of seats sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) - .Returns(5); + .Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 5 + }); var actual = await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName, true, notes); @@ -380,7 +389,11 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase // Setup for checking available seats - organization has no available seats sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) - .Returns(10); + .Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 10 + }); // Setup for checking if can scale sutProvider.GetDependency() @@ -445,7 +458,11 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase // Setup for checking available seats - organization has no available seats sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) - .Returns(10); + .Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 10 + }); // Setup for checking if can scale - cannot scale var failureReason = "Seat limit has been reached."; diff --git a/util/Migrator/DbScripts/2025-05-20_00_UpdateOrgUserReadOccupiedSeatCountForSponsorships.sql b/util/Migrator/DbScripts/2025-05-20_00_UpdateOrgUserReadOccupiedSeatCountForSponsorships.sql new file mode 100644 index 0000000000..8315d6bc33 --- /dev/null +++ b/util/Migrator/DbScripts/2025-05-20_00_UpdateOrgUserReadOccupiedSeatCountForSponsorships.sql @@ -0,0 +1,72 @@ +IF OBJECT_ID('[dbo].[OrganizationUser_ReadOccupiedSeatCountByOrganizationId]') IS NOT NULL +BEGIN + DROP PROCEDURE [dbo].[OrganizationUser_ReadOccupiedSeatCountByOrganizationId] +END +GO + +CREATE PROCEDURE [dbo].[OrganizationUser_ReadOccupiedSeatCountByOrganizationId] + @OrganizationId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + + SELECT + ( + -- Count organization users + SELECT COUNT(1) + FROM [dbo].[OrganizationUserView] + WHERE OrganizationId = @OrganizationId + AND Status >= 0 --Invited + ) as Users, + ( + -- Count admin-initiated sponsorships towards the seat count + -- Introduced in https://bitwarden.atlassian.net/browse/PM-17772 + SELECT COUNT(1) + FROM [dbo].[OrganizationSponsorship] + WHERE SponsoringOrganizationId = @OrganizationId + AND IsAdminInitiated = 1 + AND ( + -- Not marked for deletion - always count + (ToDelete = 0) + OR + -- Marked for deletion but has a valid until date in the future (RevokeWhenExpired status) + (ToDelete = 1 AND ValidUntil IS NOT NULL AND ValidUntil > GETUTCDATE()) + ) + AND ( + -- SENT status: When SponsoredOrganizationId is null + SponsoredOrganizationId IS NULL + OR + -- ACCEPTED status: When SponsoredOrganizationId is not null and ValidUntil is null or in the future + (SponsoredOrganizationId IS NOT NULL AND (ValidUntil IS NULL OR ValidUntil > GETUTCDATE())) + ) + ) as Sponsored, + ( + -- Count organization users + SELECT COUNT(1) + FROM [dbo].[OrganizationUserView] + WHERE OrganizationId = @OrganizationId + AND Status >= 0 --Invited + ) + + ( + -- Count admin-initiated sponsorships towards the seat count + SELECT COUNT(1) + FROM [dbo].[OrganizationSponsorship] + WHERE SponsoringOrganizationId = @OrganizationId + AND IsAdminInitiated = 1 + AND ( + -- Not marked for deletion - always count + (ToDelete = 0) + OR + -- Marked for deletion but has a valid until date in the future (RevokeWhenExpired status) + (ToDelete = 1 AND ValidUntil IS NOT NULL AND ValidUntil > GETUTCDATE()) + ) + AND ( + -- SENT status: When SponsoredOrganizationId is null + SponsoredOrganizationId IS NULL + OR + -- ACCEPTED status: When SponsoredOrganizationId is not null and ValidUntil is null or in the future + (SponsoredOrganizationId IS NOT NULL AND (ValidUntil IS NULL OR ValidUntil > GETUTCDATE())) + ) + ) as Total +END +GO