1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-04 20:50:21 -05:00

Revert filescoped (#2227)

* Revert "Add git blame entry (#2226)"

This reverts commit 239286737d15cb84a893703ee5a8b33a2d67ad3d.

* Revert "Turn on file scoped namespaces (#2225)"

This reverts commit 34fb4cca2aa78deb84d4cbc359992a7c6bba7ea5.
This commit is contained in:
Justin Baur 2022-08-29 15:53:48 -04:00 committed by GitHub
parent 239286737d
commit bae03feffe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1208 changed files with 74317 additions and 73126 deletions

View File

@ -114,9 +114,6 @@ csharp_new_line_before_finally = true
csharp_new_line_before_members_in_object_initializers = true csharp_new_line_before_members_in_object_initializers = true
csharp_new_line_before_members_in_anonymous_types = true csharp_new_line_before_members_in_anonymous_types = true
# Namespace settigns
csharp_style_namespace_declarations = file_scoped:warning
# All files # All files
[*] [*]
guidelines = 120 guidelines = 120

View File

@ -1,5 +1,2 @@
# Apply .NET format https://github.com/bitwarden/server/pull/1764 # Apply .NET format https://github.com/bitwarden/server/pull/1764
23b0a1f9df25058ab29785ecad9a233113c10889 23b0a1f9df25058ab29785ecad9a233113c10889
# Turn on file scoped namespaces https://github.com/bitwarden/server/pull/2225
34fb4cca2aa78deb84d4cbc359992a7c6bba7ea5

View File

@ -84,15 +84,3 @@ We recently migrated to using dotnet-format as code formatter. All previous bran
5. Commit 5. Commit
6. Run `git merge -Xours 23b0a1f9df25058ab29785ecad9a233113c10889` 6. Run `git merge -Xours 23b0a1f9df25058ab29785ecad9a233113c10889`
7. Push 7. Push
### File Scoped Namespaces
We have switched to using file scoped namespace. All previous branches will need to update to avoid large merge conflicts using the following steps:
1. Check out your local Branch
1. Run `git merge 7c4521e0b428d523f2153cda3fb51d51bca9f194`
2. Resolve any merge conflicts, commit.
3. Run `dotnet format`
4. Commit
5. Run `git merge -Xours 34fb4cca2aa78deb84d4cbc359992a7c6bba7ea5`
6. Resolve merge conflicts
7. Push

View File

@ -13,496 +13,497 @@ using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.DataProtection;
namespace Bit.Commercial.Core.Services; namespace Bit.Commercial.Core.Services
public class ProviderService : IProviderService
{ {
public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; public class ProviderService : IProviderService
private readonly IDataProtector _dataProtector;
private readonly IMailService _mailService;
private readonly IEventService _eventService;
private readonly GlobalSettings _globalSettings;
private readonly IProviderRepository _providerRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IUserRepository _userRepository;
private readonly IUserService _userService;
private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
IUserService userService, IOrganizationService organizationService, IMailService mailService,
IDataProtectionProvider dataProtectionProvider, IEventService eventService,
IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext)
{ {
_providerRepository = providerRepository; public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 };
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_organizationRepository = organizationRepository;
_userRepository = userRepository;
_userService = userService;
_organizationService = organizationService;
_mailService = mailService;
_eventService = eventService;
_globalSettings = globalSettings;
_dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
_currentContext = currentContext;
}
public async Task CreateAsync(string ownerEmail) private readonly IDataProtector _dataProtector;
{ private readonly IMailService _mailService;
var owner = await _userRepository.GetByEmailAsync(ownerEmail); private readonly IEventService _eventService;
if (owner == null) private readonly GlobalSettings _globalSettings;
private readonly IProviderRepository _providerRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IUserRepository _userRepository;
private readonly IUserService _userService;
private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
IUserService userService, IOrganizationService organizationService, IMailService mailService,
IDataProtectionProvider dataProtectionProvider, IEventService eventService,
IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext)
{ {
throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user."); _providerRepository = providerRepository;
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_organizationRepository = organizationRepository;
_userRepository = userRepository;
_userService = userService;
_organizationService = organizationService;
_mailService = mailService;
_eventService = eventService;
_globalSettings = globalSettings;
_dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
_currentContext = currentContext;
} }
var provider = new Provider public async Task CreateAsync(string ownerEmail)
{ {
Status = ProviderStatusType.Pending, var owner = await _userRepository.GetByEmailAsync(ownerEmail);
Enabled = true, if (owner == null)
UseEvents = true,
};
await _providerRepository.CreateAsync(provider);
var providerUser = new ProviderUser
{
ProviderId = provider.Id,
UserId = owner.Id,
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Confirmed,
};
await _providerUserRepository.CreateAsync(providerUser);
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
}
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key)
{
var owner = await _userService.GetUserByIdAsync(ownerUserId);
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
if (provider.Status != ProviderStatusType.Pending)
{
throw new BadRequestException("Provider is already setup.");
}
if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId);
if (!(providerUser is { Type: ProviderUserType.ProviderAdmin }))
{
throw new BadRequestException("Invalid owner.");
}
provider.Status = ProviderStatusType.Created;
await _providerRepository.UpsertAsync(provider);
providerUser.Key = key;
await _providerUserRepository.ReplaceAsync(providerUser);
return provider;
}
public async Task UpdateAsync(Provider provider, bool updateBilling = false)
{
if (provider.Id == default)
{
throw new ArgumentException("Cannot create provider this way.");
}
await _providerRepository.ReplaceAsync(provider);
}
public async Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{
throw new InvalidOperationException("Invalid permissions.");
}
var emails = invite?.UserIdentifiers;
var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
if (provider == null || emails == null || !emails.Any())
{
throw new NotFoundException();
}
var providerUsers = new List<ProviderUser>();
foreach (var email in emails)
{
// Make sure user is not already invited
var existingProviderUserCount =
await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false);
if (existingProviderUserCount > 0)
{ {
continue; throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user.");
} }
var provider = new Provider
{
Status = ProviderStatusType.Pending,
Enabled = true,
UseEvents = true,
};
await _providerRepository.CreateAsync(provider);
var providerUser = new ProviderUser var providerUser = new ProviderUser
{ {
ProviderId = invite.ProviderId, ProviderId = provider.Id,
UserId = null, UserId = owner.Id,
Email = email.ToLowerInvariant(), Type = ProviderUserType.ProviderAdmin,
Key = null, Status = ProviderUserStatusType.Confirmed,
Type = invite.Type, };
Status = ProviderUserStatusType.Invited, await _providerUserRepository.CreateAsync(providerUser);
CreationDate = DateTime.UtcNow, await SendProviderSetupInviteEmailAsync(provider, owner.Email);
RevisionDate = DateTime.UtcNow, }
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key)
{
var owner = await _userService.GetUserByIdAsync(ownerUserId);
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
if (provider.Status != ProviderStatusType.Pending)
{
throw new BadRequestException("Provider is already setup.");
}
if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId);
if (!(providerUser is { Type: ProviderUserType.ProviderAdmin }))
{
throw new BadRequestException("Invalid owner.");
}
provider.Status = ProviderStatusType.Created;
await _providerRepository.UpsertAsync(provider);
providerUser.Key = key;
await _providerUserRepository.ReplaceAsync(providerUser);
return provider;
}
public async Task UpdateAsync(Provider provider, bool updateBilling = false)
{
if (provider.Id == default)
{
throw new ArgumentException("Cannot create provider this way.");
}
await _providerRepository.ReplaceAsync(provider);
}
public async Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{
throw new InvalidOperationException("Invalid permissions.");
}
var emails = invite?.UserIdentifiers;
var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
if (provider == null || emails == null || !emails.Any())
{
throw new NotFoundException();
}
var providerUsers = new List<ProviderUser>();
foreach (var email in emails)
{
// Make sure user is not already invited
var existingProviderUserCount =
await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false);
if (existingProviderUserCount > 0)
{
continue;
}
var providerUser = new ProviderUser
{
ProviderId = invite.ProviderId,
UserId = null,
Email = email.ToLowerInvariant(),
Key = null,
Type = invite.Type,
Status = ProviderUserStatusType.Invited,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await _providerUserRepository.CreateAsync(providerUser);
await SendInviteAsync(providerUser, provider);
providerUsers.Add(providerUser);
}
await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?)));
return providerUsers;
}
public async Task<List<Tuple<ProviderUser, string>>> ResendInvitesAsync(ProviderUserInvite<Guid> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{
throw new BadRequestException("Invalid permissions.");
}
var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
var result = new List<Tuple<ProviderUser, string>>();
foreach (var providerUser in providerUsers)
{
if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId)
{
result.Add(Tuple.Create(providerUser, "User invalid."));
continue;
}
await SendInviteAsync(providerUser, provider);
result.Add(Tuple.Create(providerUser, ""));
}
return result;
}
public async Task<ProviderUser> AcceptUserAsync(Guid providerUserId, User user, string token)
{
var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId);
if (providerUser == null)
{
throw new BadRequestException("User invalid.");
}
if (providerUser.Status != ProviderUserStatusType.Invited)
{
throw new BadRequestException("Already accepted.");
}
if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
if (string.IsNullOrWhiteSpace(providerUser.Email) ||
!providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
throw new BadRequestException("User email does not match invite.");
}
providerUser.Status = ProviderUserStatusType.Accepted;
providerUser.UserId = user.Id;
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser);
return providerUser;
}
public async Task<List<Tuple<ProviderUser, string>>> ConfirmUsersAsync(Guid providerId, Dictionary<Guid, string> keys,
Guid confirmingUserId)
{
var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys);
var validProviderUsers = providerUsers
.Where(u => u.UserId != null)
.ToList();
if (!validProviderUsers.Any())
{
return new List<Tuple<ProviderUser, string>>();
}
var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList();
var provider = await _providerRepository.GetByIdAsync(providerId);
var users = await _userRepository.GetManyAsync(validOrganizationUserIds);
var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u);
var result = new List<Tuple<ProviderUser, string>>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var user in users)
{
if (!keyedFilteredUsers.ContainsKey(user.Id))
{
continue;
}
var providerUser = keyedFilteredUsers[user.Id];
try
{
if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId)
{
throw new BadRequestException("Invalid user.");
}
providerUser.Status = ProviderUserStatusType.Confirmed;
providerUser.Key = keys[providerUser.Id];
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser);
events.Add((providerUser, EventType.ProviderUser_Confirmed, null));
await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email);
result.Add(Tuple.Create(providerUser, ""));
}
catch (BadRequestException e)
{
result.Add(Tuple.Create(providerUser, e.Message));
}
}
await _eventService.LogProviderUsersEventAsync(events);
return result;
}
public async Task SaveUserAsync(ProviderUser user, Guid savingUserId)
{
if (user.Id.Equals(default))
{
throw new BadRequestException("Invite the user first.");
}
if (user.Type != ProviderUserType.ProviderAdmin &&
!await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id }))
{
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
}
await _providerUserRepository.ReplaceAsync(user);
await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated);
}
public async Task<List<Tuple<ProviderUser, string>>> DeleteUsersAsync(Guid providerId,
IEnumerable<Guid> providerUserIds, Guid deletingUserId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
throw new NotFoundException();
}
var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds);
var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue)
.Select(pu => pu.UserId.Value));
var keyedUsers = users.ToDictionary(u => u.Id);
if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds))
{
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
}
var result = new List<Tuple<ProviderUser, string>>();
var deletedUserIds = new List<Guid>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var providerUser in providerUsers)
{
try
{
if (providerUser.ProviderId != providerId)
{
throw new BadRequestException("Invalid user.");
}
if (providerUser.UserId == deletingUserId)
{
throw new BadRequestException("You cannot remove yourself.");
}
events.Add((providerUser, EventType.ProviderUser_Removed, null));
var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault());
var email = user == null ? providerUser.Email : user.Email;
if (!string.IsNullOrWhiteSpace(email))
{
await _mailService.SendProviderUserRemoved(provider.Name, email);
}
result.Add(Tuple.Create(providerUser, ""));
deletedUserIds.Add(providerUser.Id);
}
catch (BadRequestException e)
{
result.Add(Tuple.Create(providerUser, e.Message));
}
await _providerUserRepository.DeleteManyAsync(deletedUserIds);
}
await _eventService.LogProviderUsersEventAsync(events);
return result;
}
public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key)
{
var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
if (po != null)
{
throw new BadRequestException("Organization already belongs to a provider.");
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
ThrowOnInvalidPlanType(organization.PlanType);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organizationId,
Key = key,
}; };
await _providerUserRepository.CreateAsync(providerUser); await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added);
await SendInviteAsync(providerUser, provider);
providerUsers.Add(providerUser);
} }
await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?))); public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId,
OrganizationSignup organizationSignup, string clientOwnerEmail, User user)
return providerUsers;
}
public async Task<List<Tuple<ProviderUser, string>>> ResendInvitesAsync(ProviderUserInvite<Guid> invite)
{
if (!_currentContext.ProviderManageUsers(invite.ProviderId))
{ {
throw new BadRequestException("Invalid permissions."); ThrowOnInvalidPlanType(organizationSignup.Plan);
}
var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers); var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true);
var provider = await _providerRepository.GetByIdAsync(invite.ProviderId);
var result = new List<Tuple<ProviderUser, string>>(); var providerOrganization = new ProviderOrganization
foreach (var providerUser in providerUsers)
{
if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId)
{ {
result.Add(Tuple.Create(providerUser, "User invalid.")); ProviderId = providerId,
continue; OrganizationId = organization.Id,
} Key = organizationSignup.OwnerKey,
};
await SendInviteAsync(providerUser, provider); await _providerOrganizationRepository.CreateAsync(providerOrganization);
result.Add(Tuple.Create(providerUser, "")); await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created);
}
return result; await _organizationService.InviteUsersAsync(organization.Id, user.Id,
} new (OrganizationUserInvite, string)[]
public async Task<ProviderUser> AcceptUserAsync(Guid providerUserId, User user, string token)
{
var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId);
if (providerUser == null)
{
throw new BadRequestException("User invalid.");
}
if (providerUser.Status != ProviderUserStatusType.Invited)
{
throw new BadRequestException("Already accepted.");
}
if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id,
_globalSettings.OrganizationInviteExpirationHours))
{
throw new BadRequestException("Invalid token.");
}
if (string.IsNullOrWhiteSpace(providerUser.Email) ||
!providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
throw new BadRequestException("User email does not match invite.");
}
providerUser.Status = ProviderUserStatusType.Accepted;
providerUser.UserId = user.Id;
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser);
return providerUser;
}
public async Task<List<Tuple<ProviderUser, string>>> ConfirmUsersAsync(Guid providerId, Dictionary<Guid, string> keys,
Guid confirmingUserId)
{
var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys);
var validProviderUsers = providerUsers
.Where(u => u.UserId != null)
.ToList();
if (!validProviderUsers.Any())
{
return new List<Tuple<ProviderUser, string>>();
}
var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList();
var provider = await _providerRepository.GetByIdAsync(providerId);
var users = await _userRepository.GetManyAsync(validOrganizationUserIds);
var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u);
var result = new List<Tuple<ProviderUser, string>>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var user in users)
{
if (!keyedFilteredUsers.ContainsKey(user.Id))
{
continue;
}
var providerUser = keyedFilteredUsers[user.Id];
try
{
if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId)
{ {
throw new BadRequestException("Invalid user."); (
} new OrganizationUserInvite
{
Emails = new[] { clientOwnerEmail },
AccessAll = true,
Type = OrganizationUserType.Owner,
Permissions = null,
Collections = Array.Empty<SelectionReadOnly>(),
},
null
)
});
providerUser.Status = ProviderUserStatusType.Confirmed; return providerOrganization;
providerUser.Key = keys[providerUser.Id]; }
providerUser.Email = null;
await _providerUserRepository.ReplaceAsync(providerUser); public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId)
events.Add((providerUser, EventType.ProviderUser_Confirmed, null)); {
await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email); var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId);
result.Add(Tuple.Create(providerUser, "")); if (providerOrganization == null || providerOrganization.ProviderId != providerId)
}
catch (BadRequestException e)
{ {
result.Add(Tuple.Create(providerUser, e.Message)); throw new BadRequestException("Invalid organization.");
}
if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false))
{
throw new BadRequestException("Organization needs to have at least one confirmed owner.");
}
await _providerOrganizationRepository.DeleteAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
}
public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
var owner = await _userRepository.GetByIdAsync(ownerId);
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
}
private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail)
{
var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail);
}
public async Task LogProviderAccessToOrganizationAsync(Guid organizationId)
{
if (organizationId == default)
{
return;
}
var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
var organization = await _organizationRepository.GetByIdAsync(organizationId);
if (providerOrganization != null)
{
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed);
}
if (organization != null)
{
await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed);
} }
} }
await _eventService.LogProviderUsersEventAsync(events); private async Task SendInviteAsync(ProviderUser providerUser, Provider provider)
return result;
}
public async Task SaveUserAsync(ProviderUser user, Guid savingUserId)
{
if (user.Id.Equals(default))
{ {
throw new BadRequestException("Invite the user first."); var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow);
var token = _dataProtector.Protect(
$"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}");
await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email);
} }
if (user.Type != ProviderUserType.ProviderAdmin && private async Task<bool> HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable<Guid> providerUserIds)
!await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id }))
{ {
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId,
ProviderUserType.ProviderAdmin);
var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed);
var confirmedOwnersIds = confirmedOwners.Select(u => u.Id);
return confirmedOwnersIds.Except(providerUserIds).Any();
} }
await _providerUserRepository.ReplaceAsync(user); private void ThrowOnInvalidPlanType(PlanType requestedType)
await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated);
}
public async Task<List<Tuple<ProviderUser, string>>> DeleteUsersAsync(Guid providerId,
IEnumerable<Guid> providerUserIds, Guid deletingUserId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
if (provider == null)
{ {
throw new NotFoundException(); if (ProviderDisllowedOrganizationTypes.Contains(requestedType))
}
var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds);
var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue)
.Select(pu => pu.UserId.Value));
var keyedUsers = users.ToDictionary(u => u.Id);
if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds))
{
throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin.");
}
var result = new List<Tuple<ProviderUser, string>>();
var deletedUserIds = new List<Guid>();
var events = new List<(ProviderUser, EventType, DateTime?)>();
foreach (var providerUser in providerUsers)
{
try
{ {
if (providerUser.ProviderId != providerId) throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed.");
{
throw new BadRequestException("Invalid user.");
}
if (providerUser.UserId == deletingUserId)
{
throw new BadRequestException("You cannot remove yourself.");
}
events.Add((providerUser, EventType.ProviderUser_Removed, null));
var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault());
var email = user == null ? providerUser.Email : user.Email;
if (!string.IsNullOrWhiteSpace(email))
{
await _mailService.SendProviderUserRemoved(provider.Name, email);
}
result.Add(Tuple.Create(providerUser, ""));
deletedUserIds.Add(providerUser.Id);
} }
catch (BadRequestException e)
{
result.Add(Tuple.Create(providerUser, e.Message));
}
await _providerUserRepository.DeleteManyAsync(deletedUserIds);
}
await _eventService.LogProviderUsersEventAsync(events);
return result;
}
public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key)
{
var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
if (po != null)
{
throw new BadRequestException("Organization already belongs to a provider.");
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
ThrowOnInvalidPlanType(organization.PlanType);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organizationId,
Key = key,
};
await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added);
}
public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId,
OrganizationSignup organizationSignup, string clientOwnerEmail, User user)
{
ThrowOnInvalidPlanType(organizationSignup.Plan);
var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true);
var providerOrganization = new ProviderOrganization
{
ProviderId = providerId,
OrganizationId = organization.Id,
Key = organizationSignup.OwnerKey,
};
await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created);
await _organizationService.InviteUsersAsync(organization.Id, user.Id,
new (OrganizationUserInvite, string)[]
{
(
new OrganizationUserInvite
{
Emails = new[] { clientOwnerEmail },
AccessAll = true,
Type = OrganizationUserType.Owner,
Permissions = null,
Collections = Array.Empty<SelectionReadOnly>(),
},
null
)
});
return providerOrganization;
}
public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId)
{
var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId);
if (providerOrganization == null || providerOrganization.ProviderId != providerId)
{
throw new BadRequestException("Invalid organization.");
}
if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false))
{
throw new BadRequestException("Organization needs to have at least one confirmed owner.");
}
await _providerOrganizationRepository.DeleteAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
}
public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
var owner = await _userRepository.GetByIdAsync(ownerId);
if (owner == null)
{
throw new BadRequestException("Invalid owner.");
}
await SendProviderSetupInviteEmailAsync(provider, owner.Email);
}
private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail)
{
var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail);
}
public async Task LogProviderAccessToOrganizationAsync(Guid organizationId)
{
if (organizationId == default)
{
return;
}
var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId);
var organization = await _organizationRepository.GetByIdAsync(organizationId);
if (providerOrganization != null)
{
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed);
}
if (organization != null)
{
await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed);
}
}
private async Task SendInviteAsync(ProviderUser providerUser, Provider provider)
{
var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow);
var token = _dataProtector.Protect(
$"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}");
await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email);
}
private async Task<bool> HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable<Guid> providerUserIds)
{
var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId,
ProviderUserType.ProviderAdmin);
var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed);
var confirmedOwnersIds = confirmedOwners.Select(u => u.Id);
return confirmedOwnersIds.Except(providerUserIds).Any();
}
private void ThrowOnInvalidPlanType(PlanType requestedType)
{
if (ProviderDisllowedOrganizationTypes.Contains(requestedType))
{
throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed.");
} }
} }
} }

View File

@ -2,12 +2,13 @@
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
namespace Bit.Commercial.Core.Utilities; namespace Bit.Commercial.Core.Utilities
public static class ServiceCollectionExtensions
{ {
public static void AddCommCoreServices(this IServiceCollection services) public static class ServiceCollectionExtensions
{ {
services.AddScoped<IProviderService, ProviderService>(); public static void AddCommCoreServices(this IServiceCollection services)
{
services.AddScoped<IProviderService, ProviderService>();
}
} }
} }

View File

@ -4,17 +4,18 @@ using Bit.Core.Models.OrganizationConnectionConfigs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Scim.Context; namespace Bit.Scim.Context
public interface IScimContext
{ {
ScimProviderType RequestScimProvider { get; set; } public interface IScimContext
ScimConfig ScimConfiguration { get; set; } {
Guid? OrganizationId { get; set; } ScimProviderType RequestScimProvider { get; set; }
Organization Organization { get; set; } ScimConfig ScimConfiguration { get; set; }
Task BuildAsync( Guid? OrganizationId { get; set; }
HttpContext httpContext, Organization Organization { get; set; }
GlobalSettings globalSettings, Task BuildAsync(
IOrganizationRepository organizationRepository, HttpContext httpContext,
IOrganizationConnectionRepository organizationConnectionRepository); GlobalSettings globalSettings,
IOrganizationRepository organizationRepository,
IOrganizationConnectionRepository organizationConnectionRepository);
}
} }

View File

@ -4,60 +4,61 @@ using Bit.Core.Models.OrganizationConnectionConfigs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Scim.Context; namespace Bit.Scim.Context
public class ScimContext : IScimContext
{ {
private bool _builtHttpContext; public class ScimContext : IScimContext
public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default;
public ScimConfig ScimConfiguration { get; set; }
public Guid? OrganizationId { get; set; }
public Organization Organization { get; set; }
public async virtual Task BuildAsync(
HttpContext httpContext,
GlobalSettings globalSettings,
IOrganizationRepository organizationRepository,
IOrganizationConnectionRepository organizationConnectionRepository)
{ {
if (_builtHttpContext) private bool _builtHttpContext;
{
return;
}
_builtHttpContext = true; public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default;
public ScimConfig ScimConfiguration { get; set; }
public Guid? OrganizationId { get; set; }
public Organization Organization { get; set; }
string orgIdString = null; public async virtual Task BuildAsync(
if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject)) HttpContext httpContext,
GlobalSettings globalSettings,
IOrganizationRepository organizationRepository,
IOrganizationConnectionRepository organizationConnectionRepository)
{ {
orgIdString = orgIdObject?.ToString(); if (_builtHttpContext)
}
if (Guid.TryParse(orgIdString, out var orgId))
{
OrganizationId = orgId;
Organization = await organizationRepository.GetByIdAsync(orgId);
if (Organization != null)
{ {
var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id, return;
OrganizationConnectionType.Scim);
ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig<ScimConfig>();
} }
}
if (RequestScimProvider == ScimProviderType.Default && _builtHttpContext = true;
httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent))
{ string orgIdString = null;
if (userAgent.ToString().StartsWith("Okta")) if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject))
{ {
RequestScimProvider = ScimProviderType.Okta; orgIdString = orgIdObject?.ToString();
}
if (Guid.TryParse(orgIdString, out var orgId))
{
OrganizationId = orgId;
Organization = await organizationRepository.GetByIdAsync(orgId);
if (Organization != null)
{
var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id,
OrganizationConnectionType.Scim);
ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig<ScimConfig>();
}
}
if (RequestScimProvider == ScimProviderType.Default &&
httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent))
{
if (userAgent.ToString().StartsWith("Okta"))
{
RequestScimProvider = ScimProviderType.Okta;
}
}
if (RequestScimProvider == ScimProviderType.Default &&
httpContext.Request.Headers.ContainsKey("Adscimversion"))
{
RequestScimProvider = ScimProviderType.AzureAd;
} }
}
if (RequestScimProvider == ScimProviderType.Default &&
httpContext.Request.Headers.ContainsKey("Adscimversion"))
{
RequestScimProvider = ScimProviderType.AzureAd;
} }
} }
} }

View File

@ -2,21 +2,22 @@
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Scim.Controllers; namespace Bit.Scim.Controllers
[AllowAnonymous]
public class InfoController : Controller
{ {
[HttpGet("~/alive")] [AllowAnonymous]
[HttpGet("~/now")] public class InfoController : Controller
public DateTime GetAlive()
{ {
return DateTime.UtcNow; [HttpGet("~/alive")]
} [HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")] [HttpGet("~/version")]
public JsonResult GetVersion() public JsonResult GetVersion()
{ {
return Json(CoreHelpers.GetVersion()); return Json(CoreHelpers.GetVersion());
}
} }
} }

View File

@ -8,320 +8,321 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Scim.Controllers.v2; namespace Bit.Scim.Controllers.v2
[Authorize("Scim")]
[Route("v2/{organizationId}/groups")]
public class GroupsController : Controller
{ {
private readonly ScimSettings _scimSettings; [Authorize("Scim")]
private readonly IGroupRepository _groupRepository; [Route("v2/{organizationId}/groups")]
private readonly IGroupService _groupService; public class GroupsController : Controller
private readonly IScimContext _scimContext;
private readonly ILogger<GroupsController> _logger;
public GroupsController(
IGroupRepository groupRepository,
IGroupService groupService,
IOptions<ScimSettings> scimSettings,
IScimContext scimContext,
ILogger<GroupsController> logger)
{ {
_scimSettings = scimSettings?.Value; private readonly ScimSettings _scimSettings;
_groupRepository = groupRepository; private readonly IGroupRepository _groupRepository;
_groupService = groupService; private readonly IGroupService _groupService;
_scimContext = scimContext; private readonly IScimContext _scimContext;
_logger = logger; private readonly ILogger<GroupsController> _logger;
}
[HttpGet("{id}")] public GroupsController(
public async Task<IActionResult> Get(Guid organizationId, Guid id) IGroupRepository groupRepository,
{ IGroupService groupService,
var group = await _groupRepository.GetByIdAsync(id); IOptions<ScimSettings> scimSettings,
if (group == null || group.OrganizationId != organizationId) IScimContext scimContext,
ILogger<GroupsController> logger)
{ {
return new NotFoundObjectResult(new ScimErrorResponseModel _scimSettings = scimSettings?.Value;
_groupRepository = groupRepository;
_groupService = groupService;
_scimContext = scimContext;
_logger = logger;
}
[HttpGet("{id}")]
public async Task<IActionResult> Get(Guid organizationId, Guid id)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{ {
Status = 404, return new NotFoundObjectResult(new ScimErrorResponseModel
Detail = "Group not found."
});
}
return new ObjectResult(new ScimGroupResponseModel(group));
}
[HttpGet("")]
public async Task<IActionResult> Get(
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{
string nameFilter = null;
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{
if (filter.StartsWith("displayName eq "))
{
nameFilter = filter.Substring(15).Trim('"');
}
else if (filter.StartsWith("externalId eq "))
{
externalIdFilter = filter.Substring(14).Trim('"');
}
}
var groupList = new List<ScimGroupResponseModel>();
var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
var totalResults = 0;
if (!string.IsNullOrWhiteSpace(nameFilter))
{
var group = groups.FirstOrDefault(g => g.Name == nameFilter);
if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
groupList = groups.OrderBy(g => g.Name)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(g => new ScimGroupResponseModel(g))
.ToList();
totalResults = groups.Count;
}
var result = new ScimListResponseModel<ScimGroupResponseModel>
{
Resources = groupList,
ItemsPerPage = count.GetValueOrDefault(groupList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
}
[HttpPost("")]
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimGroupRequestModel model)
{
if (string.IsNullOrWhiteSpace(model.DisplayName))
{
return new BadRequestResult();
}
var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId))
{
return new ConflictResult();
}
var group = model.ToGroup(organizationId);
await _groupService.SaveAsync(group, null);
await UpdateGroupMembersAsync(group, model, true);
var response = new ScimGroupResponseModel(group);
return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response);
}
[HttpPut("{id}")]
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
group.Name = model.DisplayName;
await _groupService.SaveAsync(group);
await UpdateGroupMembersAsync(group, model, false);
return new ObjectResult(new ScimGroupResponseModel(group));
}
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Replace a list of members
if (operation.Path?.ToLowerInvariant() == "members")
{ {
var ids = GetOperationValueIds(operation.Value); Status = 404,
await _groupRepository.UpdateUsersAsync(group.Id, ids); Detail = "Group not found."
operationHandled = true; });
}
return new ObjectResult(new ScimGroupResponseModel(group));
}
[HttpGet("")]
public async Task<IActionResult> Get(
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{
string nameFilter = null;
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{
if (filter.StartsWith("displayName eq "))
{
nameFilter = filter.Substring(15).Trim('"');
} }
// Replace group name from path else if (filter.StartsWith("externalId eq "))
else if (operation.Path?.ToLowerInvariant() == "displayname")
{ {
group.Name = operation.Value.GetString(); externalIdFilter = filter.Substring(14).Trim('"');
await _groupService.SaveAsync(group);
operationHandled = true;
}
// Replace group name from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty))
{
group.Name = displayNameProperty.GetString();
await _groupService.SaveAsync(group);
operationHandled = true;
} }
} }
// Add a single member
else if (operation.Op?.ToLowerInvariant() == "add" && var groupList = new List<ScimGroupResponseModel>();
!string.IsNullOrWhiteSpace(operation.Path) && var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) var totalResults = 0;
if (!string.IsNullOrWhiteSpace(nameFilter))
{ {
var addId = GetOperationPathId(operation.Path); var group = groups.FirstOrDefault(g => g.Name == nameFilter);
if (addId.HasValue) if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (group != null)
{
groupList.Add(new ScimGroupResponseModel(group));
}
totalResults = groupList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
groupList = groups.OrderBy(g => g.Name)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(g => new ScimGroupResponseModel(g))
.ToList();
totalResults = groups.Count;
}
var result = new ScimListResponseModel<ScimGroupResponseModel>
{
Resources = groupList,
ItemsPerPage = count.GetValueOrDefault(groupList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
}
[HttpPost("")]
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimGroupRequestModel model)
{
if (string.IsNullOrWhiteSpace(model.DisplayName))
{
return new BadRequestResult();
}
var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId);
if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId))
{
return new ConflictResult();
}
var group = model.ToGroup(organizationId);
await _groupService.SaveAsync(group, null);
await UpdateGroupMembersAsync(group, model, true);
var response = new ScimGroupResponseModel(group);
return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response);
}
[HttpPut("{id}")]
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
group.Name = model.DisplayName;
await _groupService.SaveAsync(group);
await UpdateGroupMembersAsync(group, model, false);
return new ObjectResult(new ScimGroupResponseModel(group));
}
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "Group not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Replace a list of members
if (operation.Path?.ToLowerInvariant() == "members")
{
var ids = GetOperationValueIds(operation.Value);
await _groupRepository.UpdateUsersAsync(group.Id, ids);
operationHandled = true;
}
// Replace group name from path
else if (operation.Path?.ToLowerInvariant() == "displayname")
{
group.Name = operation.Value.GetString();
await _groupService.SaveAsync(group);
operationHandled = true;
}
// Replace group name from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty))
{
group.Name = displayNameProperty.GetString();
await _groupService.SaveAsync(group);
operationHandled = true;
}
}
// Add a single member
else if (operation.Op?.ToLowerInvariant() == "add" &&
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{
var addId = GetOperationPathId(operation.Path);
if (addId.HasValue)
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
orgUserIds.Add(addId.Value);
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
}
}
// Add a list of members
else if (operation.Op?.ToLowerInvariant() == "add" &&
operation.Path?.ToLowerInvariant() == "members")
{ {
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
orgUserIds.Add(addId.Value); foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Add(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
}
// Remove a single member
else if (operation.Op?.ToLowerInvariant() == "remove" &&
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{
var removeId = GetOperationPathId(operation.Path);
if (removeId.HasValue)
{
await _groupService.DeleteUserAsync(group, removeId.Value);
operationHandled = true;
}
}
// Remove a list of members
else if (operation.Op?.ToLowerInvariant() == "remove" &&
operation.Path?.ToLowerInvariant() == "members")
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true; operationHandled = true;
} }
} }
// Add a list of members
else if (operation.Op?.ToLowerInvariant() == "add" && if (!operationHandled)
operation.Path?.ToLowerInvariant() == "members")
{ {
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); _logger.LogWarning("Group patch operation not handled: {0} : ",
foreach (var v in GetOperationValueIds(operation.Value)) string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
{
orgUserIds.Add(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
} }
// Remove a single member
else if (operation.Op?.ToLowerInvariant() == "remove" && return new NoContentResult();
!string.IsNullOrWhiteSpace(operation.Path) && }
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{ {
var removeId = GetOperationPathId(operation.Path); return new NotFoundObjectResult(new ScimErrorResponseModel
if (removeId.HasValue)
{ {
await _groupService.DeleteUserAsync(group, removeId.Value); Status = 404,
operationHandled = true; Detail = "Group not found."
});
}
await _groupService.DeleteAsync(group);
return new NoContentResult();
}
private List<Guid> GetOperationValueIds(JsonElement objArray)
{
var ids = new List<Guid>();
foreach (var obj in objArray.EnumerateArray())
{
if (obj.TryGetProperty("value", out var valueProperty))
{
if (valueProperty.TryGetGuid(out var guid))
{
ids.Add(guid);
}
} }
} }
// Remove a list of members return ids;
else if (operation.Op?.ToLowerInvariant() == "remove" && }
operation.Path?.ToLowerInvariant() == "members")
private Guid? GetOperationPathId(string path)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id))
{ {
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); return id;
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
} }
return null;
} }
if (!operationHandled) private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty)
{ {
_logger.LogWarning("Group patch operation not handled: {0} : ", if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta)
string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
}
return new NoContentResult();
}
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{ {
Status = 404, return;
Detail = "Group not found." }
});
}
await _groupService.DeleteAsync(group);
return new NoContentResult();
}
private List<Guid> GetOperationValueIds(JsonElement objArray) if (model.Members == null)
{
var ids = new List<Guid>();
foreach (var obj in objArray.EnumerateArray())
{
if (obj.TryGetProperty("value", out var valueProperty))
{ {
if (valueProperty.TryGetGuid(out var guid)) return;
}
var memberIds = new List<Guid>();
foreach (var id in model.Members.Select(i => i.Value))
{
if (Guid.TryParse(id, out var guidId))
{ {
ids.Add(guid); memberIds.Add(guidId);
} }
} }
}
return ids;
}
private Guid? GetOperationPathId(string path) if (!memberIds.Any() && skipIfEmpty)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id))
{
return id;
}
return null;
}
private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty)
{
if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta)
{
return;
}
if (model.Members == null)
{
return;
}
var memberIds = new List<Guid>();
foreach (var id in model.Members.Select(i => i.Value))
{
if (Guid.TryParse(id, out var guidId))
{ {
memberIds.Add(guidId); return;
} }
}
if (!memberIds.Any() && skipIfEmpty) await _groupRepository.UpdateUsersAsync(group.Id, memberIds);
{
return;
} }
await _groupRepository.UpdateUsersAsync(group.Id, memberIds);
} }
} }

View File

@ -9,286 +9,287 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Scim.Controllers.v2; namespace Bit.Scim.Controllers.v2
[Authorize("Scim")]
[Route("v2/{organizationId}/users")]
public class UsersController : Controller
{ {
private readonly IUserService _userService; [Authorize("Scim")]
private readonly IUserRepository _userRepository; [Route("v2/{organizationId}/users")]
private readonly IOrganizationUserRepository _organizationUserRepository; public class UsersController : Controller
private readonly IOrganizationService _organizationService;
private readonly IScimContext _scimContext;
private readonly ScimSettings _scimSettings;
private readonly ILogger<UsersController> _logger;
public UsersController(
IUserService userService,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationService organizationService,
IScimContext scimContext,
IOptions<ScimSettings> scimSettings,
ILogger<UsersController> logger)
{ {
_userService = userService; private readonly IUserService _userService;
_userRepository = userRepository; private readonly IUserRepository _userRepository;
_organizationUserRepository = organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
_organizationService = organizationService; private readonly IOrganizationService _organizationService;
_scimContext = scimContext; private readonly IScimContext _scimContext;
_scimSettings = scimSettings?.Value; private readonly ScimSettings _scimSettings;
_logger = logger; private readonly ILogger<UsersController> _logger;
}
[HttpGet("{id}")] public UsersController(
public async Task<IActionResult> Get(Guid organizationId, Guid id) IUserService userService,
{ IUserRepository userRepository,
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id); IOrganizationUserRepository organizationUserRepository,
if (orgUser == null || orgUser.OrganizationId != organizationId) IOrganizationService organizationService,
IScimContext scimContext,
IOptions<ScimSettings> scimSettings,
ILogger<UsersController> logger)
{ {
return new NotFoundObjectResult(new ScimErrorResponseModel _userService = userService;
{ _userRepository = userRepository;
Status = 404, _organizationUserRepository = organizationUserRepository;
Detail = "User not found." _organizationService = organizationService;
}); _scimContext = scimContext;
_scimSettings = scimSettings?.Value;
_logger = logger;
} }
return new ObjectResult(new ScimUserResponseModel(orgUser));
}
[HttpGet("")] [HttpGet("{id}")]
public async Task<IActionResult> Get( public async Task<IActionResult> Get(Guid organizationId, Guid id)
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{
string emailFilter = null;
string usernameFilter = null;
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{ {
if (filter.StartsWith("userName eq ")) var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{ {
usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant(); return new NotFoundObjectResult(new ScimErrorResponseModel
if (usernameFilter.Contains("@"))
{ {
emailFilter = usernameFilter; Status = 404,
} Detail = "User not found."
});
} }
else if (filter.StartsWith("externalId eq ")) return new ObjectResult(new ScimUserResponseModel(orgUser));
}
[HttpGet("")]
public async Task<IActionResult> Get(
Guid organizationId,
[FromQuery] string filter,
[FromQuery] int? count,
[FromQuery] int? startIndex)
{
string emailFilter = null;
string usernameFilter = null;
string externalIdFilter = null;
if (!string.IsNullOrWhiteSpace(filter))
{ {
externalIdFilter = filter.Substring(14).Trim('"'); if (filter.StartsWith("userName eq "))
}
}
var userList = new List<ScimUserResponseModel> { };
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var totalResults = 0;
if (!string.IsNullOrWhiteSpace(emailFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
userList = orgUsers.OrderBy(ou => ou.Email)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(ou => new ScimUserResponseModel(ou))
.ToList();
totalResults = orgUsers.Count;
}
var result = new ScimListResponseModel<ScimUserResponseModel>
{
Resources = userList,
ItemsPerPage = count.GetValueOrDefault(userList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
}
[HttpPost("")]
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimUserRequestModel model)
{
var email = model.PrimaryEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
switch (_scimContext.RequestScimProvider)
{
case ScimProviderType.AzureAd:
email = model.UserName?.ToLowerInvariant();
break;
default:
email = model.WorkEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant();
}
break;
}
}
if (string.IsNullOrWhiteSpace(email) || !model.Active)
{
return new BadRequestResult();
}
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email);
if (orgUserByEmail != null)
{
return new ConflictResult();
}
string externalId = null;
if (!string.IsNullOrWhiteSpace(model.ExternalId))
{
externalId = model.ExternalId;
}
else if (!string.IsNullOrWhiteSpace(model.UserName))
{
externalId = model.UserName;
}
else
{
externalId = CoreHelpers.RandomString(15);
}
var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId);
if (orgUserByExternalId != null)
{
return new ConflictResult();
}
var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email,
OrganizationUserType.User, false, externalId, new List<SelectionReadOnly>());
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id);
var response = new ScimUserResponseModel(orgUser);
return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response);
}
[HttpPut("{id}")]
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
}
else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
}
// Have to get full details object for response model
var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id);
return new ObjectResult(new ScimUserResponseModel(orgUserDetails));
}
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Active from path
if (operation.Path?.ToLowerInvariant() == "active")
{ {
var active = operation.Value.ToString()?.ToLowerInvariant(); usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant();
var handled = await HandleActiveOperationAsync(orgUser, active == "true"); if (usernameFilter.Contains("@"))
if (!operationHandled)
{ {
operationHandled = handled; emailFilter = usernameFilter;
} }
} }
// Active from value object else if (filter.StartsWith("externalId eq "))
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("active", out var activeProperty))
{ {
var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean()); externalIdFilter = filter.Substring(14).Trim('"');
if (!operationHandled) }
}
var userList = new List<ScimUserResponseModel> { };
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var totalResults = 0;
if (!string.IsNullOrWhiteSpace(emailFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (!string.IsNullOrWhiteSpace(externalIdFilter))
{
var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter);
if (orgUser != null)
{
userList.Add(new ScimUserResponseModel(orgUser));
}
totalResults = userList.Count;
}
else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue)
{
userList = orgUsers.OrderBy(ou => ou.Email)
.Skip(startIndex.Value - 1)
.Take(count.Value)
.Select(ou => new ScimUserResponseModel(ou))
.ToList();
totalResults = orgUsers.Count;
}
var result = new ScimListResponseModel<ScimUserResponseModel>
{
Resources = userList,
ItemsPerPage = count.GetValueOrDefault(userList.Count),
TotalResults = totalResults,
StartIndex = startIndex.GetValueOrDefault(1),
};
return new ObjectResult(result);
}
[HttpPost("")]
public async Task<IActionResult> Post(Guid organizationId, [FromBody] ScimUserRequestModel model)
{
var email = model.PrimaryEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
switch (_scimContext.RequestScimProvider)
{
case ScimProviderType.AzureAd:
email = model.UserName?.ToLowerInvariant();
break;
default:
email = model.WorkEmail?.ToLowerInvariant();
if (string.IsNullOrWhiteSpace(email))
{
email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant();
}
break;
}
}
if (string.IsNullOrWhiteSpace(email) || !model.Active)
{
return new BadRequestResult();
}
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email);
if (orgUserByEmail != null)
{
return new ConflictResult();
}
string externalId = null;
if (!string.IsNullOrWhiteSpace(model.ExternalId))
{
externalId = model.ExternalId;
}
else if (!string.IsNullOrWhiteSpace(model.UserName))
{
externalId = model.UserName;
}
else
{
externalId = CoreHelpers.RandomString(15);
}
var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId);
if (orgUserByExternalId != null)
{
return new ConflictResult();
}
var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email,
OrganizationUserType.User, false, externalId, new List<SelectionReadOnly>());
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id);
var response = new ScimUserResponseModel(orgUser);
return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response);
}
[HttpPut("{id}")]
public async Task<IActionResult> Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
}
else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
}
// Have to get full details object for response model
var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id);
return new ObjectResult(new ScimUserResponseModel(orgUserDetails));
}
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Active from path
if (operation.Path?.ToLowerInvariant() == "active")
{ {
operationHandled = handled; var active = operation.Value.ToString()?.ToLowerInvariant();
var handled = await HandleActiveOperationAsync(orgUser, active == "true");
if (!operationHandled)
{
operationHandled = handled;
}
}
// Active from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("active", out var activeProperty))
{
var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean());
if (!operationHandled)
{
operationHandled = handled;
}
} }
} }
} }
}
if (!operationHandled) if (!operationHandled)
{
_logger.LogWarning("User patch operation not handled: {operation} : ",
string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
}
return new NoContentResult();
}
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(id);
if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{ {
Status = 404, _logger.LogWarning("User patch operation not handled: {operation} : ",
Detail = "User not found." string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
}); }
}
await _organizationService.DeleteUserAsync(organizationId, id, null);
return new NoContentResult();
}
private async Task<bool> HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active) return new NoContentResult();
{
if (active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
return true;
} }
else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked)
[HttpDelete("{id}")]
public async Task<IActionResult> Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model)
{ {
await _organizationService.RevokeUserAsync(orgUser, null); var orgUser = await _organizationUserRepository.GetByIdAsync(id);
return true; if (orgUser == null || orgUser.OrganizationId != organizationId)
{
return new NotFoundObjectResult(new ScimErrorResponseModel
{
Status = 404,
Detail = "User not found."
});
}
await _organizationService.DeleteUserAsync(organizationId, id, null);
return new NoContentResult();
}
private async Task<bool> HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active)
{
if (active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, null, _userService);
return true;
}
else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked)
{
await _organizationService.RevokeUserAsync(orgUser, null);
return true;
}
return false;
} }
return false;
} }
} }

View File

@ -1,17 +1,18 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public abstract class BaseScimGroupModel : BaseScimModel
{ {
public BaseScimGroupModel(bool initSchema = false) public abstract class BaseScimGroupModel : BaseScimModel
{ {
if (initSchema) public BaseScimGroupModel(bool initSchema = false)
{ {
Schemas = new List<string> { ScimConstants.Scim2SchemaGroup }; if (initSchema)
{
Schemas = new List<string> { ScimConstants.Scim2SchemaGroup };
}
} }
}
public string DisplayName { get; set; } public string DisplayName { get; set; }
public string ExternalId { get; set; } public string ExternalId { get; set; }
}
} }

View File

@ -1,14 +1,15 @@
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public abstract class BaseScimModel
{ {
public BaseScimModel() public abstract class BaseScimModel
{ }
public BaseScimModel(string schema)
{ {
Schemas = new List<string> { schema }; public BaseScimModel()
} { }
public List<string> Schemas { get; set; } public BaseScimModel(string schema)
{
Schemas = new List<string> { schema };
}
public List<string> Schemas { get; set; }
}
} }

View File

@ -1,55 +1,56 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public abstract class BaseScimUserModel : BaseScimModel
{ {
public BaseScimUserModel(bool initSchema = false) public abstract class BaseScimUserModel : BaseScimModel
{ {
if (initSchema) public BaseScimUserModel(bool initSchema = false)
{ {
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }; if (initSchema)
} {
} Schemas = new List<string> { ScimConstants.Scim2SchemaUser };
}
public string UserName { get; set; }
public NameModel Name { get; set; }
public List<EmailModel> Emails { get; set; }
public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value;
public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value;
public string DisplayName { get; set; }
public bool Active { get; set; }
public List<string> Groups { get; set; }
public string ExternalId { get; set; }
public class NameModel
{
public NameModel() { }
public NameModel(string name)
{
Formatted = name;
} }
public string Formatted { get; set; } public string UserName { get; set; }
public string GivenName { get; set; } public NameModel Name { get; set; }
public string MiddleName { get; set; } public List<EmailModel> Emails { get; set; }
public string FamilyName { get; set; } public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value;
} public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value;
public string DisplayName { get; set; }
public bool Active { get; set; }
public List<string> Groups { get; set; }
public string ExternalId { get; set; }
public class EmailModel public class NameModel
{
public EmailModel() { }
public EmailModel(string email)
{ {
Primary = true; public NameModel() { }
Value = email;
Type = "work"; public NameModel(string name)
{
Formatted = name;
}
public string Formatted { get; set; }
public string GivenName { get; set; }
public string MiddleName { get; set; }
public string FamilyName { get; set; }
} }
public bool Primary { get; set; } public class EmailModel
public string Value { get; set; } {
public string Type { get; set; } public EmailModel() { }
public EmailModel(string email)
{
Primary = true;
Value = email;
Type = "work";
}
public bool Primary { get; set; }
public string Value { get; set; }
public string Type { get; set; }
}
} }
} }

View File

@ -1,13 +1,14 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimErrorResponseModel : BaseScimModel
{ {
public ScimErrorResponseModel() public class ScimErrorResponseModel : BaseScimModel
: base(ScimConstants.Scim2SchemaError) {
{ } public ScimErrorResponseModel()
: base(ScimConstants.Scim2SchemaError)
{ }
public string Detail { get; set; } public string Detail { get; set; }
public int Status { get; set; } public int Status { get; set; }
}
} }

View File

@ -1,30 +1,31 @@
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimGroupRequestModel : BaseScimGroupModel
{ {
public ScimGroupRequestModel() public class ScimGroupRequestModel : BaseScimGroupModel
: base(false)
{ }
public Group ToGroup(Guid organizationId)
{ {
var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId; public ScimGroupRequestModel()
return new Group : base(false)
{ }
public Group ToGroup(Guid organizationId)
{ {
Name = DisplayName, var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId;
ExternalId = externalId, return new Group
OrganizationId = organizationId {
}; Name = DisplayName,
} ExternalId = externalId,
OrganizationId = organizationId
};
}
public List<GroupMembersModel> Members { get; set; } public List<GroupMembersModel> Members { get; set; }
public class GroupMembersModel public class GroupMembersModel
{ {
public string Value { get; set; } public string Value { get; set; }
public string Display { get; set; } public string Display { get; set; }
}
} }
} }

View File

@ -1,25 +1,26 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimGroupResponseModel : BaseScimGroupModel
{ {
public ScimGroupResponseModel() public class ScimGroupResponseModel : BaseScimGroupModel
: base(true)
{ {
Meta = new ScimMetaModel("Group"); public ScimGroupResponseModel()
} : base(true)
{
Meta = new ScimMetaModel("Group");
}
public ScimGroupResponseModel(Group group) public ScimGroupResponseModel(Group group)
: this() : this()
{ {
Id = group.Id.ToString(); Id = group.Id.ToString();
DisplayName = group.Name; DisplayName = group.Name;
ExternalId = group.ExternalId; ExternalId = group.ExternalId;
Meta.Created = group.CreationDate; Meta.Created = group.CreationDate;
Meta.LastModified = group.RevisionDate; Meta.LastModified = group.RevisionDate;
} }
public string Id { get; set; } public string Id { get; set; }
public ScimMetaModel Meta { get; private set; } public ScimMetaModel Meta { get; private set; }
}
} }

View File

@ -1,15 +1,16 @@
using Bit.Scim.Utilities; using Bit.Scim.Utilities;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimListResponseModel<T> : BaseScimModel
{ {
public ScimListResponseModel() public class ScimListResponseModel<T> : BaseScimModel
: base(ScimConstants.Scim2SchemaListResponse) {
{ } public ScimListResponseModel()
: base(ScimConstants.Scim2SchemaListResponse)
{ }
public int TotalResults { get; set; } public int TotalResults { get; set; }
public int StartIndex { get; set; } public int StartIndex { get; set; }
public int ItemsPerPage { get; set; } public int ItemsPerPage { get; set; }
public List<T> Resources { get; set; } public List<T> Resources { get; set; }
}
} }

View File

@ -1,13 +1,14 @@
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimMetaModel
{ {
public ScimMetaModel(string resourceType) public class ScimMetaModel
{ {
ResourceType = resourceType; public ScimMetaModel(string resourceType)
} {
ResourceType = resourceType;
}
public string ResourceType { get; set; } public string ResourceType { get; set; }
public DateTime? Created { get; set; } public DateTime? Created { get; set; }
public DateTime? LastModified { get; set; } public DateTime? LastModified { get; set; }
}
} }

View File

@ -1,18 +1,19 @@
using System.Text.Json; using System.Text.Json;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimPatchModel : BaseScimModel
{ {
public ScimPatchModel() public class ScimPatchModel : BaseScimModel
: base() { }
public List<OperationModel> Operations { get; set; }
public class OperationModel
{ {
public string Op { get; set; } public ScimPatchModel()
public string Path { get; set; } : base() { }
public JsonElement Value { get; set; }
public List<OperationModel> Operations { get; set; }
public class OperationModel
{
public string Op { get; set; }
public string Path { get; set; }
public JsonElement Value { get; set; }
}
} }
} }

View File

@ -1,8 +1,9 @@
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimUserRequestModel : BaseScimUserModel
{ {
public ScimUserRequestModel() public class ScimUserRequestModel : BaseScimUserModel
: base(false) {
{ } public ScimUserRequestModel()
: base(false)
{ }
}
} }

View File

@ -1,28 +1,29 @@
using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.Data.Organizations.OrganizationUsers;
namespace Bit.Scim.Models; namespace Bit.Scim.Models
public class ScimUserResponseModel : BaseScimUserModel
{ {
public ScimUserResponseModel() public class ScimUserResponseModel : BaseScimUserModel
: base(true)
{ {
Meta = new ScimMetaModel("User"); public ScimUserResponseModel()
Groups = new List<string>(); : base(true)
} {
Meta = new ScimMetaModel("User");
Groups = new List<string>();
}
public ScimUserResponseModel(OrganizationUserUserDetails orgUser) public ScimUserResponseModel(OrganizationUserUserDetails orgUser)
: this() : this()
{ {
Id = orgUser.Id.ToString(); Id = orgUser.Id.ToString();
ExternalId = orgUser.ExternalId; ExternalId = orgUser.ExternalId;
UserName = orgUser.Email; UserName = orgUser.Email;
DisplayName = orgUser.Name; DisplayName = orgUser.Name;
Emails = new List<EmailModel> { new EmailModel(orgUser.Email) }; Emails = new List<EmailModel> { new EmailModel(orgUser.Email) };
Name = new NameModel(orgUser.Name); Name = new NameModel(orgUser.Name);
Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked; Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked;
} }
public string Id { get; set; } public string Id { get; set; }
public ScimMetaModel Meta { get; private set; } public ScimMetaModel Meta { get; private set; }
}
} }

View File

@ -1,33 +1,34 @@
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Serilog.Events; using Serilog.Events;
namespace Bit.Scim; namespace Bit.Scim
public class Program
{ {
public static void Main(string[] args) public class Program
{ {
Host public static void Main(string[] args)
.CreateDefaultBuilder(args) {
.ConfigureWebHostDefaults(webBuilder => Host
{ .CreateDefaultBuilder(args)
webBuilder.UseStartup<Startup>(); .ConfigureWebHostDefaults(webBuilder =>
webBuilder.ConfigureLogging((hostingContext, logging) => {
logging.AddSerilog(hostingContext, e => webBuilder.UseStartup<Startup>();
{ webBuilder.ConfigureLogging((hostingContext, logging) =>
var context = e.Properties["SourceContext"].ToString(); logging.AddSerilog(hostingContext, e =>
if (e.Properties.ContainsKey("RequestPath") &&
!string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) &&
(context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer")))
{ {
return false; var context = e.Properties["SourceContext"].ToString();
}
return e.Level >= LogEventLevel.Warning; if (e.Properties.ContainsKey("RequestPath") &&
})); !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) &&
}) (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer")))
.Build() {
.Run(); return false;
}
return e.Level >= LogEventLevel.Warning;
}));
})
.Build()
.Run();
}
} }
} }

View File

@ -1,5 +1,6 @@
namespace Bit.Scim; namespace Bit.Scim
public class ScimSettings
{ {
public class ScimSettings
{
}
} }

View File

@ -9,107 +9,108 @@ using IdentityModel;
using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.DependencyInjection.Extensions;
using Stripe; using Stripe;
namespace Bit.Scim; namespace Bit.Scim
public class Startup
{ {
public Startup(IWebHostEnvironment env, IConfiguration configuration) public class Startup
{ {
CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); public Startup(IWebHostEnvironment env, IConfiguration configuration)
Configuration = configuration;
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
services.Configure<ScimSettings>(Configuration.GetSection("ScimSettings"));
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
services.AddScoped<IScimContext, ScimContext>();
// Authentication
services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme)
.AddScheme<ApiKeyAuthenticationOptions, ApiKeyAuthenticationHandler>(
ApiKeyAuthenticationOptions.DefaultScheme, null);
services.AddAuthorization(config =>
{ {
config.AddPolicy("Scim", policy => CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US");
{ Configuration = configuration;
policy.RequireAuthenticatedUser(); Environment = env;
policy.RequireClaim(JwtClaimTypes.Scope, "api.scim");
});
});
// Identity
services.AddCustomIdentityServices(globalSettings);
// Services
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>();
// Mvc
services.AddMvc(config =>
{
config.Filters.Add(new LoggingExceptionHandlerFilterAttribute());
});
services.Configure<RouteOptions>(options => options.LowercaseUrls = true);
}
public void Configure(
IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings)
{
app.UseSerilog(env, appLifetime, globalSettings);
// Add general security headers
app.UseMiddleware<SecurityHeadersMiddleware>();
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
} }
// Default Middleware public IConfiguration Configuration { get; }
app.UseDefaultMiddleware(env, globalSettings); public IWebHostEnvironment Environment { get; set; }
// Add routing public void ConfigureServices(IServiceCollection services)
app.UseRouting(); {
// Options
services.AddOptions();
// Add Scim context // Settings
app.UseMiddleware<ScimContextMiddleware>(); var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
services.Configure<ScimSettings>(Configuration.GetSection("ScimSettings"));
// Add authentication and authorization to the request pipeline. // Data Protection
app.UseAuthentication(); services.AddCustomDataProtectionServices(Environment, globalSettings);
app.UseAuthorization();
// Add current context // Stripe Billing
app.UseMiddleware<CurrentContextMiddleware>(); StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Add MVC to the request pipeline. // Repositories
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
services.AddScoped<IScimContext, ScimContext>();
// Authentication
services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme)
.AddScheme<ApiKeyAuthenticationOptions, ApiKeyAuthenticationHandler>(
ApiKeyAuthenticationOptions.DefaultScheme, null);
services.AddAuthorization(config =>
{
config.AddPolicy("Scim", policy =>
{
policy.RequireAuthenticatedUser();
policy.RequireClaim(JwtClaimTypes.Scope, "api.scim");
});
});
// Identity
services.AddCustomIdentityServices(globalSettings);
// Services
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>();
// Mvc
services.AddMvc(config =>
{
config.Filters.Add(new LoggingExceptionHandlerFilterAttribute());
});
services.Configure<RouteOptions>(options => options.LowercaseUrls = true);
}
public void Configure(
IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings)
{
app.UseSerilog(env, appLifetime, globalSettings);
// Add general security headers
app.UseMiddleware<SecurityHeadersMiddleware>();
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
}
// Default Middleware
app.UseDefaultMiddleware(env, globalSettings);
// Add routing
app.UseRouting();
// Add Scim context
app.UseMiddleware<ScimContextMiddleware>();
// Add authentication and authorization to the request pipeline.
app.UseAuthentication();
app.UseAuthorization();
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add MVC to the request pipeline.
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
}
} }
} }

View File

@ -8,82 +8,83 @@ using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Scim.Utilities; namespace Bit.Scim.Utilities
public class ApiKeyAuthenticationHandler : AuthenticationHandler<ApiKeyAuthenticationOptions>
{ {
private readonly IOrganizationRepository _organizationRepository; public class ApiKeyAuthenticationHandler : AuthenticationHandler<ApiKeyAuthenticationOptions>
private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository;
private readonly IScimContext _scimContext;
public ApiKeyAuthenticationHandler(
IOptionsMonitor<ApiKeyAuthenticationOptions> options,
ILoggerFactory logger,
UrlEncoder encoder,
ISystemClock clock,
IOrganizationRepository organizationRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
IScimContext scimContext) :
base(options, logger, encoder, clock)
{ {
_organizationRepository = organizationRepository; private readonly IOrganizationRepository _organizationRepository;
_organizationApiKeyRepository = organizationApiKeyRepository; private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository;
_scimContext = scimContext; private readonly IScimContext _scimContext;
}
protected override async Task<AuthenticateResult> HandleAuthenticateAsync() public ApiKeyAuthenticationHandler(
{ IOptionsMonitor<ApiKeyAuthenticationOptions> options,
var endpoint = Context.GetEndpoint(); ILoggerFactory logger,
if (endpoint?.Metadata?.GetMetadata<IAllowAnonymous>() != null) UrlEncoder encoder,
ISystemClock clock,
IOrganizationRepository organizationRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
IScimContext scimContext) :
base(options, logger, encoder, clock)
{ {
return AuthenticateResult.NoResult(); _organizationRepository = organizationRepository;
_organizationApiKeyRepository = organizationApiKeyRepository;
_scimContext = scimContext;
} }
if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null) protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
{ {
Logger.LogWarning("No organization."); var endpoint = Context.GetEndpoint();
return AuthenticateResult.Fail("Invalid parameters"); if (endpoint?.Metadata?.GetMetadata<IAllowAnonymous>() != null)
{
return AuthenticateResult.NoResult();
}
if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null)
{
Logger.LogWarning("No organization.");
return AuthenticateResult.Fail("Invalid parameters");
}
if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1)
{
Logger.LogWarning("An API request was received without the Authorization header");
return AuthenticateResult.Fail("Invalid parameters");
}
var apiKey = authHeader.ToString();
if (apiKey.StartsWith("Bearer "))
{
apiKey = apiKey.Substring(7);
}
if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim ||
_scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled)
{
Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId);
return AuthenticateResult.Fail("Invalid parameters");
}
var orgApiKey = (await _organizationApiKeyRepository
.GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim))
.FirstOrDefault();
if (orgApiKey?.ApiKey != apiKey)
{
Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey);
return AuthenticateResult.Fail("Invalid parameters");
}
Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId);
var claims = new[]
{
new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"),
new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()),
new Claim(JwtClaimTypes.Scope, "api.scim"),
};
var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity),
ApiKeyAuthenticationOptions.DefaultScheme);
return AuthenticateResult.Success(ticket);
} }
if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1)
{
Logger.LogWarning("An API request was received without the Authorization header");
return AuthenticateResult.Fail("Invalid parameters");
}
var apiKey = authHeader.ToString();
if (apiKey.StartsWith("Bearer "))
{
apiKey = apiKey.Substring(7);
}
if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim ||
_scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled)
{
Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId);
return AuthenticateResult.Fail("Invalid parameters");
}
var orgApiKey = (await _organizationApiKeyRepository
.GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim))
.FirstOrDefault();
if (orgApiKey?.ApiKey != apiKey)
{
Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey);
return AuthenticateResult.Fail("Invalid parameters");
}
Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId);
var claims = new[]
{
new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"),
new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()),
new Claim(JwtClaimTypes.Scope, "api.scim"),
};
var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler));
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity),
ApiKeyAuthenticationOptions.DefaultScheme);
return AuthenticateResult.Success(ticket);
} }
} }

View File

@ -1,8 +1,9 @@
using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication;
namespace Bit.Scim.Utilities; namespace Bit.Scim.Utilities
public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions
{ {
public const string DefaultScheme = "ScimApiKey"; public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions
{
public const string DefaultScheme = "ScimApiKey";
}
} }

View File

@ -1,9 +1,10 @@
namespace Bit.Scim.Utilities; namespace Bit.Scim.Utilities
public static class ScimConstants
{ {
public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse"; public static class ScimConstants
public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error"; {
public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User"; public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse";
public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group"; public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error";
public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User";
public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group";
}
} }

View File

@ -2,21 +2,22 @@
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Scim.Context; using Bit.Scim.Context;
namespace Bit.Scim.Utilities; namespace Bit.Scim.Utilities
public class ScimContextMiddleware
{ {
private readonly RequestDelegate _next; public class ScimContextMiddleware
public ScimContextMiddleware(RequestDelegate next)
{ {
_next = next; private readonly RequestDelegate _next;
}
public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings, public ScimContextMiddleware(RequestDelegate next)
IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository) {
{ _next = next;
await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository); }
await _next.Invoke(httpContext);
public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings,
IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository)
{
await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository);
await _next.Invoke(httpContext);
}
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -5,50 +5,51 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Sso.Controllers; namespace Bit.Sso.Controllers
public class HomeController : Controller
{ {
private readonly IIdentityServerInteractionService _interaction; public class HomeController : Controller
public HomeController(IIdentityServerInteractionService interaction)
{ {
_interaction = interaction; private readonly IIdentityServerInteractionService _interaction;
}
[Route("~/Error")] public HomeController(IIdentityServerInteractionService interaction)
[Route("~/Home/Error")]
[AllowAnonymous]
public async Task<IActionResult> Error(string errorId)
{
var vm = new ErrorViewModel();
// retrieve error details from identityserver
var message = string.IsNullOrWhiteSpace(errorId) ? null :
await _interaction.GetErrorContextAsync(errorId);
if (message != null)
{ {
vm.Error = message; _interaction = interaction;
} }
else
[Route("~/Error")]
[Route("~/Home/Error")]
[AllowAnonymous]
public async Task<IActionResult> Error(string errorId)
{ {
vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier; var vm = new ErrorViewModel();
var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
var exception = exceptionHandlerPathFeature?.Error; // retrieve error details from identityserver
if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: ")) var message = string.IsNullOrWhiteSpace(errorId) ? null :
await _interaction.GetErrorContextAsync(errorId);
if (message != null)
{ {
// Messages coming from aspnetcore with a message vm.Error = message;
// similar to "The registered sign-in schemes are: {schemes}." }
// will expose other Org IDs and sign-in schemes enabled on else
// the server. These errors should be truncated to just the {
// scheme impacted (always the first sentence) vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier;
var cleanupPoint = opEx.Message.IndexOf(". ") + 1; var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
var exMessage = opEx.Message.Substring(0, cleanupPoint); var exception = exceptionHandlerPathFeature?.Error;
exception = new InvalidOperationException(exMessage, opEx); if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: "))
{
// Messages coming from aspnetcore with a message
// similar to "The registered sign-in schemes are: {schemes}."
// will expose other Org IDs and sign-in schemes enabled on
// the server. These errors should be truncated to just the
// scheme impacted (always the first sentence)
var cleanupPoint = opEx.Message.IndexOf(". ") + 1;
var exMessage = opEx.Message.Substring(0, cleanupPoint);
exception = new InvalidOperationException(exMessage, opEx);
}
vm.Exception = exception;
} }
vm.Exception = exception;
}
return View("Error", vm); return View("Error", vm);
}
} }
} }

View File

@ -1,20 +1,21 @@
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Sso.Controllers; namespace Bit.Sso.Controllers
public class InfoController : Controller
{ {
[HttpGet("~/alive")] public class InfoController : Controller
[HttpGet("~/now")]
public DateTime GetAlive()
{ {
return DateTime.UtcNow; [HttpGet("~/alive")]
} [HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")] [HttpGet("~/version")]
public JsonResult GetVersion() public JsonResult GetVersion()
{ {
return Json(CoreHelpers.GetVersion()); return Json(CoreHelpers.GetVersion());
}
} }
} }

View File

@ -5,65 +5,66 @@ using Microsoft.AspNetCore.Mvc;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
using Sustainsys.Saml2.WebSso; using Sustainsys.Saml2.WebSso;
namespace Bit.Sso.Controllers; namespace Bit.Sso.Controllers
public class MetadataController : Controller
{ {
private readonly IAuthenticationSchemeProvider _schemeProvider; public class MetadataController : Controller
public MetadataController(
IAuthenticationSchemeProvider schemeProvider)
{ {
_schemeProvider = schemeProvider; private readonly IAuthenticationSchemeProvider _schemeProvider;
}
[HttpGet("saml2/{scheme}")] public MetadataController(
public async Task<IActionResult> ViewAsync(string scheme) IAuthenticationSchemeProvider schemeProvider)
{
if (string.IsNullOrWhiteSpace(scheme))
{ {
return NotFound(); _schemeProvider = schemeProvider;
} }
var authScheme = await _schemeProvider.GetSchemeAsync(scheme); [HttpGet("saml2/{scheme}")]
if (authScheme == null || public async Task<IActionResult> ViewAsync(string scheme)
!(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) ||
dynamicAuthScheme?.SsoType != SsoType.Saml2)
{ {
return NotFound(); if (string.IsNullOrWhiteSpace(scheme))
{
return NotFound();
}
var authScheme = await _schemeProvider.GetSchemeAsync(scheme);
if (authScheme == null ||
!(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) ||
dynamicAuthScheme?.SsoType != SsoType.Saml2)
{
return NotFound();
}
if (!(dynamicAuthScheme.Options is Saml2Options options))
{
return NotFound();
}
var uri = new Uri(
Request.Scheme
+ "://"
+ Request.Host
+ Request.Path
+ Request.QueryString);
var pathBase = Request.PathBase.Value;
pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase;
var requestdata = new HttpRequestData(
Request.Method,
uri,
pathBase,
null,
Request.Cookies,
(data) => data);
var metadataResult = CommandFactory
.GetCommand(CommandFactory.MetadataCommand)
.Run(requestdata, options);
//Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml");
return new ContentResult
{
Content = metadataResult.Content,
ContentType = "text/xml",
};
} }
if (!(dynamicAuthScheme.Options is Saml2Options options))
{
return NotFound();
}
var uri = new Uri(
Request.Scheme
+ "://"
+ Request.Host
+ Request.Path
+ Request.QueryString);
var pathBase = Request.PathBase.Value;
pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase;
var requestdata = new HttpRequestData(
Request.Method,
uri,
pathBase,
null,
Request.Cookies,
(data) => data);
var metadataResult = CommandFactory
.GetCommand(CommandFactory.MetadataCommand)
.Run(requestdata, options);
//Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml");
return new ContentResult
{
Content = metadataResult.Content,
ContentType = "text/xml",
};
} }
} }

View File

@ -1,26 +1,27 @@
using IdentityServer4.Models; using IdentityServer4.Models;
namespace Bit.Sso.Models; namespace Bit.Sso.Models
public class ErrorViewModel
{ {
private string _requestId; public class ErrorViewModel
public ErrorMessage Error { get; set; }
public Exception Exception { get; set; }
public string Message => Error?.Error;
public string Description => Error?.ErrorDescription ?? Exception?.Message;
public string RedirectUri => Error?.RedirectUri;
public string RequestId
{ {
get private string _requestId;
public ErrorMessage Error { get; set; }
public Exception Exception { get; set; }
public string Message => Error?.Error;
public string Description => Error?.ErrorDescription ?? Exception?.Message;
public string RedirectUri => Error?.RedirectUri;
public string RequestId
{ {
return Error?.RequestId ?? _requestId; get
} {
set return Error?.RequestId ?? _requestId;
{ }
_requestId = value; set
{
_requestId = value;
}
} }
} }
} }

View File

@ -1,6 +1,7 @@
namespace Bit.Sso.Models; namespace Bit.Sso.Models
public class RedirectViewModel
{ {
public string RedirectUrl { get; set; } public class RedirectViewModel
{
public string RedirectUrl { get; set; }
}
} }

View File

@ -1,8 +1,9 @@
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
namespace Bit.Sso.Models; namespace Bit.Sso.Models
public class SamlEnvironment
{ {
public X509Certificate2 SpSigningCertificate { get; set; } public class SamlEnvironment
{
public X509Certificate2 SpSigningCertificate { get; set; }
}
} }

View File

@ -1,12 +1,13 @@
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Sso.Models; namespace Bit.Sso.Models
public class SsoPreValidateResponseModel : JsonResult
{ {
public SsoPreValidateResponseModel(string token) : base(new public class SsoPreValidateResponseModel : JsonResult
{ {
token public SsoPreValidateResponseModel(string token) : base(new
}) {
{ } token
})
{ }
}
} }

View File

@ -2,32 +2,33 @@
using Serilog; using Serilog;
using Serilog.Events; using Serilog.Events;
namespace Bit.Sso; namespace Bit.Sso
public class Program
{ {
public static void Main(string[] args) public class Program
{ {
Host public static void Main(string[] args)
.CreateDefaultBuilder(args) {
.ConfigureCustomAppConfiguration(args) Host
.ConfigureWebHostDefaults(webBuilder => .CreateDefaultBuilder(args)
{ .ConfigureCustomAppConfiguration(args)
webBuilder.UseStartup<Startup>(); .ConfigureWebHostDefaults(webBuilder =>
webBuilder.ConfigureLogging((hostingContext, logging) =>
logging.AddSerilog(hostingContext, e =>
{ {
var context = e.Properties["SourceContext"].ToString(); webBuilder.UseStartup<Startup>();
if (e.Properties.ContainsKey("RequestPath") && webBuilder.ConfigureLogging((hostingContext, logging) =>
!string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && logging.AddSerilog(hostingContext, e =>
(context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer")))
{ {
return false; var context = e.Properties["SourceContext"].ToString();
} if (e.Properties.ContainsKey("RequestPath") &&
return e.Level >= LogEventLevel.Error; !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) &&
})); (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer")))
}) {
.Build() return false;
.Run(); }
return e.Level >= LogEventLevel.Error;
}));
})
.Build()
.Run();
}
} }
} }

View File

@ -8,147 +8,148 @@ using IdentityServer4.Extensions;
using Microsoft.IdentityModel.Logging; using Microsoft.IdentityModel.Logging;
using Stripe; using Stripe;
namespace Bit.Sso; namespace Bit.Sso
public class Startup
{ {
public Startup(IWebHostEnvironment env, IConfiguration configuration) public class Startup
{ {
Configuration = configuration; public Startup(IWebHostEnvironment env, IConfiguration configuration)
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
// Caching
services.AddMemoryCache();
services.AddDistributedCache(globalSettings);
// Mvc
services.AddControllersWithViews();
// Cookies
if (Environment.IsDevelopment())
{ {
services.Configure<CookiePolicyOptions>(options => Configuration = configuration;
Environment = env;
}
public IConfiguration Configuration { get; }
public IWebHostEnvironment Environment { get; set; }
public void ConfigureServices(IServiceCollection services)
{
// Options
services.AddOptions();
// Settings
var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment);
// Stripe Billing
StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey;
StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries;
// Data Protection
services.AddCustomDataProtectionServices(Environment, globalSettings);
// Repositories
services.AddSqlServerRepositories(globalSettings);
// Context
services.AddScoped<ICurrentContext, CurrentContext>();
// Caching
services.AddMemoryCache();
services.AddDistributedCache(globalSettings);
// Mvc
services.AddControllersWithViews();
// Cookies
if (Environment.IsDevelopment())
{ {
options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; services.Configure<CookiePolicyOptions>(options =>
options.OnAppendCookie = ctx =>
{ {
ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
}; options.OnAppendCookie = ctx =>
}); {
ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
};
});
}
// Authentication
services.AddDistributedIdentityServices(globalSettings);
services.AddAuthentication()
.AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme);
services.AddSsoServices(globalSettings);
// IdentityServer
services.AddSsoIdentityServerServices(Environment, globalSettings);
// Identity
services.AddCustomIdentityServices(globalSettings);
// Services
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.AddCoreLocalizationServices();
} }
// Authentication public void Configure(
services.AddDistributedIdentityServices(globalSettings); IApplicationBuilder app,
services.AddAuthentication() IWebHostEnvironment env,
.AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); IHostApplicationLifetime appLifetime,
services.AddSsoServices(globalSettings); GlobalSettings globalSettings,
ILogger<Startup> logger)
// IdentityServer
services.AddSsoIdentityServerServices(Environment, globalSettings);
// Identity
services.AddCustomIdentityServices(globalSettings);
// Services
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.AddCoreLocalizationServices();
}
public void Configure(
IApplicationBuilder app,
IWebHostEnvironment env,
IHostApplicationLifetime appLifetime,
GlobalSettings globalSettings,
ILogger<Startup> logger)
{
if (env.IsDevelopment() || globalSettings.SelfHosted)
{ {
IdentityModelEventSource.ShowPII = true; if (env.IsDevelopment() || globalSettings.SelfHosted)
}
app.UseSerilog(env, appLifetime, globalSettings);
// Add general security headers
app.UseMiddleware<SecurityHeadersMiddleware>();
if (!env.IsDevelopment())
{
var uri = new Uri(globalSettings.BaseServiceUri.Sso);
app.Use(async (ctx, next) =>
{ {
ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); IdentityModelEventSource.ShowPII = true;
await next(); }
app.UseSerilog(env, appLifetime, globalSettings);
// Add general security headers
app.UseMiddleware<SecurityHeadersMiddleware>();
if (!env.IsDevelopment())
{
var uri = new Uri(globalSettings.BaseServiceUri.Sso);
app.Use(async (ctx, next) =>
{
ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}");
await next();
});
}
if (globalSettings.SelfHosted)
{
app.UsePathBase("/sso");
app.UseForwardedHeaders(globalSettings);
}
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
app.UseCookiePolicy();
}
else
{
app.UseExceptionHandler("/Error");
}
app.UseCoreLocalization();
// Add static files to the request pipeline.
app.UseStaticFiles();
// Add routing
app.UseRouting();
// Add Cors
app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings))
.AllowAnyMethod().AllowAnyHeader().AllowCredentials());
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add IdentityServer to the request pipeline.
app.UseIdentityServer(new IdentityServerMiddlewareOptions
{
AuthenticationMiddleware = app => app.UseMiddleware<SsoAuthenticationMiddleware>()
}); });
// Add Mvc stuff
app.UseAuthorization();
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
// Log startup
logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started.");
} }
if (globalSettings.SelfHosted)
{
app.UsePathBase("/sso");
app.UseForwardedHeaders(globalSettings);
}
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
app.UseCookiePolicy();
}
else
{
app.UseExceptionHandler("/Error");
}
app.UseCoreLocalization();
// Add static files to the request pipeline.
app.UseStaticFiles();
// Add routing
app.UseRouting();
// Add Cors
app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings))
.AllowAnyMethod().AllowAnyHeader().AllowCredentials());
// Add current context
app.UseMiddleware<CurrentContextMiddleware>();
// Add IdentityServer to the request pipeline.
app.UseIdentityServer(new IdentityServerMiddlewareOptions
{
AuthenticationMiddleware = app => app.UseMiddleware<SsoAuthenticationMiddleware>()
});
// Add Mvc stuff
app.UseAuthorization();
app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute());
// Log startup
logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started.");
} }
} }

View File

@ -1,45 +1,46 @@
using System.Security.Claims; using System.Security.Claims;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class ClaimsExtensions
{ {
private static readonly Regex _normalizeTextRegEx = public static class ClaimsExtensions
new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline);
public static string GetFirstMatch(this IEnumerable<Claim> claims, params string[] possibleNames)
{ {
var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList(); private static readonly Regex _normalizeTextRegEx =
new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline);
// Order of prescendence is by passed in names public static string GetFirstMatch(this IEnumerable<Claim> claims, params string[] possibleNames)
foreach (var name in possibleNames.Select(Normalize))
{ {
// Second by order of claims (find claim by name) var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList();
foreach (var claim in normalizedClaims)
// Order of prescendence is by passed in names
foreach (var name in possibleNames.Select(Normalize))
{ {
if (Equals(claim.Item1, name)) // Second by order of claims (find claim by name)
foreach (var claim in normalizedClaims)
{ {
return claim.Value; if (Equals(claim.Item1, name))
{
return claim.Value;
}
} }
} }
return null;
} }
return null;
}
private static bool Equals(string text, string compare) private static bool Equals(string text, string compare)
{
return text == compare ||
(string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) ||
string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase);
}
private static string Normalize(string text)
{
if (string.IsNullOrWhiteSpace(text))
{ {
return text; return text == compare ||
(string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) ||
string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase);
}
private static string Normalize(string text)
{
if (string.IsNullOrWhiteSpace(text))
{
return text;
}
return _normalizeTextRegEx.Replace(text, string.Empty);
} }
return _normalizeTextRegEx.Replace(text, string.Empty);
} }
} }

View File

@ -5,31 +5,32 @@ using IdentityServer4.Services;
using IdentityServer4.Stores; using IdentityServer4.Stores;
using IdentityServer4.Validation; using IdentityServer4.Validation;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator
{ {
private readonly GlobalSettings _globalSettings; public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator
public DiscoveryResponseGenerator(
IdentityServerOptions options,
IResourceStore resourceStore,
IKeyMaterialService keys,
ExtensionGrantValidator extensionGrants,
ISecretsListParser secretParsers,
IResourceOwnerPasswordValidator resourceOwnerValidator,
ILogger<DiscoveryResponseGenerator> logger,
GlobalSettings globalSettings)
: base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger)
{ {
_globalSettings = globalSettings; private readonly GlobalSettings _globalSettings;
}
public override async Task<Dictionary<string, object>> CreateDiscoveryDocumentAsync( public DiscoveryResponseGenerator(
string baseUrl, string issuerUri) IdentityServerOptions options,
{ IResourceStore resourceStore,
var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); IKeyMaterialService keys,
return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso, ExtensionGrantValidator extensionGrants,
_globalSettings.BaseServiceUri.InternalSso); ISecretsListParser secretParsers,
IResourceOwnerPasswordValidator resourceOwnerValidator,
ILogger<DiscoveryResponseGenerator> logger,
GlobalSettings globalSettings)
: base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger)
{
_globalSettings = globalSettings;
}
public override async Task<Dictionary<string, object>> CreateDiscoveryDocumentAsync(
string baseUrl, string issuerUri)
{
var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri);
return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso,
_globalSettings.BaseServiceUri.InternalSso);
}
} }
} }

View File

@ -3,87 +3,88 @@ using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme
{ {
public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme
AuthenticationSchemeOptions options)
: base(name, displayName, handlerType)
{ {
Options = options; public DynamicAuthenticationScheme(string name, string displayName, Type handlerType,
} AuthenticationSchemeOptions options)
public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, : base(name, displayName, handlerType)
AuthenticationSchemeOptions options, SsoType ssoType) {
: this(name, displayName, handlerType, options) Options = options;
{ }
SsoType = ssoType; public DynamicAuthenticationScheme(string name, string displayName, Type handlerType,
} AuthenticationSchemeOptions options, SsoType ssoType)
: this(name, displayName, handlerType, options)
{
SsoType = ssoType;
}
public AuthenticationSchemeOptions Options { get; set; } public AuthenticationSchemeOptions Options { get; set; }
public SsoType SsoType { get; set; } public SsoType SsoType { get; set; }
public async Task Validate() public async Task Validate()
{
switch (SsoType)
{ {
case SsoType.OpenIdConnect: switch (SsoType)
await ValidateOpenIdConnectAsync();
break;
case SsoType.Saml2:
ValidateSaml();
break;
default:
break;
}
}
private void ValidateSaml()
{
if (SsoType != SsoType.Saml2)
{
return;
}
if (!(Options is Saml2Options samlOptions))
{
throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError");
}
samlOptions.Validate(Name);
}
private async Task ValidateOpenIdConnectAsync()
{
if (SsoType != SsoType.OpenIdConnect)
{
return;
}
if (!(Options is OpenIdConnectOptions oidcOptions))
{
throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError");
}
oidcOptions.Validate();
if (oidcOptions.Configuration == null)
{
if (oidcOptions.ConfigurationManager == null)
{ {
throw new Exception("PostConfigurationNotExecutedError"); case SsoType.OpenIdConnect:
await ValidateOpenIdConnectAsync();
break;
case SsoType.Saml2:
ValidateSaml();
break;
default:
break;
}
}
private void ValidateSaml()
{
if (SsoType != SsoType.Saml2)
{
return;
}
if (!(Options is Saml2Options samlOptions))
{
throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError");
}
samlOptions.Validate(Name);
}
private async Task ValidateOpenIdConnectAsync()
{
if (SsoType != SsoType.OpenIdConnect)
{
return;
}
if (!(Options is OpenIdConnectOptions oidcOptions))
{
throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError");
}
oidcOptions.Validate();
if (oidcOptions.Configuration == null)
{
if (oidcOptions.ConfigurationManager == null)
{
throw new Exception("PostConfigurationNotExecutedError");
}
if (oidcOptions.Configuration == null)
{
try
{
oidcOptions.Configuration = await oidcOptions.ConfigurationManager
.GetConfigurationAsync(CancellationToken.None);
}
catch (Exception ex)
{
throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex);
}
}
} }
if (oidcOptions.Configuration == null) if (oidcOptions.Configuration == null)
{ {
try throw new Exception("NoOpenIdConnectMetadataError");
{
oidcOptions.Configuration = await oidcOptions.ConfigurationManager
.GetConfigurationAsync(CancellationToken.None);
}
catch (Exception ex)
{
throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex);
}
} }
} }
if (oidcOptions.Configuration == null)
{
throw new Exception("NoOpenIdConnectMetadataError");
}
} }
} }

View File

@ -18,440 +18,441 @@ using Sustainsys.Saml2.AspNetCore2;
using Sustainsys.Saml2.Configuration; using Sustainsys.Saml2.Configuration;
using Sustainsys.Saml2.Saml2P; using Sustainsys.Saml2.Saml2P;
namespace Bit.Core.Business.Sso; namespace Bit.Core.Business.Sso
public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider
{ {
private readonly IPostConfigureOptions<OpenIdConnectOptions> _oidcPostConfigureOptions; public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider
private readonly IExtendedOptionsMonitorCache<OpenIdConnectOptions> _extendedOidcOptionsMonitorCache;
private readonly IPostConfigureOptions<Saml2Options> _saml2PostConfigureOptions;
private readonly IExtendedOptionsMonitorCache<Saml2Options> _extendedSaml2OptionsMonitorCache;
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly ILogger _logger;
private readonly GlobalSettings _globalSettings;
private readonly SamlEnvironment _samlEnvironment;
private readonly TimeSpan _schemeCacheLifetime;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedSchemes;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedHandlerSchemes;
private readonly SemaphoreSlim _semaphore;
private readonly IHttpContextAccessor _httpContextAccessor;
private DateTime? _lastSchemeLoad;
private IEnumerable<DynamicAuthenticationScheme> _schemesCopy = Array.Empty<DynamicAuthenticationScheme>();
private IEnumerable<DynamicAuthenticationScheme> _handlerSchemesCopy = Array.Empty<DynamicAuthenticationScheme>();
public DynamicAuthenticationSchemeProvider(
IOptions<AuthenticationOptions> options,
IPostConfigureOptions<OpenIdConnectOptions> oidcPostConfigureOptions,
IOptionsMonitorCache<OpenIdConnectOptions> oidcOptionsMonitorCache,
IPostConfigureOptions<Saml2Options> saml2PostConfigureOptions,
IOptionsMonitorCache<Saml2Options> saml2OptionsMonitorCache,
ISsoConfigRepository ssoConfigRepository,
ILogger<DynamicAuthenticationSchemeProvider> logger,
GlobalSettings globalSettings,
SamlEnvironment samlEnvironment,
IHttpContextAccessor httpContextAccessor)
: base(options)
{ {
_oidcPostConfigureOptions = oidcPostConfigureOptions; private readonly IPostConfigureOptions<OpenIdConnectOptions> _oidcPostConfigureOptions;
_extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as private readonly IExtendedOptionsMonitorCache<OpenIdConnectOptions> _extendedOidcOptionsMonitorCache;
IExtendedOptionsMonitorCache<OpenIdConnectOptions>; private readonly IPostConfigureOptions<Saml2Options> _saml2PostConfigureOptions;
if (_extendedOidcOptionsMonitorCache == null) private readonly IExtendedOptionsMonitorCache<Saml2Options> _extendedSaml2OptionsMonitorCache;
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly ILogger _logger;
private readonly GlobalSettings _globalSettings;
private readonly SamlEnvironment _samlEnvironment;
private readonly TimeSpan _schemeCacheLifetime;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedSchemes;
private readonly Dictionary<string, DynamicAuthenticationScheme> _cachedHandlerSchemes;
private readonly SemaphoreSlim _semaphore;
private readonly IHttpContextAccessor _httpContextAccessor;
private DateTime? _lastSchemeLoad;
private IEnumerable<DynamicAuthenticationScheme> _schemesCopy = Array.Empty<DynamicAuthenticationScheme>();
private IEnumerable<DynamicAuthenticationScheme> _handlerSchemesCopy = Array.Empty<DynamicAuthenticationScheme>();
public DynamicAuthenticationSchemeProvider(
IOptions<AuthenticationOptions> options,
IPostConfigureOptions<OpenIdConnectOptions> oidcPostConfigureOptions,
IOptionsMonitorCache<OpenIdConnectOptions> oidcOptionsMonitorCache,
IPostConfigureOptions<Saml2Options> saml2PostConfigureOptions,
IOptionsMonitorCache<Saml2Options> saml2OptionsMonitorCache,
ISsoConfigRepository ssoConfigRepository,
ILogger<DynamicAuthenticationSchemeProvider> logger,
GlobalSettings globalSettings,
SamlEnvironment samlEnvironment,
IHttpContextAccessor httpContextAccessor)
: base(options)
{ {
throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved."); _oidcPostConfigureOptions = oidcPostConfigureOptions;
_extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as
IExtendedOptionsMonitorCache<OpenIdConnectOptions>;
if (_extendedOidcOptionsMonitorCache == null)
{
throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved.");
}
_saml2PostConfigureOptions = saml2PostConfigureOptions;
_extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as
IExtendedOptionsMonitorCache<Saml2Options>;
if (_extendedSaml2OptionsMonitorCache == null)
{
throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved.");
}
_ssoConfigRepository = ssoConfigRepository;
_logger = logger;
_globalSettings = globalSettings;
_schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30);
_samlEnvironment = samlEnvironment;
_cachedSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_cachedHandlerSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_semaphore = new SemaphoreSlim(1);
_httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor));
} }
_saml2PostConfigureOptions = saml2PostConfigureOptions; private bool CacheIsValid
_extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as
IExtendedOptionsMonitorCache<Saml2Options>;
if (_extendedSaml2OptionsMonitorCache == null)
{ {
throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved."); get => _lastSchemeLoad.HasValue
&& _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow;
} }
_ssoConfigRepository = ssoConfigRepository; public override async Task<AuthenticationScheme> GetSchemeAsync(string name)
_logger = logger;
_globalSettings = globalSettings;
_schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30);
_samlEnvironment = samlEnvironment;
_cachedSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_cachedHandlerSchemes = new Dictionary<string, DynamicAuthenticationScheme>();
_semaphore = new SemaphoreSlim(1);
_httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor));
}
private bool CacheIsValid
{
get => _lastSchemeLoad.HasValue
&& _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow;
}
public override async Task<AuthenticationScheme> GetSchemeAsync(string name)
{
var scheme = await base.GetSchemeAsync(name);
if (scheme != null)
{ {
return scheme; var scheme = await base.GetSchemeAsync(name);
if (scheme != null)
{
return scheme;
}
try
{
var dynamicScheme = await GetDynamicSchemeAsync(name);
return dynamicScheme;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name);
}
return null;
} }
try public override async Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
{ {
var dynamicScheme = await GetDynamicSchemeAsync(name); var existingSchemes = await base.GetAllSchemesAsync();
return dynamicScheme; var schemes = new List<AuthenticationScheme>();
} schemes.AddRange(existingSchemes);
catch (Exception ex)
{ await LoadAllDynamicSchemesIntoCacheAsync();
_logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name); schemes.AddRange(_schemesCopy);
return schemes.ToArray();
} }
return null; public override async Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
}
public override async Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
{
var existingSchemes = await base.GetAllSchemesAsync();
var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_schemesCopy);
return schemes.ToArray();
}
public override async Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
{
var existingSchemes = await base.GetRequestHandlerSchemesAsync();
var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_handlerSchemesCopy);
return schemes.ToArray();
}
private async Task LoadAllDynamicSchemesIntoCacheAsync()
{
if (CacheIsValid)
{ {
// Our cache hasn't expired or been invalidated, ignore request var existingSchemes = await base.GetRequestHandlerSchemesAsync();
return; var schemes = new List<AuthenticationScheme>();
schemes.AddRange(existingSchemes);
await LoadAllDynamicSchemesIntoCacheAsync();
schemes.AddRange(_handlerSchemesCopy);
return schemes.ToArray();
} }
await _semaphore.WaitAsync();
try private async Task LoadAllDynamicSchemesIntoCacheAsync()
{ {
if (CacheIsValid) if (CacheIsValid)
{ {
// Just in case (double-checked locking pattern) // Our cache hasn't expired or been invalidated, ignore request
return; return;
} }
await _semaphore.WaitAsync();
// Save time just in case the following operation takes longer try
var now = DateTime.UtcNow;
var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad);
foreach (var config in newSchemes)
{ {
DynamicAuthenticationScheme scheme; if (CacheIsValid)
try
{ {
scheme = GetSchemeFromSsoConfig(config); // Just in case (double-checked locking pattern)
return;
} }
catch (Exception ex)
// Save time just in case the following operation takes longer
var now = DateTime.UtcNow;
var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad);
foreach (var config in newSchemes)
{ {
_logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id); DynamicAuthenticationScheme scheme;
continue; try
{
scheme = GetSchemeFromSsoConfig(config);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id);
continue;
}
if (scheme == null)
{
continue;
}
SetSchemeInCache(scheme);
} }
if (scheme == null)
if (newSchemes.Any())
{ {
continue; // Maintain "safe" copy for use in enumeration routines
_schemesCopy = _cachedSchemes.Values.ToArray();
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray();
} }
SetSchemeInCache(scheme); _lastSchemeLoad = now;
}
finally
{
_semaphore.Release();
}
}
private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme)
{
if (!PostConfigureDynamicScheme(scheme))
{
return null;
}
_cachedSchemes[scheme.Name] = scheme;
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
{
_cachedHandlerSchemes[scheme.Name] = scheme;
}
return scheme;
}
private async Task<DynamicAuthenticationScheme> GetDynamicSchemeAsync(string name)
{
if (_cachedSchemes.TryGetValue(name, out var cachedScheme))
{
return cachedScheme;
} }
if (newSchemes.Any()) var scheme = await GetSchemeFromSsoConfigAsync(name);
{
// Maintain "safe" copy for use in enumeration routines
_schemesCopy = _cachedSchemes.Values.ToArray();
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray();
}
_lastSchemeLoad = now;
}
finally
{
_semaphore.Release();
}
}
private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme)
{
if (!PostConfigureDynamicScheme(scheme))
{
return null;
}
_cachedSchemes[scheme.Name] = scheme;
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
{
_cachedHandlerSchemes[scheme.Name] = scheme;
}
return scheme;
}
private async Task<DynamicAuthenticationScheme> GetDynamicSchemeAsync(string name)
{
if (_cachedSchemes.TryGetValue(name, out var cachedScheme))
{
return cachedScheme;
}
var scheme = await GetSchemeFromSsoConfigAsync(name);
if (scheme == null)
{
return null;
}
await _semaphore.WaitAsync();
try
{
scheme = SetSchemeInCache(scheme);
if (scheme == null) if (scheme == null)
{ {
return null; return null;
} }
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) await _semaphore.WaitAsync();
try
{ {
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); scheme = SetSchemeInCache(scheme);
if (scheme == null)
{
return null;
}
if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType))
{
_handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray();
}
_schemesCopy = _cachedSchemes.Values.ToArray();
} }
_schemesCopy = _cachedSchemes.Values.ToArray(); finally
}
finally
{
// Note: _lastSchemeLoad is not set here, this is a one-off
// and should not impact loading further cache updates
_semaphore.Release();
}
return scheme;
}
private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme)
{
try
{
if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions)
{ {
_oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions); // Note: _lastSchemeLoad is not set here, this is a one-off
_extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions); // and should not impact loading further cache updates
_semaphore.Release();
} }
else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options) return scheme;
}
private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme)
{
try
{ {
_saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options); if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions)
_extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options); {
_oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions);
_extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions);
}
else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options)
{
_saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options);
_extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options);
}
return true;
} }
return true; catch (Exception ex)
}
catch (Exception ex)
{
_logger.LogError(ex, "Error performing post configuration for '{0}' ({1})",
scheme.Name, scheme.DisplayName);
}
return false;
}
private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config)
{
var data = config.GetData();
return data.ConfigType switch
{
SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data),
SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data),
_ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"),
};
}
private async Task<DynamicAuthenticationScheme> GetSchemeFromSsoConfigAsync(string name)
{
if (!Guid.TryParse(name, out var organizationId))
{
_logger.LogWarning("Could not determine organization id from name, '{0}'", name);
return null;
}
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId);
if (ssoConfig == null || !ssoConfig.Enabled)
{
_logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name);
return null;
}
return GetSchemeFromSsoConfig(ssoConfig);
}
private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config)
{
var oidcOptions = new OpenIdConnectOptions
{
Authority = config.Authority,
ClientId = config.ClientId,
ClientSecret = config.ClientSecret,
ResponseType = "code",
ResponseMode = "form_post",
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
SignOutScheme = IdentityServerConstants.SignoutScheme,
SaveTokens = false, // reduce overall request size
TokenValidationParameters = new TokenValidationParameters
{ {
NameClaimType = JwtClaimTypes.Name, _logger.LogError(ex, "Error performing post configuration for '{0}' ({1})",
RoleClaimType = JwtClaimTypes.Role, scheme.Name, scheme.DisplayName);
}, }
CallbackPath = SsoConfigurationData.BuildCallbackPath(), return false;
SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(),
MetadataAddress = config.MetadataAddress,
// Prevents URLs that go beyond 1024 characters which may break for some servers
AuthenticationMethod = config.RedirectBehavior,
GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint,
};
oidcOptions.Scope
.AddIfNotExists(OpenIdConnectScopes.OpenId)
.AddIfNotExists(OpenIdConnectScopes.Email)
.AddIfNotExists(OpenIdConnectScopes.Profile);
foreach (var scope in config.GetAdditionalScopes())
{
oidcOptions.Scope.AddIfNotExists(scope);
}
if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue))
{
oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr);
} }
oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name); private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config)
// see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values)
if (!string.IsNullOrWhiteSpace(config.AcrValues))
{ {
oidcOptions.Events ??= new OpenIdConnectEvents(); var data = config.GetData();
oidcOptions.Events.OnRedirectToIdentityProvider = ctx => return data.ConfigType switch
{ {
ctx.ProtocolMessage.AcrValues = config.AcrValues; SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data),
return Task.CompletedTask; SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data),
_ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"),
}; };
} }
return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler), private async Task<DynamicAuthenticationScheme> GetSchemeFromSsoConfigAsync(string name)
oidcOptions, SsoType.OpenIdConnect);
}
private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config)
{
if (_samlEnvironment == null)
{ {
throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}"); if (!Guid.TryParse(name, out var organizationId))
{
_logger.LogWarning("Could not determine organization id from name, '{0}'", name);
return null;
}
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId);
if (ssoConfig == null || !ssoConfig.Enabled)
{
_logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name);
return null;
}
return GetSchemeFromSsoConfig(ssoConfig);
} }
var spEntityId = new Sustainsys.Saml2.Metadata.EntityId( private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config)
SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso));
bool? allowCreate = null;
if (config.SpNameIdFormat != Saml2NameIdFormat.Transient)
{ {
allowCreate = true; var oidcOptions = new OpenIdConnectOptions
} {
var spOptions = new SPOptions Authority = config.Authority,
{ ClientId = config.ClientId,
EntityId = spEntityId, ClientSecret = config.ClientSecret,
ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name), ResponseType = "code",
NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)), ResponseMode = "form_post",
WantAssertionsSigned = config.SpWantAssertionsSigned, SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior), SignOutScheme = IdentityServerConstants.SignoutScheme,
ValidateCertificates = config.SpValidateCertificates, SaveTokens = false, // reduce overall request size
}; TokenValidationParameters = new TokenValidationParameters
if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm)) {
{ NameClaimType = JwtClaimTypes.Name,
spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm; RoleClaimType = JwtClaimTypes.Role,
} },
if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm)) CallbackPath = SsoConfigurationData.BuildCallbackPath(),
{ SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(),
spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm; MetadataAddress = config.MetadataAddress,
} // Prevents URLs that go beyond 1024 characters which may break for some servers
if (_samlEnvironment.SpSigningCertificate != null) AuthenticationMethod = config.RedirectBehavior,
{ GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint,
spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate); };
oidcOptions.Scope
.AddIfNotExists(OpenIdConnectScopes.OpenId)
.AddIfNotExists(OpenIdConnectScopes.Email)
.AddIfNotExists(OpenIdConnectScopes.Profile);
foreach (var scope in config.GetAdditionalScopes())
{
oidcOptions.Scope.AddIfNotExists(scope);
}
if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue))
{
oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr);
}
oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name);
// see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values)
if (!string.IsNullOrWhiteSpace(config.AcrValues))
{
oidcOptions.Events ??= new OpenIdConnectEvents();
oidcOptions.Events.OnRedirectToIdentityProvider = ctx =>
{
ctx.ProtocolMessage.AcrValues = config.AcrValues;
return Task.CompletedTask;
};
}
return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler),
oidcOptions, SsoType.OpenIdConnect);
} }
var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId); private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config)
var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions)
{ {
Binding = GetBindingType(config.IdpBindingType), if (_samlEnvironment == null)
AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse, {
DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests, throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}");
WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned, }
};
if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl)) var spEntityId = new Sustainsys.Saml2.Metadata.EntityId(
{ SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso));
idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl); bool? allowCreate = null;
if (config.SpNameIdFormat != Saml2NameIdFormat.Transient)
{
allowCreate = true;
}
var spOptions = new SPOptions
{
EntityId = spEntityId,
ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name),
NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)),
WantAssertionsSigned = config.SpWantAssertionsSigned,
AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior),
ValidateCertificates = config.SpValidateCertificates,
};
if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm))
{
spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm;
}
if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm))
{
spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm;
}
if (_samlEnvironment.SpSigningCertificate != null)
{
spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate);
}
var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId);
var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions)
{
Binding = GetBindingType(config.IdpBindingType),
AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse,
DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests,
WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned,
};
if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl))
{
idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl);
}
if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl))
{
idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl);
}
if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm))
{
idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm;
}
if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert))
{
var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert);
idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert));
}
idp.ArtifactResolutionServiceUrls.Clear();
// This must happen last since it calls Validate() internally.
idp.LoadMetadata = false;
var options = new Saml2Options
{
SPOptions = spOptions,
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme,
CookieManager = new IdentityServer.DistributedCacheCookieManager(),
};
options.IdentityProviders.Add(idp);
return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2);
} }
if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl))
private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format)
{ {
idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl); return format switch
{
Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified,
Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress,
Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName,
Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName,
Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName,
Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier,
Saml2NameIdFormat.Persistent => NameIdFormat.Persistent,
Saml2NameIdFormat.Transient => NameIdFormat.Transient,
_ => NameIdFormat.NotConfigured,
};
} }
if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm))
private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior)
{ {
idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm; return behavior switch
{
Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned,
Saml2SigningBehavior.Always => SigningBehavior.Always,
Saml2SigningBehavior.Never => SigningBehavior.Never,
_ => SigningBehavior.IfIdpWantAuthnRequestsSigned,
};
} }
if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert))
private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType)
{ {
var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert); return bindingType switch
idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert)); {
Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect,
Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
_ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
};
} }
idp.ArtifactResolutionServiceUrls.Clear();
// This must happen last since it calls Validate() internally.
idp.LoadMetadata = false;
var options = new Saml2Options
{
SPOptions = spOptions,
SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme,
SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme,
CookieManager = new IdentityServer.DistributedCacheCookieManager(),
};
options.IdentityProviders.Add(idp);
return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2);
}
private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format)
{
return format switch
{
Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified,
Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress,
Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName,
Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName,
Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName,
Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier,
Saml2NameIdFormat.Persistent => NameIdFormat.Persistent,
Saml2NameIdFormat.Transient => NameIdFormat.Transient,
_ => NameIdFormat.NotConfigured,
};
}
private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior)
{
return behavior switch
{
Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned,
Saml2SigningBehavior.Always => SigningBehavior.Always,
Saml2SigningBehavior.Never => SigningBehavior.Never,
_ => SigningBehavior.IfIdpWantAuthnRequestsSigned,
};
}
private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType)
{
return bindingType switch
{
Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect,
Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
_ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost,
};
} }
} }

View File

@ -1,36 +1,37 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public class ExtendedOptionsMonitorCache<TOptions> : IExtendedOptionsMonitorCache<TOptions> where TOptions : class
{ {
private readonly ConcurrentDictionary<string, Lazy<TOptions>> _cache = public class ExtendedOptionsMonitorCache<TOptions> : IExtendedOptionsMonitorCache<TOptions> where TOptions : class
new ConcurrentDictionary<string, Lazy<TOptions>>(StringComparer.Ordinal);
public void AddOrUpdate(string name, TOptions options)
{ {
_cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy<TOptions>(() => options), private readonly ConcurrentDictionary<string, Lazy<TOptions>> _cache =
(string s, Lazy<TOptions> lazy) => new Lazy<TOptions>(() => options)); new ConcurrentDictionary<string, Lazy<TOptions>>(StringComparer.Ordinal);
}
public void Clear() public void AddOrUpdate(string name, TOptions options)
{ {
_cache.Clear(); _cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy<TOptions>(() => options),
} (string s, Lazy<TOptions> lazy) => new Lazy<TOptions>(() => options));
}
public TOptions GetOrAdd(string name, Func<TOptions> createOptions) public void Clear()
{ {
return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy<TOptions>(createOptions)).Value; _cache.Clear();
} }
public bool TryAdd(string name, TOptions options) public TOptions GetOrAdd(string name, Func<TOptions> createOptions)
{ {
return _cache.TryAdd(name ?? Options.DefaultName, new Lazy<TOptions>(() => options)); return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy<TOptions>(createOptions)).Value;
} }
public bool TryRemove(string name) public bool TryAdd(string name, TOptions options)
{ {
return _cache.TryRemove(name ?? Options.DefaultName, out _); return _cache.TryAdd(name ?? Options.DefaultName, new Lazy<TOptions>(() => options));
}
public bool TryRemove(string name)
{
return _cache.TryRemove(name ?? Options.DefaultName, out _);
}
} }
} }

View File

@ -1,12 +1,13 @@
using Bit.Core.Enums; using Bit.Core.Enums;
using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public interface IDynamicAuthenticationScheme
{ {
AuthenticationSchemeOptions Options { get; set; } public interface IDynamicAuthenticationScheme
SsoType SsoType { get; set; } {
AuthenticationSchemeOptions Options { get; set; }
SsoType SsoType { get; set; }
Task Validate(); Task Validate();
}
} }

View File

@ -1,8 +1,9 @@
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public interface IExtendedOptionsMonitorCache<TOptions> : IOptionsMonitorCache<TOptions> where TOptions : class
{ {
void AddOrUpdate(string name, TOptions options); public interface IExtendedOptionsMonitorCache<TOptions> : IOptionsMonitorCache<TOptions> where TOptions : class
{
void AddOrUpdate(string name, TOptions options);
}
} }

View File

@ -1,62 +1,63 @@
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Microsoft.IdentityModel.Protocols.OpenIdConnect; using Microsoft.IdentityModel.Protocols.OpenIdConnect;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class OpenIdConnectOptionsExtensions
{ {
public static async Task<bool> CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context) public static class OpenIdConnectOptionsExtensions
{ {
// Determine this is a valid request for our handler public static async Task<bool> CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context)
if (options.CallbackPath != context.Request.Path &&
options.RemoteSignOutPath != context.Request.Path &&
options.SignedOutCallbackPath != context.Request.Path)
{ {
return false; // Determine this is a valid request for our handler
} if (options.CallbackPath != context.Request.Path &&
options.RemoteSignOutPath != context.Request.Path &&
if (context.Request.Query["scheme"].FirstOrDefault() == scheme) options.SignedOutCallbackPath != context.Request.Path)
{
return true;
}
try
{
// Parse out the message
OpenIdConnectMessage message = null;
if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{ {
message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
!string.IsNullOrEmpty(context.Request.ContentType) &&
context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) &&
context.Request.Body.CanRead)
{
var form = await context.Request.ReadFormAsync();
message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
var state = message?.State;
if (string.IsNullOrWhiteSpace(state))
{
// State is required, it will fail later on for this reason.
return false; return false;
} }
// Handle State if we've gotten that back if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
var decodedState = options.StateDataFormat.Unprotect(state);
if (decodedState != null && decodedState.Items.ContainsKey("scheme"))
{ {
return decodedState.Items["scheme"] == scheme; return true;
} }
}
catch try
{ {
// Parse out the message
OpenIdConnectMessage message = null;
if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
!string.IsNullOrEmpty(context.Request.ContentType) &&
context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) &&
context.Request.Body.CanRead)
{
var form = await context.Request.ReadFormAsync();
message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
var state = message?.State;
if (string.IsNullOrWhiteSpace(state))
{
// State is required, it will fail later on for this reason.
return false;
}
// Handle State if we've gotten that back
var decodedState = options.StateDataFormat.Unprotect(state);
if (decodedState != null && decodedState.Items.ContainsKey("scheme"))
{
return decodedState.Items["scheme"] == scheme;
}
}
catch
{
return false;
}
// This is likely not an appropriate handler
return false; return false;
} }
// This is likely not an appropriate handler
return false;
} }
} }

View File

@ -1,63 +1,64 @@
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
/// <summary>
/// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0
/// [RFC6749]. These values represent the standard scope values supported
/// by OAuth 2.0 and therefore OIDC.
/// </summary>
/// <remarks>
/// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes
/// </remarks>
public static class OpenIdConnectScopes
{ {
/// <summary> /// <summary>
/// REQUIRED. Informs the Authorization Server that the Client is making /// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0
/// an OpenID Connect request. If the openid scope value is not present, /// [RFC6749]. These values represent the standard scope values supported
/// the behavior is entirely unspecified. /// by OAuth 2.0 and therefore OIDC.
/// </summary>
public const string OpenId = "openid";
/// <summary>
/// OPTIONAL. This scope value requests access to the End-User's default
/// profile Claims, which are: name, family_name, given_name,
/// middle_name, nickname, preferred_username, profile, picture,
/// website, gender, birthdate, zoneinfo, locale, and updated_at.
/// </summary>
public const string Profile = "profile";
/// <summary>
/// OPTIONAL. This scope value requests access to the email and
/// email_verified Claims.
/// </summary>
public const string Email = "email";
/// <summary>
/// OPTIONAL. This scope value requests access to the address Claim.
/// </summary>
public const string Address = "address";
/// <summary>
/// OPTIONAL. This scope value requests access to the phone_number and
/// phone_number_verified Claims.
/// </summary>
public const string Phone = "phone";
/// <summary>
/// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token
/// be issued that can be used to obtain an Access Token that grants
/// access to the End-User's UserInfo Endpoint even when the End-User is
/// not present (not logged in).
/// </summary>
public const string OfflineAccess = "offline_access";
/// <summary>
/// OPTIONAL. Authentication Context Class Reference. String specifying
/// an Authentication Context Class Reference value that identifies the
/// Authentication Context Class that the authentication performed
/// satisfied.
/// </summary> /// </summary>
/// <remarks> /// <remarks>
/// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2 /// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes
/// </remarks> /// </remarks>
public const string Acr = "acr"; public static class OpenIdConnectScopes
{
/// <summary>
/// REQUIRED. Informs the Authorization Server that the Client is making
/// an OpenID Connect request. If the openid scope value is not present,
/// the behavior is entirely unspecified.
/// </summary>
public const string OpenId = "openid";
/// <summary>
/// OPTIONAL. This scope value requests access to the End-User's default
/// profile Claims, which are: name, family_name, given_name,
/// middle_name, nickname, preferred_username, profile, picture,
/// website, gender, birthdate, zoneinfo, locale, and updated_at.
/// </summary>
public const string Profile = "profile";
/// <summary>
/// OPTIONAL. This scope value requests access to the email and
/// email_verified Claims.
/// </summary>
public const string Email = "email";
/// <summary>
/// OPTIONAL. This scope value requests access to the address Claim.
/// </summary>
public const string Address = "address";
/// <summary>
/// OPTIONAL. This scope value requests access to the phone_number and
/// phone_number_verified Claims.
/// </summary>
public const string Phone = "phone";
/// <summary>
/// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token
/// be issued that can be used to obtain an Access Token that grants
/// access to the End-User's UserInfo Endpoint even when the End-User is
/// not present (not logged in).
/// </summary>
public const string OfflineAccess = "offline_access";
/// <summary>
/// OPTIONAL. Authentication Context Class Reference. String specifying
/// an Authentication Context Class Reference value that identifies the
/// Authentication Context Class that the authentication performed
/// satisfied.
/// </summary>
/// <remarks>
/// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2
/// </remarks>
public const string Acr = "acr";
}
} }

View File

@ -4,101 +4,102 @@ using System.Xml;
using Sustainsys.Saml2; using Sustainsys.Saml2;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class Saml2OptionsExtensions
{ {
public static async Task<bool> CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context) public static class Saml2OptionsExtensions
{ {
// Determine this is a valid request for our handler public static async Task<bool> CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context)
if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal))
{ {
return false; // Determine this is a valid request for our handler
} if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal))
{
return false;
}
var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default; var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default;
if (idp == null) if (idp == null)
{ {
return false; return false;
} }
if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
{
return true;
}
// We need to pull out and parse the response or request SAML envelope
XmlElement envelope = null;
try
{
if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
context.Request.HasFormContentType)
{
string encodedMessage;
if (context.Request.Form.TryGetValue("SAMLResponse", out var response))
{
encodedMessage = response.FirstOrDefault();
}
else
{
encodedMessage = context.Request.Form["SAMLRequest"];
}
if (string.IsNullOrWhiteSpace(encodedMessage))
{
return false;
}
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement;
}
else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ??
context.Request.Query["SAMLResponse"].FirstOrDefault();
try
{
var payload = Convert.FromBase64String(encodedPayload);
using var compressed = new MemoryStream(payload);
using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true);
using var deCompressed = new MemoryStream();
await decompressedStream.CopyToAsync(deCompressed);
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement;
}
catch (FormatException ex)
{
throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex);
}
}
}
catch
{
return false;
}
if (envelope == null)
{
return false;
}
// Double check the entity Ids
var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim();
if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase))
{
return false;
}
if (options.SPOptions.WantAssertionsSigned)
{
var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name];
var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys,
options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm);
if (!isAssertionSigned)
{
throw new Exception("Cannot verify SAML assertion signature.");
}
}
if (context.Request.Query["scheme"].FirstOrDefault() == scheme)
{
return true; return true;
} }
// We need to pull out and parse the response or request SAML envelope
XmlElement envelope = null;
try
{
if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) &&
context.Request.HasFormContentType)
{
string encodedMessage;
if (context.Request.Form.TryGetValue("SAMLResponse", out var response))
{
encodedMessage = response.FirstOrDefault();
}
else
{
encodedMessage = context.Request.Form["SAMLRequest"];
}
if (string.IsNullOrWhiteSpace(encodedMessage))
{
return false;
}
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement;
}
else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase))
{
var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ??
context.Request.Query["SAMLResponse"].FirstOrDefault();
try
{
var payload = Convert.FromBase64String(encodedPayload);
using var compressed = new MemoryStream(payload);
using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true);
using var deCompressed = new MemoryStream();
await decompressedStream.CopyToAsync(deCompressed);
envelope = XmlHelpers.XmlDocumentFromString(
Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement;
}
catch (FormatException ex)
{
throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex);
}
}
}
catch
{
return false;
}
if (envelope == null)
{
return false;
}
// Double check the entity Ids
var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim();
if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase))
{
return false;
}
if (options.SPOptions.WantAssertionsSigned)
{
var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name];
var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys,
options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm);
if (!isAssertionSigned)
{
throw new Exception("Cannot verify SAML assertion signature.");
}
}
return true;
} }
} }

View File

@ -1,11 +1,12 @@
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class SamlClaimTypes
{ {
public const string Email = "urn:oid:0.9.2342.19200300.100.1.3"; public static class SamlClaimTypes
public const string GivenName = "urn:oid:2.5.4.42"; {
public const string Surname = "urn:oid:2.5.4.4"; public const string Email = "urn:oid:0.9.2342.19200300.100.1.3";
public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241"; public const string GivenName = "urn:oid:2.5.4.42";
public const string CommonName = "urn:oid:2.5.4.3"; public const string Surname = "urn:oid:2.5.4.4";
public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1"; public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241";
public const string CommonName = "urn:oid:2.5.4.3";
public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1";
}
} }

View File

@ -1,17 +1,18 @@
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class SamlNameIdFormats
{ {
// Common public static class SamlNameIdFormats
public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"; {
public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; // Common
public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"; public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified";
public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"; public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress";
// Not-so-common public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent";
public const string Upn = "http://schemas.xmlsoap.org/claims/UPN"; public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient";
public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName"; // Not-so-common
public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"; public const string Upn = "http://schemas.xmlsoap.org/claims/UPN";
public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"; public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName";
public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"; public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName";
public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"; public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName";
public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos";
public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity";
}
} }

View File

@ -1,6 +1,7 @@
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class SamlPropertyKeys
{ {
public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format"; public static class SamlPropertyKeys
{
public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format";
}
} }

View File

@ -9,69 +9,70 @@ using IdentityServer4.ResponseHandling;
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public static class ServiceCollectionExtensions
{ {
public static IServiceCollection AddSsoServices(this IServiceCollection services, public static class ServiceCollectionExtensions
GlobalSettings globalSettings)
{ {
// SAML SP Configuration public static IServiceCollection AddSsoServices(this IServiceCollection services,
var samlEnvironment = new SamlEnvironment GlobalSettings globalSettings)
{ {
SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings), // SAML SP Configuration
}; var samlEnvironment = new SamlEnvironment
services.AddSingleton(s => samlEnvironment);
services.AddSingleton<Microsoft.AspNetCore.Authentication.IAuthenticationSchemeProvider,
DynamicAuthenticationSchemeProvider>();
// Oidc
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<OpenIdConnectOptions>,
OpenIdConnectPostConfigureOptions>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<OpenIdConnectOptions>,
ExtendedOptionsMonitorCache<OpenIdConnectOptions>>();
// Saml2
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<Saml2Options>,
PostConfigureSaml2Options>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<Saml2Options>,
ExtendedOptionsMonitorCache<Saml2Options>>();
return services;
}
public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services,
IWebHostEnvironment env, GlobalSettings globalSettings)
{
services.AddTransient<IDiscoveryResponseGenerator, DiscoveryResponseGenerator>();
var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso);
var identityServerBuilder = services
.AddIdentityServer(options =>
{ {
options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings),
if (env.IsDevelopment()) };
services.AddSingleton(s => samlEnvironment);
services.AddSingleton<Microsoft.AspNetCore.Authentication.IAuthenticationSchemeProvider,
DynamicAuthenticationSchemeProvider>();
// Oidc
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<OpenIdConnectOptions>,
OpenIdConnectPostConfigureOptions>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<OpenIdConnectOptions>,
ExtendedOptionsMonitorCache<OpenIdConnectOptions>>();
// Saml2
services.AddSingleton<Microsoft.Extensions.Options.IPostConfigureOptions<Saml2Options>,
PostConfigureSaml2Options>();
services.AddSingleton<Microsoft.Extensions.Options.IOptionsMonitorCache<Saml2Options>,
ExtendedOptionsMonitorCache<Saml2Options>>();
return services;
}
public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services,
IWebHostEnvironment env, GlobalSettings globalSettings)
{
services.AddTransient<IDiscoveryResponseGenerator, DiscoveryResponseGenerator>();
var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso);
var identityServerBuilder = services
.AddIdentityServer(options =>
{ {
options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}";
} if (env.IsDevelopment())
else {
options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified;
}
else
{
options.UserInteraction.ErrorUrl = "/Error";
options.UserInteraction.ErrorIdParameter = "errorId";
}
options.InputLengthRestrictions.UserName = 256;
})
.AddInMemoryCaching()
.AddInMemoryClients(new List<Client>
{ {
options.UserInteraction.ErrorUrl = "/Error"; new OidcIdentityClient(globalSettings)
options.UserInteraction.ErrorIdParameter = "errorId"; })
} .AddInMemoryIdentityResources(new List<IdentityResource>
options.InputLengthRestrictions.UserName = 256; {
}) new IdentityResources.OpenId(),
.AddInMemoryCaching() new IdentityResources.Profile()
.AddInMemoryClients(new List<Client> })
{ .AddIdentityServerCertificate(env, globalSettings);
new OidcIdentityClient(globalSettings)
})
.AddInMemoryIdentityResources(new List<IdentityResource>
{
new IdentityResources.OpenId(),
new IdentityResources.Profile()
})
.AddIdentityServerCertificate(env, globalSettings);
return identityServerBuilder; return identityServerBuilder;
}
} }
} }

View File

@ -3,82 +3,83 @@ using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.AspNetCore2;
namespace Bit.Sso.Utilities; namespace Bit.Sso.Utilities
public class SsoAuthenticationMiddleware
{ {
private readonly RequestDelegate _next; public class SsoAuthenticationMiddleware
public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes)
{ {
_next = next ?? throw new ArgumentNullException(nameof(next)); private readonly RequestDelegate _next;
Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes));
}
public IAuthenticationSchemeProvider Schemes { get; set; } public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes)
public async Task Invoke(HttpContext context)
{
if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart"))
|| (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart")))
{ {
throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed."); _next = next ?? throw new ArgumentNullException(nameof(next));
Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes));
} }
context.Features.Set<IAuthenticationFeature>(new AuthenticationFeature public IAuthenticationSchemeProvider Schemes { get; set; }
{
OriginalPath = context.Request.Path,
OriginalPathBase = context.Request.PathBase
});
// Give any IAuthenticationRequestHandler schemes a chance to handle the request public async Task Invoke(HttpContext context)
var handlers = context.RequestServices.GetRequiredService<IAuthenticationHandlerProvider>();
foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync())
{ {
// Determine if scheme is appropriate for the current context FIRST if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart"))
if (scheme is IDynamicAuthenticationScheme dynamicScheme) || (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart")))
{ {
switch (dynamicScheme.SsoType) throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed.");
}
context.Features.Set<IAuthenticationFeature>(new AuthenticationFeature
{
OriginalPath = context.Request.Path,
OriginalPathBase = context.Request.PathBase
});
// Give any IAuthenticationRequestHandler schemes a chance to handle the request
var handlers = context.RequestServices.GetRequiredService<IAuthenticationHandlerProvider>();
foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync())
{
// Determine if scheme is appropriate for the current context FIRST
if (scheme is IDynamicAuthenticationScheme dynamicScheme)
{ {
case SsoType.OpenIdConnect: switch (dynamicScheme.SsoType)
default: {
if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions && case SsoType.OpenIdConnect:
!await oidcOptions.CouldHandleAsync(scheme.Name, context)) default:
{ if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions &&
// It's OIDC and Dynamic, but not a good fit !await oidcOptions.CouldHandleAsync(scheme.Name, context))
continue; {
} // It's OIDC and Dynamic, but not a good fit
break; continue;
case SsoType.Saml2: }
if (dynamicScheme.Options is Saml2Options samlOptions && break;
!await samlOptions.CouldHandleAsync(scheme.Name, context)) case SsoType.Saml2:
{ if (dynamicScheme.Options is Saml2Options samlOptions &&
// It's SAML and Dynamic, but not a good fit !await samlOptions.CouldHandleAsync(scheme.Name, context))
continue; {
} // It's SAML and Dynamic, but not a good fit
break; continue;
}
break;
}
}
// This far it's not dynamic OR it is but "could" be handled
if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler &&
await handler.HandleRequestAsync())
{
return;
} }
} }
// This far it's not dynamic OR it is but "could" be handled // Fallback to the default scheme from the provider
if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler && var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync();
await handler.HandleRequestAsync()) if (defaultAuthenticate != null)
{ {
return; var result = await context.AuthenticateAsync(defaultAuthenticate.Name);
if (result?.Principal != null)
{
context.User = result.Principal;
}
} }
}
// Fallback to the default scheme from the provider await _next(context);
var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync();
if (defaultAuthenticate != null)
{
var result = await context.AuthenticateAsync(defaultAuthenticate.Name);
if (result?.Principal != null)
{
context.User = result.Principal;
}
} }
await _next(context);
} }
} }

View File

@ -3,42 +3,43 @@ using AutoFixture;
using AutoFixture.Xunit2; using AutoFixture.Xunit2;
using Bit.Core.Enums.Provider; using Bit.Core.Enums.Provider;
namespace Bit.Commercial.Core.Test.AutoFixture; namespace Bit.Commercial.Core.Test.AutoFixture
internal class ProviderUser : ICustomization
{ {
public ProviderUserStatusType Status { get; set; } internal class ProviderUser : ICustomization
public ProviderUserType Type { get; set; }
public ProviderUser(ProviderUserStatusType status, ProviderUserType type)
{ {
Status = status; public ProviderUserStatusType Status { get; set; }
Type = type; public ProviderUserType Type { get; set; }
public ProviderUser(ProviderUserStatusType status, ProviderUserType type)
{
Status = status;
Type = type;
}
public void Customize(IFixture fixture)
{
fixture.Customize<Bit.Core.Entities.Provider.ProviderUser>(composer => composer
.With(o => o.Type, Type)
.With(o => o.Status, Status));
}
} }
public void Customize(IFixture fixture) public class ProviderUserAttribute : CustomizeAttribute
{ {
fixture.Customize<Bit.Core.Entities.Provider.ProviderUser>(composer => composer private readonly ProviderUserStatusType _status;
.With(o => o.Type, Type) private readonly ProviderUserType _type;
.With(o => o.Status, Status));
} public ProviderUserAttribute(
} ProviderUserStatusType status = ProviderUserStatusType.Confirmed,
ProviderUserType type = ProviderUserType.ProviderAdmin)
public class ProviderUserAttribute : CustomizeAttribute {
{ _status = status;
private readonly ProviderUserStatusType _status; _type = type;
private readonly ProviderUserType _type; }
public ProviderUserAttribute( public override ICustomization GetCustomization(ParameterInfo parameter)
ProviderUserStatusType status = ProviderUserStatusType.Confirmed, {
ProviderUserType type = ProviderUserType.ProviderAdmin) return new ProviderUser(_status, _type);
{ }
_status = status;
_type = type;
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new ProviderUser(_status, _type);
} }
} }

View File

@ -1,15 +1,16 @@
namespace Bit.Admin; namespace Bit.Admin
public class AdminSettings
{ {
public virtual string Admins { get; set; } public class AdminSettings
public virtual CloudflareSettings Cloudflare { get; set; }
public int? DeleteTrashDaysAgo { get; set; }
public class CloudflareSettings
{ {
public string ZoneId { get; set; } public virtual string Admins { get; set; }
public string AuthEmail { get; set; } public virtual CloudflareSettings Cloudflare { get; set; }
public string AuthKey { get; set; } public int? DeleteTrashDaysAgo { get; set; }
public class CloudflareSettings
{
public string ZoneId { get; set; }
public string AuthEmail { get; set; }
public string AuthKey { get; set; }
}
} }
} }

View File

@ -1,23 +1,24 @@
using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
public class ErrorController : Controller
{ {
[Route("/error")] public class ErrorController : Controller
public IActionResult Error(int? statusCode = null)
{ {
var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>(); [Route("/error")]
TempData["Error"] = HttpContext.Features.Get<IExceptionHandlerFeature>()?.Error.Message; public IActionResult Error(int? statusCode = null)
{
var exceptionHandlerPathFeature = HttpContext.Features.Get<IExceptionHandlerPathFeature>();
TempData["Error"] = HttpContext.Features.Get<IExceptionHandlerFeature>()?.Error.Message;
if (exceptionHandlerPathFeature != null) if (exceptionHandlerPathFeature != null)
{ {
return Redirect(exceptionHandlerPathFeature.Path); return Redirect(exceptionHandlerPathFeature.Path);
} }
else else
{ {
return Redirect("/Home"); return Redirect("/Home");
}
} }
} }
} }

View File

@ -6,108 +6,109 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Newtonsoft.Json; using Newtonsoft.Json;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
public class HomeController : Controller
{ {
private readonly GlobalSettings _globalSettings; public class HomeController : Controller
private readonly HttpClient _httpClient = new HttpClient();
private readonly ILogger<HomeController> _logger;
public HomeController(GlobalSettings globalSettings, ILogger<HomeController> logger)
{ {
_globalSettings = globalSettings; private readonly GlobalSettings _globalSettings;
_logger = logger; private readonly HttpClient _httpClient = new HttpClient();
} private readonly ILogger<HomeController> _logger;
[Authorize] public HomeController(GlobalSettings globalSettings, ILogger<HomeController> logger)
public IActionResult Index()
{
return View(new HomeModel
{ {
GlobalSettings = _globalSettings, _globalSettings = globalSettings;
CurrentVersion = Core.Utilities.CoreHelpers.GetVersion() _logger = logger;
}); }
}
public IActionResult Error() [Authorize]
{ public IActionResult Index()
return View(new ErrorViewModel
{ {
RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier return View(new HomeModel
});
}
public async Task<IActionResult> GetLatestVersion(ProjectType project, CancellationToken cancellationToken)
{
var requestUri = $"https://selfhost.bitwarden.com/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{ {
var latestVersions = JsonConvert.DeserializeObject<LatestVersions>(await response.Content.ReadAsStringAsync()); GlobalSettings = _globalSettings,
return project switch CurrentVersion = Core.Utilities.CoreHelpers.GetVersion()
});
}
public IActionResult Error()
{
return View(new ErrorViewModel
{
RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier
});
}
public async Task<IActionResult> GetLatestVersion(ProjectType project, CancellationToken cancellationToken)
{
var requestUri = $"https://selfhost.bitwarden.com/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{ {
ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion), var latestVersions = JsonConvert.DeserializeObject<LatestVersions>(await response.Content.ReadAsStringAsync());
ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion), return project switch
_ => throw new System.NotImplementedException(), {
}; ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion),
ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion),
_ => throw new System.NotImplementedException(),
};
}
} }
} catch (HttpRequestException e)
catch (HttpRequestException e)
{
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError };
}
return new JsonResult("-");
}
public async Task<IActionResult> GetInstalledWebVersion(CancellationToken cancellationToken)
{
var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json";
try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{ {
using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken); _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
var root = jsonDocument.RootElement; return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError };
return new JsonResult(root.GetProperty("version").GetString());
} }
return new JsonResult("-");
} }
catch (HttpRequestException e)
public async Task<IActionResult> GetInstalledWebVersion(CancellationToken cancellationToken)
{ {
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json";
return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; try
{
var response = await _httpClient.GetAsync(requestUri, cancellationToken);
if (response.IsSuccessStatusCode)
{
using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken);
var root = jsonDocument.RootElement;
return new JsonResult(root.GetProperty("version").GetString());
}
}
catch (HttpRequestException e)
{
_logger.LogError(e, $"Error encountered while sending GET request to {requestUri}");
return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError };
}
return new JsonResult("-");
} }
return new JsonResult("-"); private class LatestVersions
{
[JsonProperty("versions")]
public Versions Versions { get; set; }
}
private class Versions
{
[JsonProperty("coreVersion")]
public string CoreVersion { get; set; }
[JsonProperty("webVersion")]
public string WebVersion { get; set; }
[JsonProperty("keyConnectorVersion")]
public string KeyConnectorVersion { get; set; }
}
} }
private class LatestVersions public enum ProjectType
{ {
[JsonProperty("versions")] Core,
public Versions Versions { get; set; } Web,
}
private class Versions
{
[JsonProperty("coreVersion")]
public string CoreVersion { get; set; }
[JsonProperty("webVersion")]
public string WebVersion { get; set; }
[JsonProperty("keyConnectorVersion")]
public string KeyConnectorVersion { get; set; }
} }
} }
public enum ProjectType
{
Core,
Web,
}

View File

@ -1,20 +1,21 @@
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
public class InfoController : Controller
{ {
[HttpGet("~/alive")] public class InfoController : Controller
[HttpGet("~/now")]
public DateTime GetAlive()
{ {
return DateTime.UtcNow; [HttpGet("~/alive")]
} [HttpGet("~/now")]
public DateTime GetAlive()
{
return DateTime.UtcNow;
}
[HttpGet("~/version")] [HttpGet("~/version")]
public JsonResult GetVersion() public JsonResult GetVersion()
{ {
return Json(CoreHelpers.GetVersion()); return Json(CoreHelpers.GetVersion());
}
} }
} }

View File

@ -3,90 +3,91 @@ using Bit.Core.Identity;
using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
public class LoginController : Controller
{ {
private readonly PasswordlessSignInManager<IdentityUser> _signInManager; public class LoginController : Controller
public LoginController(
PasswordlessSignInManager<IdentityUser> signInManager)
{ {
_signInManager = signInManager; private readonly PasswordlessSignInManager<IdentityUser> _signInManager;
}
public IActionResult Index(string returnUrl = null, int? error = null, int? success = null, public LoginController(
bool accessDenied = false) PasswordlessSignInManager<IdentityUser> signInManager)
{
if (!error.HasValue && accessDenied)
{ {
error = 4; _signInManager = signInManager;
} }
return View(new LoginModel public IActionResult Index(string returnUrl = null, int? error = null, int? success = null,
bool accessDenied = false)
{ {
ReturnUrl = returnUrl, if (!error.HasValue && accessDenied)
Error = GetMessage(error),
Success = GetMessage(success)
});
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Index(LoginModel model)
{
if (ModelState.IsValid)
{
await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl);
return RedirectToAction("Index", new
{ {
success = 3 error = 4;
}
return View(new LoginModel
{
ReturnUrl = returnUrl,
Error = GetMessage(error),
Success = GetMessage(success)
}); });
} }
return View(model); [HttpPost]
} [ValidateAntiForgeryToken]
public async Task<IActionResult> Index(LoginModel model)
public async Task<IActionResult> Confirm(string email, string token, string returnUrl)
{
var result = await _signInManager.PasswordlessSignInAsync(email, token, true);
if (!result.Succeeded)
{ {
if (ModelState.IsValid)
{
await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl);
return RedirectToAction("Index", new
{
success = 3
});
}
return View(model);
}
public async Task<IActionResult> Confirm(string email, string token, string returnUrl)
{
var result = await _signInManager.PasswordlessSignInAsync(email, token, true);
if (!result.Succeeded)
{
return RedirectToAction("Index", new
{
error = 2
});
}
if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl))
{
return Redirect(returnUrl);
}
return RedirectToAction("Index", "Home");
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Logout()
{
await _signInManager.SignOutAsync();
return RedirectToAction("Index", new return RedirectToAction("Index", new
{ {
error = 2 success = 1
}); });
} }
if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl)) private string GetMessage(int? messageCode)
{ {
return Redirect(returnUrl); return messageCode switch
{
1 => "You have been logged out.",
2 => "This login confirmation link is invalid. Try logging in again.",
3 => "If a valid admin user with this email address exists, " +
"we've sent you an email with a secure link to log in.",
4 => "Access denied. Please log in.",
_ => null,
};
} }
return RedirectToAction("Index", "Home");
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Logout()
{
await _signInManager.SignOutAsync();
return RedirectToAction("Index", new
{
success = 1
});
}
private string GetMessage(int? messageCode)
{
return messageCode switch
{
1 => "You have been logged out.",
2 => "This login confirmation link is invalid. Try logging in again.",
3 => "If a valid admin user with this email address exists, " +
"we've sent you an email with a secure link to log in.",
4 => "Access denied. Please log in.",
_ => null,
};
} }
} }

View File

@ -7,86 +7,87 @@ using Microsoft.Azure.Cosmos;
using Microsoft.Azure.Cosmos.Linq; using Microsoft.Azure.Cosmos.Linq;
using Serilog.Events; using Serilog.Events;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
[Authorize]
[SelfHosted(NotSelfHostedOnly = true)]
public class LogsController : Controller
{ {
private const string Database = "Diagnostics"; [Authorize]
private const string Container = "Logs"; [SelfHosted(NotSelfHostedOnly = true)]
public class LogsController : Controller
private readonly GlobalSettings _globalSettings;
public LogsController(GlobalSettings globalSettings)
{ {
_globalSettings = globalSettings; private const string Database = "Diagnostics";
} private const string Container = "Logs";
public async Task<IActionResult> Index(string cursor = null, int count = 50, private readonly GlobalSettings _globalSettings;
LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null)
{ public LogsController(GlobalSettings globalSettings)
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{ {
var cosmosContainer = client.GetContainer(Database, Container); _globalSettings = globalSettings;
var query = cosmosContainer.GetItemLinqQueryable<LogModel>(
requestOptions: new QueryRequestOptions()
{
MaxItemCount = count
},
continuationToken: cursor
).AsQueryable();
if (level.HasValue)
{
query = query.Where(l => l.Level == level.Value.ToString());
}
if (!string.IsNullOrWhiteSpace(project))
{
query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project);
}
if (start.HasValue)
{
query = query.Where(l => l.Timestamp >= start.Value);
}
if (end.HasValue)
{
query = query.Where(l => l.Timestamp <= end.Value);
}
var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator();
var response = await feedIterator.ReadNextAsync();
return View(new LogsModel
{
Level = level,
Project = project,
Start = start,
End = end,
Items = response.ToList(),
Count = count,
Cursor = cursor,
NextCursor = response.ContinuationToken
});
} }
}
public async Task<IActionResult> View(Guid id) public async Task<IActionResult> Index(string cursor = null, int count = 50,
{ LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null)
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{ {
var cosmosContainer = client.GetContainer(Database, Container); using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
var query = cosmosContainer.GetItemLinqQueryable<LogDetailsModel>() _globalSettings.DocumentDb.Key))
.AsQueryable()
.Where(l => l.Id == id.ToString());
var response = await query.ToFeedIterator().ReadNextAsync();
if (response == null || response.Count == 0)
{ {
return RedirectToAction("Index"); var cosmosContainer = client.GetContainer(Database, Container);
var query = cosmosContainer.GetItemLinqQueryable<LogModel>(
requestOptions: new QueryRequestOptions()
{
MaxItemCount = count
},
continuationToken: cursor
).AsQueryable();
if (level.HasValue)
{
query = query.Where(l => l.Level == level.Value.ToString());
}
if (!string.IsNullOrWhiteSpace(project))
{
query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project);
}
if (start.HasValue)
{
query = query.Where(l => l.Timestamp >= start.Value);
}
if (end.HasValue)
{
query = query.Where(l => l.Timestamp <= end.Value);
}
var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator();
var response = await feedIterator.ReadNextAsync();
return View(new LogsModel
{
Level = level,
Project = project,
Start = start,
End = end,
Items = response.ToList(),
Count = count,
Cursor = cursor,
NextCursor = response.ContinuationToken
});
}
}
public async Task<IActionResult> View(Guid id)
{
using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri,
_globalSettings.DocumentDb.Key))
{
var cosmosContainer = client.GetContainer(Database, Container);
var query = cosmosContainer.GetItemLinqQueryable<LogDetailsModel>()
.AsQueryable()
.Where(l => l.Id == id.ToString());
var response = await query.ToFeedIterator().ReadNextAsync();
if (response == null || response.Count == 0)
{
return RedirectToAction("Index");
}
return View(response.First());
} }
return View(response.First());
} }
} }
} }

View File

@ -11,206 +11,207 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
[Authorize]
public class OrganizationsController : Controller
{ {
private readonly IOrganizationRepository _organizationRepository; [Authorize]
private readonly IOrganizationUserRepository _organizationUserRepository; public class OrganizationsController : Controller
private readonly IOrganizationConnectionRepository _organizationConnectionRepository;
private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand;
private readonly ICipherRepository _cipherRepository;
private readonly ICollectionRepository _collectionRepository;
private readonly IGroupRepository _groupRepository;
private readonly IPolicyRepository _policyRepository;
private readonly IPaymentService _paymentService;
private readonly ILicensingService _licensingService;
private readonly IApplicationCacheService _applicationCacheService;
private readonly GlobalSettings _globalSettings;
private readonly IReferenceEventService _referenceEventService;
private readonly IUserService _userService;
private readonly ILogger<OrganizationsController> _logger;
public OrganizationsController(
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationConnectionRepository organizationConnectionRepository,
ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand,
ICipherRepository cipherRepository,
ICollectionRepository collectionRepository,
IGroupRepository groupRepository,
IPolicyRepository policyRepository,
IPaymentService paymentService,
ILicensingService licensingService,
IApplicationCacheService applicationCacheService,
GlobalSettings globalSettings,
IReferenceEventService referenceEventService,
IUserService userService,
ILogger<OrganizationsController> logger)
{ {
_organizationRepository = organizationRepository; private readonly IOrganizationRepository _organizationRepository;
_organizationUserRepository = organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
_organizationConnectionRepository = organizationConnectionRepository; private readonly IOrganizationConnectionRepository _organizationConnectionRepository;
_syncSponsorshipsCommand = syncSponsorshipsCommand; private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand;
_cipherRepository = cipherRepository; private readonly ICipherRepository _cipherRepository;
_collectionRepository = collectionRepository; private readonly ICollectionRepository _collectionRepository;
_groupRepository = groupRepository; private readonly IGroupRepository _groupRepository;
_policyRepository = policyRepository; private readonly IPolicyRepository _policyRepository;
_paymentService = paymentService; private readonly IPaymentService _paymentService;
_licensingService = licensingService; private readonly ILicensingService _licensingService;
_applicationCacheService = applicationCacheService; private readonly IApplicationCacheService _applicationCacheService;
_globalSettings = globalSettings; private readonly GlobalSettings _globalSettings;
_referenceEventService = referenceEventService; private readonly IReferenceEventService _referenceEventService;
_userService = userService; private readonly IUserService _userService;
_logger = logger; private readonly ILogger<OrganizationsController> _logger;
}
public async Task<IActionResult> Index(string name = null, string userEmail = null, bool? paid = null, public OrganizationsController(
int page = 1, int count = 25) IOrganizationRepository organizationRepository,
{ IOrganizationUserRepository organizationUserRepository,
if (page < 1) IOrganizationConnectionRepository organizationConnectionRepository,
ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand,
ICipherRepository cipherRepository,
ICollectionRepository collectionRepository,
IGroupRepository groupRepository,
IPolicyRepository policyRepository,
IPaymentService paymentService,
ILicensingService licensingService,
IApplicationCacheService applicationCacheService,
GlobalSettings globalSettings,
IReferenceEventService referenceEventService,
IUserService userService,
ILogger<OrganizationsController> logger)
{ {
page = 1; _organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_organizationConnectionRepository = organizationConnectionRepository;
_syncSponsorshipsCommand = syncSponsorshipsCommand;
_cipherRepository = cipherRepository;
_collectionRepository = collectionRepository;
_groupRepository = groupRepository;
_policyRepository = policyRepository;
_paymentService = paymentService;
_licensingService = licensingService;
_applicationCacheService = applicationCacheService;
_globalSettings = globalSettings;
_referenceEventService = referenceEventService;
_userService = userService;
_logger = logger;
} }
if (count < 1) public async Task<IActionResult> Index(string name = null, string userEmail = null, bool? paid = null,
int page = 1, int count = 25)
{ {
count = 1; if (page < 1)
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count);
return View(new OrganizationsModel
{
Items = organizations as List<Organization>,
Name = string.IsNullOrWhiteSpace(name) ? null : name,
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
Paid = paid,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
} }
var skip = (page - 1) * count; public async Task<IActionResult> View(Guid id)
var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count);
return View(new OrganizationsModel
{ {
Items = organizations as List<Organization>, var organization = await _organizationRepository.GetByIdAsync(id);
Name = string.IsNullOrWhiteSpace(name) ? null : name, if (organization == null)
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, {
Paid = paid, return RedirectToAction("Index");
Page = page, }
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
}
public async Task<IActionResult> View(Guid id) var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
{ var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
var organization = await _organizationRepository.GetByIdAsync(id); IEnumerable<Group> groups = null;
if (organization == null) if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies));
}
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{ {
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(organization);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies,
billingInfo, billingSyncConnection, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, OrganizationEditModel model)
{
var organization = await _organizationRepository.GetByIdAsync(id);
model.ToOrganization(organization);
await _organizationRepository.ReplaceAsync(organization);
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization)
{
EventRaisedByUser = _userService.GetUserName(User),
SalesAssistedTrialStarted = model.SalesAssistedTrialStarted,
});
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization != null)
{
await _organizationRepository.DeleteAsync(organization);
await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id);
}
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); public async Task<IActionResult> TriggerBillingSync(Guid id)
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{ {
groups = await _groupRepository.GetManyByOrganizationIdAsync(id); var organization = await _organizationRepository.GetByIdAsync(id);
} if (organization == null)
IEnumerable<Policy> policies = null; {
if (organization.UsePolicies) return RedirectToAction("Index");
{ }
policies = await _policyRepository.GetManyByOrganizationIdAsync(id); var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault();
} if (connection != null)
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); {
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; try
return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies)); {
} var config = connection.GetConfig<BillingSyncConfig>();
await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection);
TempData["ConnectionActivated"] = id;
TempData["ConnectionError"] = null;
}
catch (Exception ex)
{
TempData["ConnectionError"] = ex.Message;
_logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id);
}
[SelfHosted(NotSelfHostedOnly = true)] if (_globalSettings.SelfHosted)
public async Task<IActionResult> Edit(Guid id) {
{ return RedirectToAction("View", new { id });
var organization = await _organizationRepository.GetByIdAsync(id); }
if (organization == null) else
{ {
return RedirectToAction("Edit", new { id });
}
}
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id);
var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id);
IEnumerable<Group> groups = null;
if (organization.UseGroups)
{
groups = await _groupRepository.GetManyByOrganizationIdAsync(id);
}
IEnumerable<Policy> policies = null;
if (organization.UsePolicies)
{
policies = await _policyRepository.GetManyByOrganizationIdAsync(id);
}
var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(organization);
var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null;
return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies,
billingInfo, billingSyncConnection, _globalSettings));
} }
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, OrganizationEditModel model)
{
var organization = await _organizationRepository.GetByIdAsync(id);
model.ToOrganization(organization);
await _organizationRepository.ReplaceAsync(organization);
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization)
{
EventRaisedByUser = _userService.GetUserName(User),
SalesAssistedTrialStarted = model.SalesAssistedTrialStarted,
});
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization != null)
{
await _organizationRepository.DeleteAsync(organization);
await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id);
}
return RedirectToAction("Index");
}
public async Task<IActionResult> TriggerBillingSync(Guid id)
{
var organization = await _organizationRepository.GetByIdAsync(id);
if (organization == null)
{
return RedirectToAction("Index");
}
var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault();
if (connection != null)
{
try
{
var config = connection.GetConfig<BillingSyncConfig>();
await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection);
TempData["ConnectionActivated"] = id;
TempData["ConnectionError"] = null;
}
catch (Exception ex)
{
TempData["ConnectionError"] = ex.Message;
_logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id);
}
if (_globalSettings.SelfHosted)
{
return RedirectToAction("View", new { id });
}
else
{
return RedirectToAction("Edit", new { id });
}
}
return RedirectToAction("Index");
}
} }

View File

@ -7,127 +7,128 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
[Authorize]
[SelfHosted(NotSelfHostedOnly = true)]
public class ProvidersController : Controller
{ {
private readonly IProviderRepository _providerRepository; [Authorize]
private readonly IProviderUserRepository _providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly GlobalSettings _globalSettings;
private readonly IApplicationCacheService _applicationCacheService;
private readonly IProviderService _providerService;
public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService,
GlobalSettings globalSettings, IApplicationCacheService applicationCacheService)
{
_providerRepository = providerRepository;
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_providerService = providerService;
_globalSettings = globalSettings;
_applicationCacheService = applicationCacheService;
}
public async Task<IActionResult> Index(string name = null, string userEmail = null, int page = 1, int count = 25)
{
if (page < 1)
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count);
return View(new ProvidersModel
{
Items = providers as List<Provider>,
Name = string.IsNullOrWhiteSpace(name) ? null : name,
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
}
public IActionResult Create(string ownerEmail = null)
{
return View(new CreateProviderModel
{
OwnerEmail = ownerEmail
});
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Create(CreateProviderModel model)
{
if (!ModelState.IsValid)
{
return View(model);
}
await _providerService.CreateAsync(model.OwnerEmail);
return RedirectToAction("Index");
}
public async Task<IActionResult> View(Guid id)
{
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{
return RedirectToAction("Index");
}
var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id);
var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
return View(new ProviderViewModel(provider, users, providerOrganizations));
}
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id) public class ProvidersController : Controller
{ {
var provider = await _providerRepository.GetByIdAsync(id); private readonly IProviderRepository _providerRepository;
if (provider == null) private readonly IProviderUserRepository _providerUserRepository;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly GlobalSettings _globalSettings;
private readonly IApplicationCacheService _applicationCacheService;
private readonly IProviderService _providerService;
public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService,
GlobalSettings globalSettings, IApplicationCacheService applicationCacheService)
{ {
_providerRepository = providerRepository;
_providerUserRepository = providerUserRepository;
_providerOrganizationRepository = providerOrganizationRepository;
_providerService = providerService;
_globalSettings = globalSettings;
_applicationCacheService = applicationCacheService;
}
public async Task<IActionResult> Index(string name = null, string userEmail = null, int page = 1, int count = 25)
{
if (page < 1)
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count);
return View(new ProvidersModel
{
Items = providers as List<Provider>,
Name = string.IsNullOrWhiteSpace(name) ? null : name,
UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit",
SelfHosted = _globalSettings.SelfHosted
});
}
public IActionResult Create(string ownerEmail = null)
{
return View(new CreateProviderModel
{
OwnerEmail = ownerEmail
});
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Create(CreateProviderModel model)
{
if (!ModelState.IsValid)
{
return View(model);
}
await _providerService.CreateAsync(model.OwnerEmail);
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); public async Task<IActionResult> View(Guid id)
var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
return View(new ProviderEditModel(provider, users, providerOrganizations));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, ProviderEditModel model)
{
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{ {
return RedirectToAction("Index"); var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{
return RedirectToAction("Index");
}
var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id);
var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
return View(new ProviderViewModel(provider, users, providerOrganizations));
} }
model.ToProvider(provider); [SelfHosted(NotSelfHostedOnly = true)]
await _providerRepository.ReplaceAsync(provider); public async Task<IActionResult> Edit(Guid id)
await _applicationCacheService.UpsertProviderAbilityAsync(provider); {
return RedirectToAction("Edit", new { id }); var provider = await _providerRepository.GetByIdAsync(id);
} if (provider == null)
{
return RedirectToAction("Index");
}
public async Task<IActionResult> ResendInvite(Guid ownerId, Guid providerId) var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id);
{ var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id);
await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId); return View(new ProviderEditModel(provider, users, providerOrganizations));
TempData["InviteResentTo"] = ownerId; }
return RedirectToAction("Edit", new { id = providerId });
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, ProviderEditModel model)
{
var provider = await _providerRepository.GetByIdAsync(id);
if (provider == null)
{
return RedirectToAction("Index");
}
model.ToProvider(provider);
await _providerRepository.ReplaceAsync(provider);
await _applicationCacheService.UpsertProviderAbilityAsync(provider);
return RedirectToAction("Edit", new { id });
}
public async Task<IActionResult> ResendInvite(Guid ownerId, Guid providerId)
{
await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId);
TempData["InviteResentTo"] = ownerId;
return RedirectToAction("Edit", new { id = providerId });
}
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -7,104 +7,105 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Controllers; namespace Bit.Admin.Controllers
[Authorize]
public class UsersController : Controller
{ {
private readonly IUserRepository _userRepository; [Authorize]
private readonly ICipherRepository _cipherRepository; public class UsersController : Controller
private readonly IPaymentService _paymentService;
private readonly GlobalSettings _globalSettings;
public UsersController(
IUserRepository userRepository,
ICipherRepository cipherRepository,
IPaymentService paymentService,
GlobalSettings globalSettings)
{ {
_userRepository = userRepository; private readonly IUserRepository _userRepository;
_cipherRepository = cipherRepository; private readonly ICipherRepository _cipherRepository;
_paymentService = paymentService; private readonly IPaymentService _paymentService;
_globalSettings = globalSettings; private readonly GlobalSettings _globalSettings;
}
public async Task<IActionResult> Index(string email, int page = 1, int count = 25) public UsersController(
{ IUserRepository userRepository,
if (page < 1) ICipherRepository cipherRepository,
IPaymentService paymentService,
GlobalSettings globalSettings)
{ {
page = 1; _userRepository = userRepository;
_cipherRepository = cipherRepository;
_paymentService = paymentService;
_globalSettings = globalSettings;
} }
if (count < 1) public async Task<IActionResult> Index(string email, int page = 1, int count = 25)
{ {
count = 1; if (page < 1)
{
page = 1;
}
if (count < 1)
{
count = 1;
}
var skip = (page - 1) * count;
var users = await _userRepository.SearchAsync(email, skip, count);
return View(new UsersModel
{
Items = users as List<User>,
Email = string.IsNullOrWhiteSpace(email) ? null : email,
Page = page,
Count = count,
Action = _globalSettings.SelfHosted ? "View" : "Edit"
});
} }
var skip = (page - 1) * count; public async Task<IActionResult> View(Guid id)
var users = await _userRepository.SearchAsync(email, skip, count);
return View(new UsersModel
{ {
Items = users as List<User>, var user = await _userRepository.GetByIdAsync(id);
Email = string.IsNullOrWhiteSpace(email) ? null : email, if (user == null)
Page = page, {
Count = count, return RedirectToAction("Index");
Action = _globalSettings.SelfHosted ? "View" : "Edit" }
});
}
public async Task<IActionResult> View(Guid id) var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
{ return View(new UserViewModel(user, ciphers));
var user = await _userRepository.GetByIdAsync(id); }
if (user == null)
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{ {
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(user);
return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, UserEditModel model)
{
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
model.ToUser(user);
await _userRepository.ReplaceAsync(user);
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var user = await _userRepository.GetByIdAsync(id);
if (user != null)
{
await _userRepository.DeleteAsync(user);
}
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
return View(new UserViewModel(user, ciphers));
}
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id)
{
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(id);
var billingInfo = await _paymentService.GetBillingAsync(user);
return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings));
}
[HttpPost]
[ValidateAntiForgeryToken]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> Edit(Guid id, UserEditModel model)
{
var user = await _userRepository.GetByIdAsync(id);
if (user == null)
{
return RedirectToAction("Index");
}
model.ToUser(user);
await _userRepository.ReplaceAsync(user);
return RedirectToAction("Edit", new { id });
}
[HttpPost]
[ValidateAntiForgeryToken]
public async Task<IActionResult> Delete(Guid id)
{
var user = await _userRepository.GetByIdAsync(id);
if (user != null)
{
await _userRepository.DeleteAsync(user);
}
return RedirectToAction("Index");
} }
} }

View File

@ -4,80 +4,81 @@ using Amazon.SQS.Model;
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Admin.HostedServices; namespace Bit.Admin.HostedServices
public class AmazonSqsBlockIpHostedService : BlockIpHostedService
{ {
private AmazonSQSClient _client; public class AmazonSqsBlockIpHostedService : BlockIpHostedService
public AmazonSqsBlockIpHostedService(
ILogger<AmazonSqsBlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
: base(logger, adminSettings, globalSettings)
{ }
public override void Dispose()
{ {
_client?.Dispose(); private AmazonSQSClient _client;
}
protected override async Task ExecuteAsync(CancellationToken cancellationToken) public AmazonSqsBlockIpHostedService(
{ ILogger<AmazonSqsBlockIpHostedService> logger,
_client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId, IOptions<AdminSettings> adminSettings,
_globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region)); GlobalSettings globalSettings)
var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken); : base(logger, adminSettings, globalSettings)
var blockIpQueueUrl = blockIpQueue.QueueUrl; { }
var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken);
var unblockIpQueueUrl = unblockIpQueue.QueueUrl;
while (!cancellationToken.IsCancellationRequested) public override void Dispose()
{ {
var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest _client?.Dispose();
{ }
QueueUrl = blockIpQueueUrl,
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (blockMessageResponse.Messages.Any())
{
foreach (var message in blockMessageResponse.Messages)
{
try
{
await BlockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to block IP.");
}
await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken);
}
}
var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest protected override async Task ExecuteAsync(CancellationToken cancellationToken)
{ {
QueueUrl = unblockIpQueueUrl, _client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId,
MaxNumberOfMessages = 10, _globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region));
WaitTimeSeconds = 15 var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken);
}, cancellationToken); var blockIpQueueUrl = blockIpQueue.QueueUrl;
if (unblockMessageResponse.Messages.Any()) var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken);
{ var unblockIpQueueUrl = unblockIpQueue.QueueUrl;
foreach (var message in unblockMessageResponse.Messages)
{
try
{
await UnblockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken);
}
}
await Task.Delay(TimeSpan.FromSeconds(15)); while (!cancellationToken.IsCancellationRequested)
{
var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest
{
QueueUrl = blockIpQueueUrl,
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (blockMessageResponse.Messages.Any())
{
foreach (var message in blockMessageResponse.Messages)
{
try
{
await BlockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to block IP.");
}
await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken);
}
}
var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest
{
QueueUrl = unblockIpQueueUrl,
MaxNumberOfMessages = 10,
WaitTimeSeconds = 15
}, cancellationToken);
if (unblockMessageResponse.Messages.Any())
{
foreach (var message in unblockMessageResponse.Messages)
{
try
{
await UnblockIpAsync(message.Body, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken);
}
}
await Task.Delay(TimeSpan.FromSeconds(15));
}
} }
} }
} }

View File

@ -2,62 +2,63 @@
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Admin.HostedServices; namespace Bit.Admin.HostedServices
public class AzureQueueBlockIpHostedService : BlockIpHostedService
{ {
private QueueClient _blockIpQueueClient; public class AzureQueueBlockIpHostedService : BlockIpHostedService
private QueueClient _unblockIpQueueClient;
public AzureQueueBlockIpHostedService(
ILogger<AzureQueueBlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
: base(logger, adminSettings, globalSettings)
{ }
protected override async Task ExecuteAsync(CancellationToken cancellationToken)
{ {
_blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip"); private QueueClient _blockIpQueueClient;
_unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip"); private QueueClient _unblockIpQueueClient;
while (!cancellationToken.IsCancellationRequested) public AzureQueueBlockIpHostedService(
ILogger<AzureQueueBlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
: base(logger, adminSettings, globalSettings)
{ }
protected override async Task ExecuteAsync(CancellationToken cancellationToken)
{ {
var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); _blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip");
if (blockMessages.Value?.Any() ?? false) _unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip");
{
foreach (var message in blockMessages.Value)
{
try
{
await BlockIpAsync(message.MessageText, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to block IP.");
}
await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
}
}
var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); while (!cancellationToken.IsCancellationRequested)
if (unblockMessages.Value?.Any() ?? false)
{ {
foreach (var message in unblockMessages.Value) var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32);
if (blockMessages.Value?.Any() ?? false)
{ {
try foreach (var message in blockMessages.Value)
{ {
await UnblockIpAsync(message.MessageText, cancellationToken); try
{
await BlockIpAsync(message.MessageText, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to block IP.");
}
await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
} }
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
} }
}
await Task.Delay(TimeSpan.FromSeconds(15)); var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32);
if (unblockMessages.Value?.Any() ?? false)
{
foreach (var message in unblockMessages.Value)
{
try
{
await UnblockIpAsync(message.MessageText, cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to unblock IP.");
}
await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
}
}
await Task.Delay(TimeSpan.FromSeconds(15));
}
} }
} }
} }

View File

@ -6,96 +6,97 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Admin.HostedServices; namespace Bit.Admin.HostedServices
public class AzureQueueMailHostedService : IHostedService
{ {
private readonly ILogger<AzureQueueMailHostedService> _logger; public class AzureQueueMailHostedService : IHostedService
private readonly GlobalSettings _globalSettings;
private readonly IMailService _mailService;
private CancellationTokenSource _cts;
private Task _executingTask;
private QueueClient _mailQueueClient;
public AzureQueueMailHostedService(
ILogger<AzureQueueMailHostedService> logger,
IMailService mailService,
GlobalSettings globalSettings)
{ {
_logger = logger; private readonly ILogger<AzureQueueMailHostedService> _logger;
_mailService = mailService; private readonly GlobalSettings _globalSettings;
_globalSettings = globalSettings; private readonly IMailService _mailService;
} private CancellationTokenSource _cts;
private Task _executingTask;
public Task StartAsync(CancellationToken cancellationToken) private QueueClient _mailQueueClient;
{
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken) public AzureQueueMailHostedService(
{ ILogger<AzureQueueMailHostedService> logger,
if (_executingTask == null) IMailService mailService,
GlobalSettings globalSettings)
{ {
return; _logger = logger;
_mailService = mailService;
_globalSettings = globalSettings;
} }
_cts.Cancel();
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
cancellationToken.ThrowIfCancellationRequested();
}
private async Task ExecuteAsync(CancellationToken cancellationToken) public Task StartAsync(CancellationToken cancellationToken)
{
_mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail");
QueueMessage[] mailMessages;
while (!cancellationToken.IsCancellationRequested)
{ {
if (!(mailMessages = await RetrieveMessagesAsync()).Any()) _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken)
{
if (_executingTask == null)
{ {
await Task.Delay(TimeSpan.FromSeconds(15)); return;
} }
_cts.Cancel();
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
cancellationToken.ThrowIfCancellationRequested();
}
foreach (var message in mailMessages) private async Task ExecuteAsync(CancellationToken cancellationToken)
{
_mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail");
QueueMessage[] mailMessages;
while (!cancellationToken.IsCancellationRequested)
{ {
try if (!(mailMessages = await RetrieveMessagesAsync()).Any())
{ {
using var document = JsonDocument.Parse(message.DecodeMessageText()); await Task.Delay(TimeSpan.FromSeconds(15));
var root = document.RootElement; }
if (root.ValueKind == JsonValueKind.Array) foreach (var message in mailMessages)
{
try
{ {
foreach (var mailQueueMessage in root.ToObject<List<MailQueueMessage>>()) using var document = JsonDocument.Parse(message.DecodeMessageText());
var root = document.RootElement;
if (root.ValueKind == JsonValueKind.Array)
{ {
foreach (var mailQueueMessage in root.ToObject<List<MailQueueMessage>>())
{
await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage);
}
}
else if (root.ValueKind == JsonValueKind.Object)
{
var mailQueueMessage = root.ToObject<MailQueueMessage>();
await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage);
} }
} }
else if (root.ValueKind == JsonValueKind.Object) catch (Exception e)
{ {
var mailQueueMessage = root.ToObject<MailQueueMessage>(); _logger.LogError(e, "Failed to send email");
await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); // TODO: retries?
} }
}
catch (Exception e)
{
_logger.LogError(e, "Failed to send email");
// TODO: retries?
}
await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt);
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
break; break;
}
} }
} }
} }
}
private async Task<QueueMessage[]> RetrieveMessagesAsync() private async Task<QueueMessage[]> RetrieveMessagesAsync()
{ {
return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { }; return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { };
}
} }
} }

View File

@ -1,105 +1,71 @@
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Bit.Admin.HostedServices; namespace Bit.Admin.HostedServices
public abstract class BlockIpHostedService : IHostedService, IDisposable
{ {
protected readonly ILogger<BlockIpHostedService> _logger; public abstract class BlockIpHostedService : IHostedService, IDisposable
protected readonly GlobalSettings _globalSettings;
private readonly AdminSettings _adminSettings;
private Task _executingTask;
private CancellationTokenSource _cts;
private HttpClient _httpClient = new HttpClient();
public BlockIpHostedService(
ILogger<BlockIpHostedService> logger,
IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
{ {
_logger = logger; protected readonly ILogger<BlockIpHostedService> _logger;
_globalSettings = globalSettings; protected readonly GlobalSettings _globalSettings;
_adminSettings = adminSettings?.Value; private readonly AdminSettings _adminSettings;
}
public Task StartAsync(CancellationToken cancellationToken) private Task _executingTask;
{ private CancellationTokenSource _cts;
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); private HttpClient _httpClient = new HttpClient();
_executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken) public BlockIpHostedService(
{ ILogger<BlockIpHostedService> logger,
if (_executingTask == null) IOptions<AdminSettings> adminSettings,
GlobalSettings globalSettings)
{ {
return; _logger = logger;
_globalSettings = globalSettings;
_adminSettings = adminSettings?.Value;
} }
_cts.Cancel();
await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
cancellationToken.ThrowIfCancellationRequested();
}
public virtual void Dispose() public Task StartAsync(CancellationToken cancellationToken)
{ }
protected abstract Task ExecuteAsync(CancellationToken cancellationToken);
protected async Task BlockIpAsync(string message, CancellationToken cancellationToken)
{
var request = new HttpRequestMessage();
request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Post;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules");
request.Content = JsonContent.Create(new
{ {
mode = "block", _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
configuration = new _executingTask = ExecuteAsync(_cts.Token);
return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken)
{
if (_executingTask == null)
{ {
target = "ip", return;
value = message }
}, _cts.Cancel();
notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}." await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken));
}); cancellationToken.ThrowIfCancellationRequested();
var response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode)
{
return;
} }
var accessRuleResponse = await response.Content.ReadFromJsonAsync<AccessRuleResponse>(cancellationToken: cancellationToken); public virtual void Dispose()
if (!accessRuleResponse.Success) { }
{
return;
}
// TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue protected abstract Task ExecuteAsync(CancellationToken cancellationToken);
}
protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken) protected async Task BlockIpAsync(string message, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(message))
{ {
return;
}
if (message.Contains(".") || message.Contains(":"))
{
// IP address messages
var request = new HttpRequestMessage(); var request = new HttpRequestMessage();
request.Headers.Accept.Clear(); request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Get; request.Method = HttpMethod.Post;
request.RequestUri = new Uri("https://api.cloudflare.com/" + request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules");
$"configuration_target=ip&configuration_value={message}");
request.Content = JsonContent.Create(new
{
mode = "block",
configuration = new
{
target = "ip",
value = message
},
notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}."
});
var response = await _httpClient.SendAsync(request, cancellationToken); var response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode) if (!response.IsSuccessStatusCode)
@ -107,58 +73,93 @@ public abstract class BlockIpHostedService : IHostedService, IDisposable
return; return;
} }
var listResponse = await response.Content.ReadFromJsonAsync<ListResponse>(cancellationToken: cancellationToken); var accessRuleResponse = await response.Content.ReadFromJsonAsync<AccessRuleResponse>(cancellationToken: cancellationToken);
if (!listResponse.Success) if (!accessRuleResponse.Success)
{ {
return; return;
} }
foreach (var rule in listResponse.Result) // TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue
}
protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(message))
{ {
await DeleteAccessRuleAsync(rule.Id, cancellationToken); return;
}
if (message.Contains(".") || message.Contains(":"))
{
// IP address messages
var request = new HttpRequestMessage();
request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Get;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" +
$"configuration_target=ip&configuration_value={message}");
var response = await _httpClient.SendAsync(request, cancellationToken);
if (!response.IsSuccessStatusCode)
{
return;
}
var listResponse = await response.Content.ReadFromJsonAsync<ListResponse>(cancellationToken: cancellationToken);
if (!listResponse.Success)
{
return;
}
foreach (var rule in listResponse.Result)
{
await DeleteAccessRuleAsync(rule.Id, cancellationToken);
}
}
else
{
// Rule Id messages
await DeleteAccessRuleAsync(message, cancellationToken);
} }
} }
else
protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken)
{ {
// Rule Id messages var request = new HttpRequestMessage();
await DeleteAccessRuleAsync(message, cancellationToken); request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Delete;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}");
await _httpClient.SendAsync(request, cancellationToken);
} }
}
protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken) public class ListResponse
{
var request = new HttpRequestMessage();
request.Headers.Accept.Clear();
request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail);
request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey);
request.Method = HttpMethod.Delete;
request.RequestUri = new Uri("https://api.cloudflare.com/" +
$"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}");
await _httpClient.SendAsync(request, cancellationToken);
}
public class ListResponse
{
public bool Success { get; set; }
public List<AccessRuleResultResponse> Result { get; set; }
}
public class AccessRuleResponse
{
public bool Success { get; set; }
public AccessRuleResultResponse Result { get; set; }
}
public class AccessRuleResultResponse
{
public string Id { get; set; }
public string Notes { get; set; }
public ConfigurationResponse Configuration { get; set; }
public class ConfigurationResponse
{ {
public string Target { get; set; } public bool Success { get; set; }
public string Value { get; set; } public List<AccessRuleResultResponse> Result { get; set; }
}
public class AccessRuleResponse
{
public bool Success { get; set; }
public AccessRuleResultResponse Result { get; set; }
}
public class AccessRuleResultResponse
{
public string Id { get; set; }
public string Notes { get; set; }
public ConfigurationResponse Configuration { get; set; }
public class ConfigurationResponse
{
public string Target { get; set; }
public string Value { get; set; }
}
} }
} }
} }

View File

@ -3,61 +3,62 @@ using Bit.Core.Jobs;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Migrator; using Bit.Migrator;
namespace Bit.Admin.HostedServices; namespace Bit.Admin.HostedServices
public class DatabaseMigrationHostedService : IHostedService, IDisposable
{ {
private readonly GlobalSettings _globalSettings; public class DatabaseMigrationHostedService : IHostedService, IDisposable
private readonly ILogger<DatabaseMigrationHostedService> _logger;
private readonly DbMigrator _dbMigrator;
public DatabaseMigrationHostedService(
GlobalSettings globalSettings,
ILogger<DatabaseMigrationHostedService> logger,
ILogger<DbMigrator> migratorLogger,
ILogger<JobListener> listenerLogger)
{ {
_globalSettings = globalSettings; private readonly GlobalSettings _globalSettings;
_logger = logger; private readonly ILogger<DatabaseMigrationHostedService> _logger;
_dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger); private readonly DbMigrator _dbMigrator;
}
public virtual async Task StartAsync(CancellationToken cancellationToken) public DatabaseMigrationHostedService(
{ GlobalSettings globalSettings,
// Wait 20 seconds to allow database to come online ILogger<DatabaseMigrationHostedService> logger,
await Task.Delay(20000); ILogger<DbMigrator> migratorLogger,
ILogger<JobListener> listenerLogger)
var maxMigrationAttempts = 10;
for (var i = 1; i <= maxMigrationAttempts; i++)
{ {
try _globalSettings = globalSettings;
_logger = logger;
_dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger);
}
public virtual async Task StartAsync(CancellationToken cancellationToken)
{
// Wait 20 seconds to allow database to come online
await Task.Delay(20000);
var maxMigrationAttempts = 10;
for (var i = 1; i <= maxMigrationAttempts; i++)
{ {
_dbMigrator.MigrateMsSqlDatabase(true, cancellationToken); try
// TODO: Maybe flip a flag somewhere to indicate migration is complete??
break;
}
catch (SqlException e)
{
if (i >= maxMigrationAttempts)
{ {
_logger.LogError(e, "Database failed to migrate."); _dbMigrator.MigrateMsSqlDatabase(true, cancellationToken);
throw; // TODO: Maybe flip a flag somewhere to indicate migration is complete??
break;
} }
else catch (SqlException e)
{ {
_logger.LogError(e, if (i >= maxMigrationAttempts)
"Database unavailable for migration. Trying again (attempt #{0})...", i + 1); {
await Task.Delay(20000); _logger.LogError(e, "Database failed to migrate.");
throw;
}
else
{
_logger.LogError(e,
"Database unavailable for migration. Trying again (attempt #{0})...", i + 1);
await Task.Delay(20000);
}
} }
} }
} }
}
public virtual Task StopAsync(CancellationToken cancellationToken) public virtual Task StopAsync(CancellationToken cancellationToken)
{ {
return Task.FromResult(0); return Task.FromResult(0);
} }
public virtual void Dispose() public virtual void Dispose()
{ } { }
}
} }

View File

@ -3,26 +3,27 @@ using Bit.Core.Jobs;
using Bit.Core.Settings; using Bit.Core.Settings;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class AliveJob : BaseJob
{ {
private readonly GlobalSettings _globalSettings; public class AliveJob : BaseJob
private HttpClient _httpClient = new HttpClient();
public AliveJob(
GlobalSettings globalSettings,
ILogger<AliveJob> logger)
: base(logger)
{ {
_globalSettings = globalSettings; private readonly GlobalSettings _globalSettings;
} private HttpClient _httpClient = new HttpClient();
protected async override Task ExecuteJobAsync(IJobExecutionContext context) public AliveJob(
{ GlobalSettings globalSettings,
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); ILogger<AliveJob> logger)
var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); : base(logger)
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + {
response.StatusCode); _globalSettings = globalSettings;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive");
var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " +
response.StatusCode);
}
} }
} }

View File

@ -3,24 +3,25 @@ using Bit.Core.Jobs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class DatabaseExpiredGrantsJob : BaseJob
{ {
private readonly IMaintenanceRepository _maintenanceRepository; public class DatabaseExpiredGrantsJob : BaseJob
public DatabaseExpiredGrantsJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseExpiredGrantsJob> logger)
: base(logger)
{ {
_maintenanceRepository = maintenanceRepository; private readonly IMaintenanceRepository _maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context) public DatabaseExpiredGrantsJob(
{ IMaintenanceRepository maintenanceRepository,
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync"); ILogger<DatabaseExpiredGrantsJob> logger)
await _maintenanceRepository.DeleteExpiredGrantsAsync(); : base(logger)
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync"); {
_maintenanceRepository = maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync");
await _maintenanceRepository.DeleteExpiredGrantsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync");
}
} }
} }

View File

@ -4,35 +4,36 @@ using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class DatabaseExpiredSponsorshipsJob : BaseJob
{ {
private GlobalSettings _globalSettings; public class DatabaseExpiredSponsorshipsJob : BaseJob
private readonly IMaintenanceRepository _maintenanceRepository;
public DatabaseExpiredSponsorshipsJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseExpiredSponsorshipsJob> logger,
GlobalSettings globalSettings)
: base(logger)
{ {
_maintenanceRepository = maintenanceRepository; private GlobalSettings _globalSettings;
_globalSettings = globalSettings; private readonly IMaintenanceRepository _maintenanceRepository;
}
protected override async Task ExecuteJobAsync(IJobExecutionContext context) public DatabaseExpiredSponsorshipsJob(
{ IMaintenanceRepository maintenanceRepository,
if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) ILogger<DatabaseExpiredSponsorshipsJob> logger,
GlobalSettings globalSettings)
: base(logger)
{ {
return; _maintenanceRepository = maintenanceRepository;
_globalSettings = globalSettings;
} }
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync");
// allow a 90 day grace period before deleting protected override async Task ExecuteJobAsync(IJobExecutionContext context)
var deleteDate = DateTime.UtcNow.AddDays(-90); {
if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication)
{
return;
}
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync");
await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate); // allow a 90 day grace period before deleting
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync"); var deleteDate = DateTime.UtcNow.AddDays(-90);
await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync");
}
} }
} }

View File

@ -3,24 +3,25 @@ using Bit.Core.Jobs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class DatabaseRebuildlIndexesJob : BaseJob
{ {
private readonly IMaintenanceRepository _maintenanceRepository; public class DatabaseRebuildlIndexesJob : BaseJob
public DatabaseRebuildlIndexesJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseRebuildlIndexesJob> logger)
: base(logger)
{ {
_maintenanceRepository = maintenanceRepository; private readonly IMaintenanceRepository _maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context) public DatabaseRebuildlIndexesJob(
{ IMaintenanceRepository maintenanceRepository,
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync"); ILogger<DatabaseRebuildlIndexesJob> logger)
await _maintenanceRepository.RebuildIndexesAsync(); : base(logger)
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync"); {
_maintenanceRepository = maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync");
await _maintenanceRepository.RebuildIndexesAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync");
}
} }
} }

View File

@ -3,27 +3,28 @@ using Bit.Core.Jobs;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class DatabaseUpdateStatisticsJob : BaseJob
{ {
private readonly IMaintenanceRepository _maintenanceRepository; public class DatabaseUpdateStatisticsJob : BaseJob
public DatabaseUpdateStatisticsJob(
IMaintenanceRepository maintenanceRepository,
ILogger<DatabaseUpdateStatisticsJob> logger)
: base(logger)
{ {
_maintenanceRepository = maintenanceRepository; private readonly IMaintenanceRepository _maintenanceRepository;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context) public DatabaseUpdateStatisticsJob(
{ IMaintenanceRepository maintenanceRepository,
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync"); ILogger<DatabaseUpdateStatisticsJob> logger)
await _maintenanceRepository.UpdateStatisticsAsync(); : base(logger)
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync"); {
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync"); _maintenanceRepository = maintenanceRepository;
await _maintenanceRepository.DisableCipherAutoStatsAsync(); }
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync");
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync");
await _maintenanceRepository.UpdateStatisticsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync");
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync");
await _maintenanceRepository.DisableCipherAutoStatsAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync");
}
} }
} }

View File

@ -4,33 +4,34 @@ using Bit.Core.Repositories;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class DeleteCiphersJob : BaseJob
{ {
private readonly ICipherRepository _cipherRepository; public class DeleteCiphersJob : BaseJob
private readonly AdminSettings _adminSettings;
public DeleteCiphersJob(
ICipherRepository cipherRepository,
IOptions<AdminSettings> adminSettings,
ILogger<DeleteCiphersJob> logger)
: base(logger)
{ {
_cipherRepository = cipherRepository; private readonly ICipherRepository _cipherRepository;
_adminSettings = adminSettings?.Value; private readonly AdminSettings _adminSettings;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context) public DeleteCiphersJob(
{ ICipherRepository cipherRepository,
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync"); IOptions<AdminSettings> adminSettings,
var deleteDate = DateTime.UtcNow.AddDays(-30); ILogger<DeleteCiphersJob> logger)
var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault(); : base(logger)
if (daysAgoSetting > 0)
{ {
deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting); _cipherRepository = cipherRepository;
_adminSettings = adminSettings?.Value;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync");
var deleteDate = DateTime.UtcNow.AddDays(-30);
var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault();
if (daysAgoSetting > 0)
{
deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting);
}
await _cipherRepository.DeleteDeletedAsync(deleteDate);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync");
} }
await _cipherRepository.DeleteDeletedAsync(deleteDate);
_logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync");
} }
} }

View File

@ -4,37 +4,38 @@ using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class DeleteSendsJob : BaseJob
{ {
private readonly ISendRepository _sendRepository; public class DeleteSendsJob : BaseJob
private readonly IServiceProvider _serviceProvider;
public DeleteSendsJob(
ISendRepository sendRepository,
IServiceProvider serviceProvider,
ILogger<DatabaseExpiredGrantsJob> logger)
: base(logger)
{ {
_sendRepository = sendRepository; private readonly ISendRepository _sendRepository;
_serviceProvider = serviceProvider; private readonly IServiceProvider _serviceProvider;
}
protected async override Task ExecuteJobAsync(IJobExecutionContext context) public DeleteSendsJob(
{ ISendRepository sendRepository,
var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow); IServiceProvider serviceProvider,
_logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count); ILogger<DatabaseExpiredGrantsJob> logger)
if (!sends.Any()) : base(logger)
{ {
return; _sendRepository = sendRepository;
_serviceProvider = serviceProvider;
} }
using (var scope = _serviceProvider.CreateScope())
protected async override Task ExecuteJobAsync(IJobExecutionContext context)
{ {
var sendService = scope.ServiceProvider.GetRequiredService<ISendService>(); var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow);
foreach (var send in sends) _logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count);
if (!sends.Any())
{ {
await sendService.DeleteSendAsync(send); return;
}
using (var scope = _serviceProvider.CreateScope())
{
var sendService = scope.ServiceProvider.GetRequiredService<ISendService>();
foreach (var send in sends)
{
await sendService.DeleteSendAsync(send);
}
} }
} }
} }

View File

@ -3,93 +3,94 @@ using Bit.Core.Jobs;
using Bit.Core.Settings; using Bit.Core.Settings;
using Quartz; using Quartz;
namespace Bit.Admin.Jobs; namespace Bit.Admin.Jobs
public class JobsHostedService : BaseJobsHostedService
{ {
public JobsHostedService( public class JobsHostedService : BaseJobsHostedService
GlobalSettings globalSettings,
IServiceProvider serviceProvider,
ILogger<JobsHostedService> logger,
ILogger<JobListener> listenerLogger)
: base(globalSettings, serviceProvider, logger, listenerLogger) { }
public override async Task StartAsync(CancellationToken cancellationToken)
{ {
var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? public JobsHostedService(
TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : GlobalSettings globalSettings,
TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); IServiceProvider serviceProvider,
if (_globalSettings.SelfHosted) ILogger<JobsHostedService> logger,
ILogger<JobListener> listenerLogger)
: base(globalSettings, serviceProvider, logger, listenerLogger) { }
public override async Task StartAsync(CancellationToken cancellationToken)
{ {
timeZone = TimeZoneInfo.Local; var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") :
TimeZoneInfo.FindSystemTimeZoneById("America/New_York");
if (_globalSettings.SelfHosted)
{
timeZone = TimeZoneInfo.Local;
}
var everyTopOfTheHourTrigger = TriggerBuilder.Create()
.WithIdentity("EveryTopOfTheHourTrigger")
.StartNow()
.WithCronSchedule("0 0 * * * ?")
.Build();
var everyFiveMinutesTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFiveMinutesTrigger")
.StartNow()
.WithCronSchedule("0 */5 * * * ?")
.Build();
var everyFridayAt10pmTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFridayAt10pmTrigger")
.StartNow()
.WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone))
.Build();
var everySaturdayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySaturdayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone))
.Build();
var everySundayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySundayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone))
.Build();
var everyMondayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EveryMondayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone))
.Build();
var everyDayAtMidnightUtc = TriggerBuilder.Create()
.WithIdentity("EveryDayAtMidnightUtc")
.StartNow()
.WithCronSchedule("0 0 0 * * ?")
.Build();
var jobs = new List<Tuple<Type, ITrigger>>
{
new Tuple<Type, ITrigger>(typeof(DeleteSendsJob), everyFiveMinutesTrigger),
new Tuple<Type, ITrigger>(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger),
new Tuple<Type, ITrigger>(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger),
new Tuple<Type, ITrigger>(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger),
new Tuple<Type, ITrigger>(typeof(DeleteCiphersJob), everyDayAtMidnightUtc),
new Tuple<Type, ITrigger>(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger)
};
if (!_globalSettings.SelfHosted)
{
jobs.Add(new Tuple<Type, ITrigger>(typeof(AliveJob), everyTopOfTheHourTrigger));
}
Jobs = jobs;
await base.StartAsync(cancellationToken);
} }
var everyTopOfTheHourTrigger = TriggerBuilder.Create() public static void AddJobsServices(IServiceCollection services, bool selfHosted)
.WithIdentity("EveryTopOfTheHourTrigger")
.StartNow()
.WithCronSchedule("0 0 * * * ?")
.Build();
var everyFiveMinutesTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFiveMinutesTrigger")
.StartNow()
.WithCronSchedule("0 */5 * * * ?")
.Build();
var everyFridayAt10pmTrigger = TriggerBuilder.Create()
.WithIdentity("EveryFridayAt10pmTrigger")
.StartNow()
.WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone))
.Build();
var everySaturdayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySaturdayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone))
.Build();
var everySundayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EverySundayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone))
.Build();
var everyMondayAtMidnightTrigger = TriggerBuilder.Create()
.WithIdentity("EveryMondayAtMidnightTrigger")
.StartNow()
.WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone))
.Build();
var everyDayAtMidnightUtc = TriggerBuilder.Create()
.WithIdentity("EveryDayAtMidnightUtc")
.StartNow()
.WithCronSchedule("0 0 0 * * ?")
.Build();
var jobs = new List<Tuple<Type, ITrigger>>
{ {
new Tuple<Type, ITrigger>(typeof(DeleteSendsJob), everyFiveMinutesTrigger), if (!selfHosted)
new Tuple<Type, ITrigger>(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), {
new Tuple<Type, ITrigger>(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), services.AddTransient<AliveJob>();
new Tuple<Type, ITrigger>(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), }
new Tuple<Type, ITrigger>(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), services.AddTransient<DatabaseUpdateStatisticsJob>();
new Tuple<Type, ITrigger>(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger) services.AddTransient<DatabaseRebuildlIndexesJob>();
}; services.AddTransient<DatabaseExpiredGrantsJob>();
services.AddTransient<DatabaseExpiredSponsorshipsJob>();
if (!_globalSettings.SelfHosted) services.AddTransient<DeleteSendsJob>();
{ services.AddTransient<DeleteCiphersJob>();
jobs.Add(new Tuple<Type, ITrigger>(typeof(AliveJob), everyTopOfTheHourTrigger));
} }
Jobs = jobs;
await base.StartAsync(cancellationToken);
}
public static void AddJobsServices(IServiceCollection services, bool selfHosted)
{
if (!selfHosted)
{
services.AddTransient<AliveJob>();
}
services.AddTransient<DatabaseUpdateStatisticsJob>();
services.AddTransient<DatabaseRebuildlIndexesJob>();
services.AddTransient<DatabaseExpiredGrantsJob>();
services.AddTransient<DatabaseExpiredSponsorshipsJob>();
services.AddTransient<DeleteSendsJob>();
services.AddTransient<DeleteCiphersJob>();
} }
} }

View File

@ -1,10 +1,11 @@
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class BillingInformationModel
{ {
public BillingInfo BillingInfo { get; set; } public class BillingInformationModel
public Guid? UserId { get; set; } {
public Guid? OrganizationId { get; set; } public BillingInfo BillingInfo { get; set; }
public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; }
}
} }

View File

@ -1,26 +1,27 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class ChargeBraintreeModel : IValidatableObject
{ {
[Required] public class ChargeBraintreeModel : IValidatableObject
[Display(Name = "Braintree Customer Id")]
public string Id { get; set; }
[Required]
[Display(Name = "Amount")]
public decimal? Amount { get; set; }
public string TransactionId { get; set; }
public string PayPalTransactionId { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
if (Id != null) [Required]
[Display(Name = "Braintree Customer Id")]
public string Id { get; set; }
[Required]
[Display(Name = "Amount")]
public decimal? Amount { get; set; }
public string TransactionId { get; set; }
public string PayPalTransactionId { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || if (Id != null)
!Guid.TryParse(Id.Substring(1, 32), out var guid))
{ {
yield return new ValidationResult("Customer Id is not a valid format."); if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') ||
!Guid.TryParse(Id.Substring(1, 32), out var guid))
{
yield return new ValidationResult("Customer Id is not a valid format.");
}
} }
} }
} }

View File

@ -1,12 +1,13 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class CreateProviderModel
{ {
public CreateProviderModel() { } public class CreateProviderModel
{
public CreateProviderModel() { }
[Display(Name = "Owner Email")] [Display(Name = "Owner Email")]
[Required] [Required]
public string OwnerEmail { get; set; } public string OwnerEmail { get; set; }
}
} }

View File

@ -2,76 +2,77 @@
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class CreateUpdateTransactionModel : IValidatableObject
{ {
public CreateUpdateTransactionModel() { } public class CreateUpdateTransactionModel : IValidatableObject
public CreateUpdateTransactionModel(Transaction transaction)
{ {
Edit = true; public CreateUpdateTransactionModel() { }
UserId = transaction.UserId;
OrganizationId = transaction.OrganizationId;
Amount = transaction.Amount;
RefundedAmount = transaction.RefundedAmount;
Refunded = transaction.Refunded.GetValueOrDefault();
Details = transaction.Details;
Date = transaction.CreationDate;
PaymentMethod = transaction.PaymentMethodType;
Gateway = transaction.Gateway;
GatewayId = transaction.GatewayId;
Type = transaction.Type;
}
public bool Edit { get; set; } public CreateUpdateTransactionModel(Transaction transaction)
[Display(Name = "User Id")]
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Required]
public decimal? Amount { get; set; }
[Display(Name = "Refunded Amount")]
public decimal? RefundedAmount { get; set; }
public bool Refunded { get; set; }
[Required]
public string Details { get; set; }
[Required]
public DateTime? Date { get; set; }
[Display(Name = "Payment Method")]
public PaymentMethodType? PaymentMethod { get; set; }
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Id")]
public string GatewayId { get; set; }
[Required]
public TransactionType? Type { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue))
{ {
yield return new ValidationResult("Must provide either User Id, or Organization Id."); Edit = true;
UserId = transaction.UserId;
OrganizationId = transaction.OrganizationId;
Amount = transaction.Amount;
RefundedAmount = transaction.RefundedAmount;
Refunded = transaction.Refunded.GetValueOrDefault();
Details = transaction.Details;
Date = transaction.CreationDate;
PaymentMethod = transaction.PaymentMethodType;
Gateway = transaction.Gateway;
GatewayId = transaction.GatewayId;
Type = transaction.Type;
}
public bool Edit { get; set; }
[Display(Name = "User Id")]
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Required]
public decimal? Amount { get; set; }
[Display(Name = "Refunded Amount")]
public decimal? RefundedAmount { get; set; }
public bool Refunded { get; set; }
[Required]
public string Details { get; set; }
[Required]
public DateTime? Date { get; set; }
[Display(Name = "Payment Method")]
public PaymentMethodType? PaymentMethod { get; set; }
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Id")]
public string GatewayId { get; set; }
[Required]
public TransactionType? Type { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue))
{
yield return new ValidationResult("Must provide either User Id, or Organization Id.");
}
}
public Transaction ToTransaction(Guid? id = null)
{
return new Transaction
{
Id = id.GetValueOrDefault(),
UserId = UserId,
OrganizationId = OrganizationId,
Amount = Amount.Value,
RefundedAmount = RefundedAmount,
Refunded = Refunded ? true : (bool?)null,
Details = Details,
CreationDate = Date.Value,
PaymentMethodType = PaymentMethod,
Gateway = Gateway,
GatewayId = GatewayId,
Type = Type.Value
};
} }
} }
public Transaction ToTransaction(Guid? id = null)
{
return new Transaction
{
Id = id.GetValueOrDefault(),
UserId = UserId,
OrganizationId = OrganizationId,
Amount = Amount.Value,
RefundedAmount = RefundedAmount,
Refunded = Refunded ? true : (bool?)null,
Details = Details,
CreationDate = Date.Value,
PaymentMethodType = PaymentMethod,
Gateway = Gateway,
GatewayId = GatewayId,
Type = Type.Value
};
}
} }

View File

@ -1,9 +1,10 @@
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class CursorPagedModel<T>
{ {
public List<T> Items { get; set; } public class CursorPagedModel<T>
public int Count { get; set; } {
public string Cursor { get; set; } public List<T> Items { get; set; }
public string NextCursor { get; set; } public int Count { get; set; }
public string Cursor { get; set; }
public string NextCursor { get; set; }
}
} }

View File

@ -1,8 +1,9 @@
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class ErrorViewModel
{ {
public string RequestId { get; set; } public class ErrorViewModel
{
public string RequestId { get; set; }
public bool ShowRequestId => !string.IsNullOrEmpty(RequestId); public bool ShowRequestId => !string.IsNullOrEmpty(RequestId);
}
} }

View File

@ -1,9 +1,10 @@
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class HomeModel
{ {
public string CurrentVersion { get; set; } public class HomeModel
public GlobalSettings GlobalSettings { get; set; } {
public string CurrentVersion { get; set; }
public GlobalSettings GlobalSettings { get; set; }
}
} }

View File

@ -1,34 +1,35 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class LicenseModel : IValidatableObject
{ {
[Display(Name = "User Id")] public class LicenseModel : IValidatableObject
public Guid? UserId { get; set; }
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
[Display(Name = "Installation Id")]
public Guid? InstallationId { get; set; }
[Required]
[Display(Name = "Version")]
public int Version { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
if (UserId.HasValue && OrganizationId.HasValue) [Display(Name = "User Id")]
{ public Guid? UserId { get; set; }
yield return new ValidationResult("Use either User Id or Organization Id. Not both."); [Display(Name = "Organization Id")]
} public Guid? OrganizationId { get; set; }
[Display(Name = "Installation Id")]
public Guid? InstallationId { get; set; }
[Required]
[Display(Name = "Version")]
public int Version { get; set; }
if (!UserId.HasValue && !OrganizationId.HasValue) public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{ {
yield return new ValidationResult("User Id or Organization Id is required."); if (UserId.HasValue && OrganizationId.HasValue)
} {
yield return new ValidationResult("Use either User Id or Organization Id. Not both.");
}
if (OrganizationId.HasValue && !InstallationId.HasValue) if (!UserId.HasValue && !OrganizationId.HasValue)
{ {
yield return new ValidationResult("Installation Id is required for organization licenses."); yield return new ValidationResult("User Id or Organization Id is required.");
}
if (OrganizationId.HasValue && !InstallationId.HasValue)
{
yield return new ValidationResult("Installation Id is required for organization licenses.");
}
} }
} }
} }

View File

@ -1,54 +1,55 @@
using Microsoft.Azure.Documents; using Microsoft.Azure.Documents;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class LogModel : Resource
{ {
public long EventIdHash { get; set; } public class LogModel : Resource
public string Level { get; set; }
public string Message { get; set; }
public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message;
public string MessageTemplate { get; set; }
public IDictionary<string, object> Properties { get; set; }
public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null;
}
public class LogDetailsModel : LogModel
{
public JObject Exception { get; set; }
public string ExceptionToString(JObject e)
{ {
if (e == null) public long EventIdHash { get; set; }
{ public string Level { get; set; }
return null; public string Message { get; set; }
} public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message;
public string MessageTemplate { get; set; }
public IDictionary<string, object> Properties { get; set; }
public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null;
}
var val = string.Empty; public class LogDetailsModel : LogModel
if (e["Message"] != null && e["Message"].ToObject<string>() != null) {
{ public JObject Exception { get; set; }
val += "Message:\n";
val += e["Message"] + "\n";
}
if (e["StackTrace"] != null && e["StackTrace"].ToObject<string>() != null) public string ExceptionToString(JObject e)
{ {
val += "\nStack Trace:\n"; if (e == null)
val += e["StackTrace"]; {
} return null;
else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject<string>() != null) }
{
val += "\nStack Trace String:\n";
val += e["StackTraceString"];
}
if (e["InnerException"] != null && e["InnerException"].ToObject<JObject>() != null) var val = string.Empty;
{ if (e["Message"] != null && e["Message"].ToObject<string>() != null)
val += "\n\n=== Inner Exception ===\n\n"; {
val += ExceptionToString(e["InnerException"].ToObject<JObject>()); val += "Message:\n";
} val += e["Message"] + "\n";
}
return val; if (e["StackTrace"] != null && e["StackTrace"].ToObject<string>() != null)
{
val += "\nStack Trace:\n";
val += e["StackTrace"];
}
else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject<string>() != null)
{
val += "\nStack Trace String:\n";
val += e["StackTraceString"];
}
if (e["InnerException"] != null && e["InnerException"].ToObject<JObject>() != null)
{
val += "\n\n=== Inner Exception ===\n\n";
val += ExceptionToString(e["InnerException"].ToObject<JObject>());
}
return val;
}
} }
} }

View File

@ -1,13 +1,14 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class LoginModel
{ {
[Required] public class LoginModel
[EmailAddress] {
public string Email { get; set; } [Required]
public string ReturnUrl { get; set; } [EmailAddress]
public string Error { get; set; } public string Email { get; set; }
public string Success { get; set; } public string ReturnUrl { get; set; }
public string Error { get; set; }
public string Success { get; set; }
}
} }

View File

@ -1,11 +1,12 @@
using Serilog.Events; using Serilog.Events;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class LogsModel : CursorPagedModel<LogModel>
{ {
public LogEventLevel? Level { get; set; } public class LogsModel : CursorPagedModel<LogModel>
public string Project { get; set; } {
public DateTime? Start { get; set; } public LogEventLevel? Level { get; set; }
public DateTime? End { get; set; } public string Project { get; set; }
public DateTime? Start { get; set; }
public DateTime? End { get; set; }
}
} }

View File

@ -6,147 +6,148 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class OrganizationEditModel : OrganizationViewModel
{ {
public OrganizationEditModel() { } public class OrganizationEditModel : OrganizationViewModel
public OrganizationEditModel(Organization org, IEnumerable<OrganizationUserUserDetails> orgUsers,
IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections, IEnumerable<Group> groups,
IEnumerable<Policy> policies, BillingInfo billingInfo, IEnumerable<OrganizationConnection> connections,
GlobalSettings globalSettings)
: base(org, connections, orgUsers, ciphers, collections, groups, policies)
{ {
BillingInfo = billingInfo; public OrganizationEditModel() { }
BraintreeMerchantId = globalSettings.Braintree.MerchantId;
Name = org.Name; public OrganizationEditModel(Organization org, IEnumerable<OrganizationUserUserDetails> orgUsers,
BusinessName = org.BusinessName; IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections, IEnumerable<Group> groups,
BillingEmail = org.BillingEmail; IEnumerable<Policy> policies, BillingInfo billingInfo, IEnumerable<OrganizationConnection> connections,
PlanType = org.PlanType; GlobalSettings globalSettings)
Plan = org.Plan; : base(org, connections, orgUsers, ciphers, collections, groups, policies)
Seats = org.Seats; {
MaxAutoscaleSeats = org.MaxAutoscaleSeats; BillingInfo = billingInfo;
MaxCollections = org.MaxCollections; BraintreeMerchantId = globalSettings.Braintree.MerchantId;
UsePolicies = org.UsePolicies;
UseSso = org.UseSso;
UseKeyConnector = org.UseKeyConnector;
UseScim = org.UseScim;
UseGroups = org.UseGroups;
UseDirectory = org.UseDirectory;
UseEvents = org.UseEvents;
UseTotp = org.UseTotp;
Use2fa = org.Use2fa;
UseApi = org.UseApi;
UseResetPassword = org.UseResetPassword;
SelfHost = org.SelfHost;
UsersGetPremium = org.UsersGetPremium;
MaxStorageGb = org.MaxStorageGb;
Gateway = org.Gateway;
GatewayCustomerId = org.GatewayCustomerId;
GatewaySubscriptionId = org.GatewaySubscriptionId;
Enabled = org.Enabled;
LicenseKey = org.LicenseKey;
ExpirationDate = org.ExpirationDate;
}
public BillingInfo BillingInfo { get; set; } Name = org.Name;
public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); BusinessName = org.BusinessName;
public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm"); BillingEmail = org.BillingEmail;
public string BraintreeMerchantId { get; set; } PlanType = org.PlanType;
Plan = org.Plan;
Seats = org.Seats;
MaxAutoscaleSeats = org.MaxAutoscaleSeats;
MaxCollections = org.MaxCollections;
UsePolicies = org.UsePolicies;
UseSso = org.UseSso;
UseKeyConnector = org.UseKeyConnector;
UseScim = org.UseScim;
UseGroups = org.UseGroups;
UseDirectory = org.UseDirectory;
UseEvents = org.UseEvents;
UseTotp = org.UseTotp;
Use2fa = org.Use2fa;
UseApi = org.UseApi;
UseResetPassword = org.UseResetPassword;
SelfHost = org.SelfHost;
UsersGetPremium = org.UsersGetPremium;
MaxStorageGb = org.MaxStorageGb;
Gateway = org.Gateway;
GatewayCustomerId = org.GatewayCustomerId;
GatewaySubscriptionId = org.GatewaySubscriptionId;
Enabled = org.Enabled;
LicenseKey = org.LicenseKey;
ExpirationDate = org.ExpirationDate;
}
[Required] public BillingInfo BillingInfo { get; set; }
[Display(Name = "Name")] public string RandomLicenseKey => CoreHelpers.SecureRandomString(20);
public string Name { get; set; } public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm");
[Display(Name = "Business Name")] public string BraintreeMerchantId { get; set; }
public string BusinessName { get; set; }
[Display(Name = "Billing Email")]
public string BillingEmail { get; set; }
[Required]
[Display(Name = "Plan")]
public PlanType? PlanType { get; set; }
[Required]
[Display(Name = "Plan Name")]
public string Plan { get; set; }
[Display(Name = "Seats")]
public int? Seats { get; set; }
[Display(Name = "Max. Autoscale Seats")]
public int? MaxAutoscaleSeats { get; set; }
[Display(Name = "Max. Collections")]
public short? MaxCollections { get; set; }
[Display(Name = "Policies")]
public bool UsePolicies { get; set; }
[Display(Name = "SSO")]
public bool UseSso { get; set; }
[Display(Name = "Key Connector with Customer Encryption")]
public bool UseKeyConnector { get; set; }
[Display(Name = "Groups")]
public bool UseGroups { get; set; }
[Display(Name = "Directory")]
public bool UseDirectory { get; set; }
[Display(Name = "Events")]
public bool UseEvents { get; set; }
[Display(Name = "TOTP")]
public bool UseTotp { get; set; }
[Display(Name = "2FA")]
public bool Use2fa { get; set; }
[Display(Name = "API")]
public bool UseApi { get; set; }
[Display(Name = "Reset Password")]
public bool UseResetPassword { get; set; }
[Display(Name = "SCIM")]
public bool UseScim { get; set; }
[Display(Name = "Self Host")]
public bool SelfHost { get; set; }
[Display(Name = "Users Get Premium")]
public bool UsersGetPremium { get; set; }
[Display(Name = "Max. Storage GB")]
public short? MaxStorageGb { get; set; }
[Display(Name = "Gateway")]
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Customer Id")]
public string GatewayCustomerId { get; set; }
[Display(Name = "Gateway Subscription Id")]
public string GatewaySubscriptionId { get; set; }
[Display(Name = "Enabled")]
public bool Enabled { get; set; }
[Display(Name = "License Key")]
public string LicenseKey { get; set; }
[Display(Name = "Expiration Date")]
public DateTime? ExpirationDate { get; set; }
public bool SalesAssistedTrialStarted { get; set; }
public Organization ToOrganization(Organization existingOrganization) [Required]
{ [Display(Name = "Name")]
existingOrganization.Name = Name; public string Name { get; set; }
existingOrganization.BusinessName = BusinessName; [Display(Name = "Business Name")]
existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); public string BusinessName { get; set; }
existingOrganization.PlanType = PlanType.Value; [Display(Name = "Billing Email")]
existingOrganization.Plan = Plan; public string BillingEmail { get; set; }
existingOrganization.Seats = Seats; [Required]
existingOrganization.MaxCollections = MaxCollections; [Display(Name = "Plan")]
existingOrganization.UsePolicies = UsePolicies; public PlanType? PlanType { get; set; }
existingOrganization.UseSso = UseSso; [Required]
existingOrganization.UseKeyConnector = UseKeyConnector; [Display(Name = "Plan Name")]
existingOrganization.UseScim = UseScim; public string Plan { get; set; }
existingOrganization.UseGroups = UseGroups; [Display(Name = "Seats")]
existingOrganization.UseDirectory = UseDirectory; public int? Seats { get; set; }
existingOrganization.UseEvents = UseEvents; [Display(Name = "Max. Autoscale Seats")]
existingOrganization.UseTotp = UseTotp; public int? MaxAutoscaleSeats { get; set; }
existingOrganization.Use2fa = Use2fa; [Display(Name = "Max. Collections")]
existingOrganization.UseApi = UseApi; public short? MaxCollections { get; set; }
existingOrganization.UseResetPassword = UseResetPassword; [Display(Name = "Policies")]
existingOrganization.SelfHost = SelfHost; public bool UsePolicies { get; set; }
existingOrganization.UsersGetPremium = UsersGetPremium; [Display(Name = "SSO")]
existingOrganization.MaxStorageGb = MaxStorageGb; public bool UseSso { get; set; }
existingOrganization.Gateway = Gateway; [Display(Name = "Key Connector with Customer Encryption")]
existingOrganization.GatewayCustomerId = GatewayCustomerId; public bool UseKeyConnector { get; set; }
existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId; [Display(Name = "Groups")]
existingOrganization.Enabled = Enabled; public bool UseGroups { get; set; }
existingOrganization.LicenseKey = LicenseKey; [Display(Name = "Directory")]
existingOrganization.ExpirationDate = ExpirationDate; public bool UseDirectory { get; set; }
existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats; [Display(Name = "Events")]
return existingOrganization; public bool UseEvents { get; set; }
[Display(Name = "TOTP")]
public bool UseTotp { get; set; }
[Display(Name = "2FA")]
public bool Use2fa { get; set; }
[Display(Name = "API")]
public bool UseApi { get; set; }
[Display(Name = "Reset Password")]
public bool UseResetPassword { get; set; }
[Display(Name = "SCIM")]
public bool UseScim { get; set; }
[Display(Name = "Self Host")]
public bool SelfHost { get; set; }
[Display(Name = "Users Get Premium")]
public bool UsersGetPremium { get; set; }
[Display(Name = "Max. Storage GB")]
public short? MaxStorageGb { get; set; }
[Display(Name = "Gateway")]
public GatewayType? Gateway { get; set; }
[Display(Name = "Gateway Customer Id")]
public string GatewayCustomerId { get; set; }
[Display(Name = "Gateway Subscription Id")]
public string GatewaySubscriptionId { get; set; }
[Display(Name = "Enabled")]
public bool Enabled { get; set; }
[Display(Name = "License Key")]
public string LicenseKey { get; set; }
[Display(Name = "Expiration Date")]
public DateTime? ExpirationDate { get; set; }
public bool SalesAssistedTrialStarted { get; set; }
public Organization ToOrganization(Organization existingOrganization)
{
existingOrganization.Name = Name;
existingOrganization.BusinessName = BusinessName;
existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim();
existingOrganization.PlanType = PlanType.Value;
existingOrganization.Plan = Plan;
existingOrganization.Seats = Seats;
existingOrganization.MaxCollections = MaxCollections;
existingOrganization.UsePolicies = UsePolicies;
existingOrganization.UseSso = UseSso;
existingOrganization.UseKeyConnector = UseKeyConnector;
existingOrganization.UseScim = UseScim;
existingOrganization.UseGroups = UseGroups;
existingOrganization.UseDirectory = UseDirectory;
existingOrganization.UseEvents = UseEvents;
existingOrganization.UseTotp = UseTotp;
existingOrganization.Use2fa = Use2fa;
existingOrganization.UseApi = UseApi;
existingOrganization.UseResetPassword = UseResetPassword;
existingOrganization.SelfHost = SelfHost;
existingOrganization.UsersGetPremium = UsersGetPremium;
existingOrganization.MaxStorageGb = MaxStorageGb;
existingOrganization.Gateway = Gateway;
existingOrganization.GatewayCustomerId = GatewayCustomerId;
existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId;
existingOrganization.Enabled = Enabled;
existingOrganization.LicenseKey = LicenseKey;
existingOrganization.ExpirationDate = ExpirationDate;
existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats;
return existingOrganization;
}
} }
} }

View File

@ -2,48 +2,49 @@
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.Data.Organizations.OrganizationUsers;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class OrganizationViewModel
{ {
public OrganizationViewModel() { } public class OrganizationViewModel
public OrganizationViewModel(Organization org, IEnumerable<OrganizationConnection> connections,
IEnumerable<OrganizationUserUserDetails> orgUsers, IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections,
IEnumerable<Group> groups, IEnumerable<Policy> policies)
{ {
Organization = org; public OrganizationViewModel() { }
Connections = connections ?? Enumerable.Empty<OrganizationConnection>();
HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null;
UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited);
UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted);
UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed);
UserCount = orgUsers.Count();
CipherCount = ciphers.Count();
CollectionCount = collections.Count();
GroupCount = groups?.Count() ?? 0;
PolicyCount = policies?.Count() ?? 0;
Owners = string.Join(", ",
orgUsers
.Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed)
.Select(u => u.Email));
Admins = string.Join(", ",
orgUsers
.Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed)
.Select(u => u.Email));
}
public Organization Organization { get; set; } public OrganizationViewModel(Organization org, IEnumerable<OrganizationConnection> connections,
public IEnumerable<OrganizationConnection> Connections { get; set; } IEnumerable<OrganizationUserUserDetails> orgUsers, IEnumerable<Cipher> ciphers, IEnumerable<Collection> collections,
public string Owners { get; set; } IEnumerable<Group> groups, IEnumerable<Policy> policies)
public string Admins { get; set; } {
public int UserInvitedCount { get; set; } Organization = org;
public int UserConfirmedCount { get; set; } Connections = connections ?? Enumerable.Empty<OrganizationConnection>();
public int UserAcceptedCount { get; set; } HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null;
public int UserCount { get; set; } UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited);
public int CipherCount { get; set; } UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted);
public int CollectionCount { get; set; } UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed);
public int GroupCount { get; set; } UserCount = orgUsers.Count();
public int PolicyCount { get; set; } CipherCount = ciphers.Count();
public bool HasPublicPrivateKeys { get; set; } CollectionCount = collections.Count();
GroupCount = groups?.Count() ?? 0;
PolicyCount = policies?.Count() ?? 0;
Owners = string.Join(", ",
orgUsers
.Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed)
.Select(u => u.Email));
Admins = string.Join(", ",
orgUsers
.Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed)
.Select(u => u.Email));
}
public Organization Organization { get; set; }
public IEnumerable<OrganizationConnection> Connections { get; set; }
public string Owners { get; set; }
public string Admins { get; set; }
public int UserInvitedCount { get; set; }
public int UserConfirmedCount { get; set; }
public int UserAcceptedCount { get; set; }
public int UserCount { get; set; }
public int CipherCount { get; set; }
public int CollectionCount { get; set; }
public int GroupCount { get; set; }
public int PolicyCount { get; set; }
public bool HasPublicPrivateKeys { get; set; }
}
} }

View File

@ -1,12 +1,13 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class OrganizationsModel : PagedModel<Organization>
{ {
public string Name { get; set; } public class OrganizationsModel : PagedModel<Organization>
public string UserEmail { get; set; } {
public bool? Paid { get; set; } public string Name { get; set; }
public string Action { get; set; } public string UserEmail { get; set; }
public bool SelfHosted { get; set; } public bool? Paid { get; set; }
public string Action { get; set; }
public bool SelfHosted { get; set; }
}
} }

View File

@ -1,10 +1,11 @@
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public abstract class PagedModel<T>
{ {
public List<T> Items { get; set; } public abstract class PagedModel<T>
public int Page { get; set; } {
public int Count { get; set; } public List<T> Items { get; set; }
public int? PreviousPage => Page < 2 ? (int?)null : Page - 1; public int Page { get; set; }
public int? NextPage => Items.Count < Count ? (int?)null : Page + 1; public int Count { get; set; }
public int? PreviousPage => Page < 2 ? (int?)null : Page - 1;
public int? NextPage => Items.Count < Count ? (int?)null : Page + 1;
}
} }

View File

@ -1,13 +1,14 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class PromoteAdminModel
{ {
[Required] public class PromoteAdminModel
[Display(Name = "Admin User Id")] {
public Guid? UserId { get; set; } [Required]
[Required] [Display(Name = "Admin User Id")]
[Display(Name = "Organization Id")] public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; } [Required]
[Display(Name = "Organization Id")]
public Guid? OrganizationId { get; set; }
}
} }

View File

@ -2,32 +2,33 @@
using Bit.Core.Entities.Provider; using Bit.Core.Entities.Provider;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class ProviderEditModel : ProviderViewModel
{ {
public ProviderEditModel() { } public class ProviderEditModel : ProviderViewModel
public ProviderEditModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
: base(provider, providerUsers, organizations)
{ {
Name = provider.Name; public ProviderEditModel() { }
BusinessName = provider.BusinessName;
BillingEmail = provider.BillingEmail;
}
[Display(Name = "Billing Email")] public ProviderEditModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
public string BillingEmail { get; set; } : base(provider, providerUsers, organizations)
[Display(Name = "Business Name")] {
public string BusinessName { get; set; } Name = provider.Name;
public string Name { get; set; } BusinessName = provider.BusinessName;
[Display(Name = "Events")] BillingEmail = provider.BillingEmail;
}
public Provider ToProvider(Provider existingProvider) [Display(Name = "Billing Email")]
{ public string BillingEmail { get; set; }
existingProvider.Name = Name; [Display(Name = "Business Name")]
existingProvider.BusinessName = BusinessName; public string BusinessName { get; set; }
existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); public string Name { get; set; }
return existingProvider; [Display(Name = "Events")]
public Provider ToProvider(Provider existingProvider)
{
existingProvider.Name = Name;
existingProvider.BusinessName = BusinessName;
existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim();
return existingProvider;
}
} }
} }

View File

@ -2,23 +2,24 @@
using Bit.Core.Enums.Provider; using Bit.Core.Enums.Provider;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class ProviderViewModel
{ {
public ProviderViewModel() { } public class ProviderViewModel
public ProviderViewModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
{ {
Provider = provider; public ProviderViewModel() { }
UserCount = providerUsers.Count();
ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin);
ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id); public ProviderViewModel(Provider provider, IEnumerable<ProviderUserUserDetails> providerUsers, IEnumerable<ProviderOrganizationOrganizationDetails> organizations)
{
Provider = provider;
UserCount = providerUsers.Count();
ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin);
ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id);
}
public int UserCount { get; set; }
public Provider Provider { get; set; }
public IEnumerable<ProviderUserUserDetails> ProviderAdmins { get; set; }
public IEnumerable<ProviderOrganizationOrganizationDetails> ProviderOrganizations { get; set; }
} }
public int UserCount { get; set; }
public Provider Provider { get; set; }
public IEnumerable<ProviderUserUserDetails> ProviderAdmins { get; set; }
public IEnumerable<ProviderOrganizationOrganizationDetails> ProviderOrganizations { get; set; }
} }

View File

@ -1,12 +1,13 @@
using Bit.Core.Entities.Provider; using Bit.Core.Entities.Provider;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class ProvidersModel : PagedModel<Provider>
{ {
public string Name { get; set; } public class ProvidersModel : PagedModel<Provider>
public string UserEmail { get; set; } {
public bool? Paid { get; set; } public string Name { get; set; }
public string Action { get; set; } public string UserEmail { get; set; }
public bool SelfHosted { get; set; } public bool? Paid { get; set; }
public string Action { get; set; }
public bool SelfHosted { get; set; }
}
} }

View File

@ -1,42 +1,43 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using Bit.Core.Models.BitStripe; using Bit.Core.Models.BitStripe;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class StripeSubscriptionRowModel
{ {
public Stripe.Subscription Subscription { get; set; } public class StripeSubscriptionRowModel
public bool Selected { get; set; }
public StripeSubscriptionRowModel() { }
public StripeSubscriptionRowModel(Stripe.Subscription subscription)
{ {
Subscription = subscription; public Stripe.Subscription Subscription { get; set; }
} public bool Selected { get; set; }
}
public enum StripeSubscriptionsAction public StripeSubscriptionRowModel() { }
{ public StripeSubscriptionRowModel(Stripe.Subscription subscription)
Search,
PreviousPage,
NextPage,
Export,
BulkCancel
}
public class StripeSubscriptionsModel : IValidatableObject
{
public List<StripeSubscriptionRowModel> Items { get; set; }
public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search;
public string Message { get; set; }
public List<Stripe.Price> Prices { get; set; }
public List<Stripe.TestHelpers.TestClock> TestClocks { get; set; }
public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions();
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid")
{ {
yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); Subscription = subscription;
}
}
public enum StripeSubscriptionsAction
{
Search,
PreviousPage,
NextPage,
Export,
BulkCancel
}
public class StripeSubscriptionsModel : IValidatableObject
{
public List<StripeSubscriptionRowModel> Items { get; set; }
public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search;
public string Message { get; set; }
public List<Stripe.Price> Prices { get; set; }
public List<Stripe.TestHelpers.TestClock> TestClocks { get; set; }
public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions();
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid")
{
yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions");
}
} }
} }
} }

View File

@ -1,10 +1,11 @@
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class TaxRateAddEditModel
{ {
public string StripeTaxRateId { get; set; } public class TaxRateAddEditModel
public string Country { get; set; } {
public string State { get; set; } public string StripeTaxRateId { get; set; }
public string PostalCode { get; set; } public string Country { get; set; }
public decimal Rate { get; set; } public string State { get; set; }
public string PostalCode { get; set; }
public decimal Rate { get; set; }
}
} }

View File

@ -1,8 +1,9 @@
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Admin.Models; namespace Bit.Admin.Models
public class TaxRatesModel : PagedModel<TaxRate>
{ {
public string Message { get; set; } public class TaxRatesModel : PagedModel<TaxRate>
{
public string Message { get; set; }
}
} }

Some files were not shown because too many files have changed in this diff Show More