1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-17 15:40:59 -05:00

Merge branch 'main' into ac/pm-18238/add-requiretwofactorpolicyrequirement

# Conflicts:
#	src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs
This commit is contained in:
Rui Tome
2025-05-20 10:25:45 +01:00
393 changed files with 17349 additions and 5548 deletions

View File

@ -114,6 +114,11 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable,
/// </summary>
public bool UseRiskInsights { get; set; }
/// <summary>
/// If true, the organization can claim domains, which unlocks additional enterprise features
/// </summary>
public bool UseOrganizationDomains { get; set; }
/// <summary>
/// If set to true, admins can initiate organization-issued sponsorships.
/// </summary>
@ -319,5 +324,7 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable,
SmSeats = license.SmSeats;
SmServiceAccounts = license.SmServiceAccounts;
UseRiskInsights = license.UseRiskInsights;
UseOrganizationDomains = license.UseOrganizationDomains;
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies;
}
}

View File

@ -26,6 +26,7 @@ public class OrganizationAbility
LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UseRiskInsights = organization.UseRiskInsights;
UseOrganizationDomains = organization.UseOrganizationDomains;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
}
@ -46,5 +47,6 @@ public class OrganizationAbility
public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
}

View File

@ -59,5 +59,7 @@ public class OrganizationUserOrganizationDetails
public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
public bool? IsAdminInitiated { get; set; }
}

View File

@ -150,6 +150,7 @@ public class SelfHostedOrganizationDetails : Organization
AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems,
Status = Status,
UseRiskInsights = UseRiskInsights,
UseAdminSponsoredFamilies = UseAdminSponsoredFamilies,
};
}
}

View File

@ -45,6 +45,7 @@ public class ProviderUserOrganizationDetails
public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
public ProviderType ProviderType { get; set; }
}

View File

@ -20,7 +20,6 @@ public class VerifyOrganizationDomainCommand(
IDnsResolverService dnsResolverService,
IEventService eventService,
IGlobalSettings globalSettings,
IFeatureService featureService,
ICurrentContext currentContext,
ISavePolicyCommand savePolicyCommand,
IMailService mailService,
@ -125,11 +124,8 @@ public class VerifyOrganizationDomainCommand(
private async Task DomainVerificationSideEffectsAsync(OrganizationDomain domain, IActingUser actingUser)
{
if (featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning))
{
await EnableSingleOrganizationPolicyAsync(domain.OrganizationId, actingUser);
await SendVerifiedDomainUserEmailAsync(domain);
}
await EnableSingleOrganizationPolicyAsync(domain.OrganizationId, actingUser);
await SendVerifiedDomainUserEmailAsync(domain);
}
private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) =>

View File

@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
@ -25,6 +26,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand
private readonly IPolicyService _policyService;
private readonly IMailService _mailService;
private readonly IUserRepository _userRepository;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
private readonly IDataProtectorTokenFactory<OrgUserInviteTokenable> _orgUserInviteTokenDataFactory;
private readonly IFeatureService _featureService;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
@ -37,6 +39,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand
IPolicyService policyService,
IMailService mailService,
IUserRepository userRepository,
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IDataProtectorTokenFactory<OrgUserInviteTokenable> orgUserInviteTokenDataFactory,
IFeatureService featureService,
IPolicyRequirementQuery policyRequirementQuery)
@ -49,6 +52,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand
_policyService = policyService;
_mailService = mailService;
_userRepository = userRepository;
_twoFactorIsEnabledQuery = twoFactorIsEnabledQuery;
_orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory;
_featureService = featureService;
_policyRequirementQuery = policyRequirementQuery;
@ -198,7 +202,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand
}
// Enforce Two Factor Authentication Policy of organization user is trying to join
if (!await userService.TwoFactorIsEnabledAsync(user))
if (!await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user))
{
await ValidateTwoFactorAuthenticationPolicy(user, orgUser.OrganizationId);
}

View File

@ -24,9 +24,7 @@ public class GetOrganizationUsersClaimedStatusQuery : IGetOrganizationUsersClaim
// Users can only be claimed by an Organization that is enabled and can have organization domains
var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId);
// TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622).
// Verified domains were tied to SSO, so we currently check the "UseSso" organization ability.
if (organizationAbility is { Enabled: true, UseSso: true })
if (organizationAbility is { Enabled: true, UseOrganizationDomains: true })
{
// Get all organization users with claimed domains by the organization
var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId);

View File

@ -1,4 +1,5 @@
using Bit.Core.Entities;
using Bit.Core.Exceptions;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
using Bit.Core.Models.Commands;
using Bit.Core.AdminConsole.Utilities.Commands;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;

View File

@ -1,4 +1,4 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.Utilities.Errors;
using Bit.Core.Exceptions;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.Models.Commands;
using Bit.Core.AdminConsole.Utilities.Commands;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers;

View File

@ -1,17 +1,17 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.Interfaces;
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Commands;
using Bit.Core.AdminConsole.Utilities.Errors;
using Bit.Core.AdminConsole.Utilities.Validation;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Models.Commands;
using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface;
using Bit.Core.Repositories;
using Bit.Core.Services;
@ -50,11 +50,11 @@ public class InviteOrganizationUsersCommand(IEventService eventService,
{
case Failure<InviteOrganizationUsersResponse> failure:
return new Failure<ScimInviteOrganizationUsersResponse>(
failure.Errors.Select(error => new Error<ScimInviteOrganizationUsersResponse>(error.Message,
new Error<ScimInviteOrganizationUsersResponse>(failure.Error.Message,
new ScimInviteOrganizationUsersResponse
{
InvitedUser = error.ErroredValue.InvitedUsers.FirstOrDefault()
})));
InvitedUser = failure.Error.ErroredValue.InvitedUsers.FirstOrDefault()
}));
case Success<InviteOrganizationUsersResponse> success when success.Value.InvitedUsers.Any():
var user = success.Value.InvitedUsers.First();

View File

@ -1,4 +1,4 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.GlobalSettings;

View File

@ -1,4 +1,4 @@
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Validation;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.GlobalSettings;

View File

@ -1,7 +1,7 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager;
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Errors;
using Bit.Core.AdminConsole.Utilities.Validation;
using Bit.Core.Models.Business;
using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface;
using Bit.Core.Repositories;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Organization;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Validation;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Organization;

View File

@ -1,4 +1,4 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager;

View File

@ -4,7 +4,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.V
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Organization;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Validation;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Models;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Payments;

View File

@ -1,6 +1,6 @@
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Payments;
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Validation;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;

View File

@ -1,4 +1,4 @@
using Bit.Core.AdminConsole.Errors;
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Provider;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Shared.Validation;
using Bit.Core.AdminConsole.Utilities.Validation;
using Bit.Core.Billing.Extensions;
namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Provider;

View File

@ -159,7 +159,7 @@ public class RemoveOrganizationUserCommand : IRemoveOrganizationUserCommand
throw new BadRequestException(RemoveAdminByCustomUserErrorMessage);
}
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) && deletingUserId.HasValue && eventSystemUser == null)
if (deletingUserId.HasValue && eventSystemUser == null)
{
var claimedStatus = await _getOrganizationUsersClaimedStatusQuery.GetUsersOrganizationClaimedStatusAsync(orgUser.OrganizationId, new[] { orgUser.Id });
if (claimedStatus.TryGetValue(orgUser.Id, out var isClaimed) && isClaimed)
@ -214,7 +214,7 @@ public class RemoveOrganizationUserCommand : IRemoveOrganizationUserCommand
deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId);
}
var claimedStatus = _featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) && deletingUserId.HasValue && eventSystemUser == null
var claimedStatus = deletingUserId.HasValue && eventSystemUser == null
? await _getOrganizationUsersClaimedStatusQuery.GetUsersOrganizationClaimedStatusAsync(organizationId, filteredUsers.Select(u => u.Id))
: filteredUsers.ToDictionary(u => u.Id, u => false);
var result = new List<(OrganizationUser OrganizationUser, string ErrorMessage)>();

View File

@ -1,8 +1,8 @@
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
using Bit.Core.AdminConsole.Utilities.Commands;
using Bit.Core.Enums;
using Bit.Core.Models.Commands;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Services;

View File

@ -104,7 +104,8 @@ public class CloudOrganizationSignUpCommand(
RevisionDate = DateTime.UtcNow,
Status = OrganizationStatusType.Created,
UsePasswordManager = true,
UseSecretsManager = signup.UseSecretsManager
UseSecretsManager = signup.UseSecretsManager,
UseOrganizationDomains = plan.HasOrganizationDomains,
};
if (signup.UseSecretsManager)

View File

@ -61,16 +61,9 @@ public class SingleOrgPolicyValidator : IPolicyValidator
{
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
{
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning))
{
var currentUser = _currentContext.UserId ?? Guid.Empty;
var isOwnerOrProvider = await _currentContext.OrganizationOwner(policyUpdate.OrganizationId);
await RevokeNonCompliantUsersAsync(policyUpdate.OrganizationId, policyUpdate.PerformedBy ?? new StandardUser(currentUser, isOwnerOrProvider));
}
else
{
await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId);
}
var currentUser = _currentContext.UserId ?? Guid.Empty;
var isOwnerOrProvider = await _currentContext.OrganizationOwner(policyUpdate.OrganizationId);
await RevokeNonCompliantUsersAsync(policyUpdate.OrganizationId, policyUpdate.PerformedBy ?? new StandardUser(currentUser, isOwnerOrProvider));
}
}
@ -116,42 +109,6 @@ public class SingleOrgPolicyValidator : IPolicyValidator
_mailService.SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), x.Email)));
}
private async Task RemoveNonCompliantUsersAsync(Guid organizationId)
{
// Remove non-compliant users
var savingUserId = _currentContext.UserId;
// Note: must get OrganizationUserUserDetails so that Email is always populated from the User object
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var org = await _organizationRepository.GetByIdAsync(organizationId);
if (org == null)
{
throw new NotFoundException(OrganizationNotFoundErrorMessage);
}
var removableOrgUsers = orgUsers.Where(ou =>
ou.Status != OrganizationUserStatusType.Invited &&
ou.Status != OrganizationUserStatusType.Revoked &&
ou.Type != OrganizationUserType.Owner &&
ou.Type != OrganizationUserType.Admin &&
ou.UserId != savingUserId
).ToList();
var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(
removableOrgUsers.Select(ou => ou.UserId!.Value));
foreach (var orgUser in removableOrgUsers)
{
if (userOrgs.Any(ou => ou.UserId == orgUser.UserId
&& ou.OrganizationId != org.Id
&& ou.Status != OrganizationUserStatusType.Invited))
{
await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id, savingUserId);
await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(
org.DisplayName(), orgUser.Email);
}
}
}
public async Task<string> ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
if (policyUpdate is not { Enabled: true })
@ -165,8 +122,7 @@ public class SingleOrgPolicyValidator : IPolicyValidator
return validateDecryptionErrorMessage;
}
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)
&& await _organizationHasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policyUpdate.OrganizationId))
if (await _organizationHasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policyUpdate.OrganizationId))
{
return ClaimedDomainSingleOrganizationRequiredErrorMessage;
}

View File

@ -23,8 +23,6 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
private readonly IOrganizationRepository _organizationRepository;
private readonly ICurrentContext _currentContext;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly IFeatureService _featureService;
private readonly IRevokeNonCompliantOrganizationUserCommand _revokeNonCompliantOrganizationUserCommand;
public const string NonCompliantMembersWillLoseAccessMessage = "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.";
@ -38,8 +36,6 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
IOrganizationRepository organizationRepository,
ICurrentContext currentContext,
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IRemoveOrganizationUserCommand removeOrganizationUserCommand,
IFeatureService featureService,
IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand)
{
_organizationUserRepository = organizationUserRepository;
@ -47,8 +43,6 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
_organizationRepository = organizationRepository;
_currentContext = currentContext;
_twoFactorIsEnabledQuery = twoFactorIsEnabledQuery;
_removeOrganizationUserCommand = removeOrganizationUserCommand;
_featureService = featureService;
_revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand;
}
@ -56,16 +50,9 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
{
if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true })
{
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning))
{
var currentUser = _currentContext.UserId ?? Guid.Empty;
var isOwnerOrProvider = await _currentContext.OrganizationOwner(policyUpdate.OrganizationId);
await RevokeNonCompliantUsersAsync(policyUpdate.OrganizationId, policyUpdate.PerformedBy ?? new StandardUser(currentUser, isOwnerOrProvider));
}
else
{
await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId);
}
var currentUser = _currentContext.UserId ?? Guid.Empty;
var isOwnerOrProvider = await _currentContext.OrganizationOwner(policyUpdate.OrganizationId);
await RevokeNonCompliantUsersAsync(policyUpdate.OrganizationId, policyUpdate.PerformedBy ?? new StandardUser(currentUser, isOwnerOrProvider));
}
}
@ -121,40 +108,6 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator
_mailService.SendOrganizationUserRevokedForTwoFactorPolicyEmailAsync(organization.DisplayName(), x.Email)));
}
private async Task RemoveNonCompliantUsersAsync(Guid organizationId)
{
var org = await _organizationRepository.GetByIdAsync(organizationId);
var savingUserId = _currentContext.UserId;
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var organizationUsersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(orgUsers);
var removableOrgUsers = orgUsers.Where(ou =>
ou.Status != OrganizationUserStatusType.Invited && ou.Status != OrganizationUserStatusType.Revoked &&
ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin &&
ou.UserId != savingUserId);
// Reorder by HasMasterPassword to prioritize checking users without a master if they have 2FA enabled
foreach (var orgUser in removableOrgUsers.OrderBy(ou => ou.HasMasterPassword))
{
var userTwoFactorEnabled = organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == orgUser.Id)
.twoFactorIsEnabled;
if (!userTwoFactorEnabled)
{
if (!orgUser.HasMasterPassword)
{
throw new BadRequestException(
"Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.");
}
await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id,
savingUserId);
await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(
org!.DisplayName(), orgUser.Email);
}
}
}
private static bool MembersWithNoMasterPasswordWillLoseAccess(
IEnumerable<OrganizationUserUserDetails> orgUserDetails,
IEnumerable<(OrganizationUserUserDetails user, bool isTwoFactorEnabled)> organizationUsersTwoFactorEnabled) =>

View File

@ -11,8 +11,6 @@ namespace Bit.Core.Services;
public interface IOrganizationService
{
Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType,
TaxInfo taxInfo);
Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null);
Task ReinstateSubscriptionAsync(Guid organizationId);
Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb);

View File

@ -93,16 +93,8 @@ public class OrganizationDomainService : IOrganizationDomainService
//Send email to administrators
if (adminEmails.Count > 0)
{
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning))
{
await _mailService.SendUnclaimedOrganizationDomainEmailAsync(adminEmails,
domain.OrganizationId.ToString(), domain.DomainName);
}
else
{
await _mailService.SendUnverifiedOrganizationDomainEmailAsync(adminEmails,
domain.OrganizationId.ToString(), domain.DomainName);
}
await _mailService.SendUnclaimedOrganizationDomainEmailAsync(adminEmails,
domain.OrganizationId.ToString(), domain.DomainName);
}
_logger.LogInformation(Constants.BypassFiltersEventId, "Expired domain: {domainName}", domain.DomainName);

View File

@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService
_sendOrganizationInvitesCommand = sendOrganizationInvitesCommand;
}
public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken,
PaymentMethodType paymentMethodType, TaxInfo taxInfo)
{
var organization = await GetOrgById(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _paymentService.SaveTaxInfoAsync(organization, taxInfo);
var updated = await _paymentService.UpdatePaymentMethodAsync(
organization,
paymentMethodType,
paymentToken,
taxInfo);
if (updated)
{
await ReplaceAndUpdateCacheAsync(organization);
}
}
public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null)
{
var organization = await GetOrgById(organizationId);
@ -449,6 +428,7 @@ public class OrganizationService : IOrganizationService
MaxStorageGb = 1,
UsePolicies = plan.HasPolicies,
UseSso = plan.HasSso,
UseOrganizationDomains = plan.HasOrganizationDomains,
UseGroups = plan.HasGroups,
UseEvents = plan.HasEvents,
UseDirectory = plan.HasDirectory,
@ -570,6 +550,8 @@ public class OrganizationService : IOrganizationService
SmSeats = license.SmSeats,
SmServiceAccounts = license.SmServiceAccounts,
UseRiskInsights = license.UseRiskInsights,
UseOrganizationDomains = license.UseOrganizationDomains,
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies,
};
var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false);

View File

@ -1,44 +0,0 @@
using Bit.Core.AdminConsole.Errors;
namespace Bit.Core.AdminConsole.Shared.Validation;
public abstract record ValidationResult<T>;
public record Valid<T> : ValidationResult<T>
{
public Valid() { }
public Valid(T Value)
{
this.Value = Value;
}
public T Value { get; init; }
}
public record Invalid<T> : ValidationResult<T>
{
public IEnumerable<Error<T>> Errors { get; init; } = [];
public string ErrorMessageString => string.Join(" ", Errors.Select(e => e.Message));
public Invalid() { }
public Invalid(Error<T> error) : this([error]) { }
public Invalid(IEnumerable<Error<T>> errors)
{
Errors = errors;
}
}
public static class ValidationResultMappers
{
public static ValidationResult<B> Map<A, B>(this ValidationResult<A> validationResult, B invalidValue) =>
validationResult switch
{
Valid<A> => new Valid<B>(invalidValue),
Invalid<A> invalid => new Invalid<B>(invalid.Errors.Select(x => x.ToError(invalidValue))),
_ => throw new ArgumentOutOfRangeException(nameof(validationResult), "Unhandled validation result type")
};
}

View File

@ -0,0 +1,51 @@
#nullable enable
using Bit.Core.AdminConsole.Utilities.Errors;
using Bit.Core.AdminConsole.Utilities.Validation;
namespace Bit.Core.AdminConsole.Utilities.Commands;
public abstract class CommandResult<T>;
public class Success<T>(T value) : CommandResult<T>
{
public T Value { get; } = value;
}
public class Failure<T>(Error<T> error) : CommandResult<T>
{
public Error<T> Error { get; } = error;
}
public class Partial<T>(IEnumerable<T> successfulItems, IEnumerable<Error<T>> failedItems)
: CommandResult<T>
{
public IEnumerable<T> Successes { get; } = successfulItems;
public IEnumerable<Error<T>> Failures { get; } = failedItems;
}
public static class CommandResultExtensions
{
/// <summary>
/// This is to help map between the InvalidT ValidationResult and the FailureT CommandResult types.
///
/// </summary>
/// <param name="invalidResult">This is the invalid type from validating the object.</param>
/// <param name="mappingFunction">This function will map between the two types for the inner ErrorT</param>
/// <typeparam name="A">Invalid object's type</typeparam>
/// <typeparam name="B">Failure object's type</typeparam>
/// <returns></returns>
public static CommandResult<B> MapToFailure<A, B>(this Invalid<A> invalidResult, Func<A, B> mappingFunction) =>
new Failure<B>(invalidResult.Error.ToError(mappingFunction(invalidResult.Error.ErroredValue)));
}
[Obsolete("Use CommandResult<T> instead. This will be removed once old code is updated.")]
public class CommandResult(IEnumerable<string> errors)
{
public CommandResult(string error) : this([error]) { }
public bool Success => ErrorMessages.Count == 0;
public bool HasErrors => ErrorMessages.Count > 0;
public List<string> ErrorMessages { get; } = errors.ToList();
public CommandResult() : this(Array.Empty<string>()) { }
}

View File

@ -1,4 +1,4 @@
namespace Bit.Core.AdminConsole.Errors;
namespace Bit.Core.AdminConsole.Utilities.Errors;
public record Error<T>(string Message, T ErroredValue);

View File

@ -1,4 +1,4 @@
namespace Bit.Core.AdminConsole.Errors;
namespace Bit.Core.AdminConsole.Utilities.Errors;
public record InsufficientPermissionsError<T>(string Message, T ErroredValue) : Error<T>(Message, ErroredValue)
{

View File

@ -1,4 +1,4 @@
namespace Bit.Core.AdminConsole.Errors;
namespace Bit.Core.AdminConsole.Utilities.Errors;
public record InvalidResultTypeError<T>(T Value) : Error<T>(Code, Value)
{

View File

@ -1,4 +1,4 @@
namespace Bit.Core.AdminConsole.Errors;
namespace Bit.Core.AdminConsole.Utilities.Errors;
public record RecordNotFoundError<T>(string Message, T ErroredValue) : Error<T>(Message, ErroredValue)
{

View File

@ -1,4 +1,4 @@
namespace Bit.Core.AdminConsole.Shared.Validation;
namespace Bit.Core.AdminConsole.Utilities.Validation;
public interface IValidator<T>
{

View File

@ -0,0 +1,20 @@
using Bit.Core.AdminConsole.Utilities.Errors;
namespace Bit.Core.AdminConsole.Utilities.Validation;
public abstract record ValidationResult<T>;
public record Valid<T>(T Value) : ValidationResult<T>;
public record Invalid<T>(Error<T> Error) : ValidationResult<T>;
public static class ValidationResultMappers
{
public static ValidationResult<B> Map<A, B>(this ValidationResult<A> validationResult, B invalidValue) =>
validationResult switch
{
Valid<A> => new Valid<B>(invalidValue),
Invalid<A> invalid => new Invalid<B>(invalid.Error.ToError(invalidValue)),
_ => throw new ArgumentOutOfRangeException(nameof(validationResult), "Unhandled validation result type")
};
}

View File

@ -2,9 +2,24 @@
public enum EmergencyAccessStatusType : byte
{
/// <summary>
/// The user has been invited to be an emergency contact.
/// </summary>
Invited = 0,
/// <summary>
/// The invited user, "grantee", has accepted the request to be an emergency contact.
/// </summary>
Accepted = 1,
/// <summary>
/// The inviting user, "grantor", has approved the grantee's acceptance.
/// </summary>
Confirmed = 2,
/// <summary>
/// The grantee has initiated the recovery process.
/// </summary>
RecoveryInitiated = 3,
/// <summary>
/// The grantee has excercised their emergency access.
/// </summary>
RecoveryApproved = 4,
}

View File

@ -6,7 +6,8 @@ public enum TwoFactorProviderType : byte
Email = 1,
Duo = 2,
YubiKey = 3,
U2f = 4, // Deprecated
[Obsolete("Deprecated in favor of WebAuthn.")]
U2f = 4,
Remember = 5,
OrganizationDuo = 6,
WebAuthn = 7,

View File

@ -1,6 +1,5 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Entities;
using Bit.Core.Services;
using Microsoft.AspNetCore.Identity;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
@ -12,16 +11,13 @@ public class AuthenticatorTokenProvider : IUserTwoFactorTokenProvider<User>
{
private const string CacheKeyFormat = "Authenticator_TOTP_{0}_{1}";
private readonly IServiceProvider _serviceProvider;
private readonly IDistributedCache _distributedCache;
private readonly DistributedCacheEntryOptions _distributedCacheEntryOptions;
public AuthenticatorTokenProvider(
IServiceProvider serviceProvider,
[FromKeyedServices("persistent")]
IDistributedCache distributedCache)
{
_serviceProvider = serviceProvider;
_distributedCache = distributedCache;
_distributedCacheEntryOptions = new DistributedCacheEntryOptions
{
@ -29,15 +25,14 @@ public class AuthenticatorTokenProvider : IUserTwoFactorTokenProvider<User>
};
}
public async Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
public Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
{
var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator);
if (string.IsNullOrWhiteSpace((string)provider?.MetaData["Key"]))
var authenticatorProvider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator);
if (string.IsNullOrWhiteSpace((string)authenticatorProvider?.MetaData["Key"]))
{
return false;
return Task.FromResult(false);
}
return await _serviceProvider.GetRequiredService<IUserService>()
.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Authenticator, user);
return Task.FromResult(authenticatorProvider.Enabled);
}
public Task<string> GenerateAsync(string purpose, UserManager<User> manager, User user)

View File

@ -16,10 +16,11 @@ public class DuoUniversalTokenProvider(
IDuoUniversalTokenService duoUniversalTokenService) : IUserTwoFactorTokenProvider<User>
{
/// <summary>
/// We need the IServiceProvider to resolve the IUserService. There is a complex dependency dance
/// occurring between IUserService, which extends the UserManager<User>, and the usage of the
/// UserManager<User> within this class. Trying to resolve the IUserService using the DI pipeline
/// will not allow the server to start and it will hang and give no helpful indication as to the problem.
/// We need the IServiceProvider to resolve the <see cref="IUserService"/>. There is a complex dependency dance
/// occurring between <see cref="IUserService"/>, which extends the <see cref="UserManager{User}"/>, and the usage
/// of the <see cref="UserManager{User}"/> within this class. Trying to resolve the <see cref="IUserService"/> using
/// the DI pipeline will not allow the server to start and it will hang and give no helpful indication as to the
/// problem.
/// </summary>
private readonly IServiceProvider _serviceProvider = serviceProvider;
private readonly IDataProtectorTokenFactory<DuoUserStateTokenable> _tokenDataFactory = tokenDataFactory;
@ -28,12 +29,13 @@ public class DuoUniversalTokenProvider(
public async Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
{
var userService = _serviceProvider.GetRequiredService<IUserService>();
var provider = await GetDuoTwoFactorProvider(user, userService);
if (provider == null)
var duoUniversalTokenProvider = await GetDuoTwoFactorProvider(user, userService);
if (duoUniversalTokenProvider == null)
{
return false;
}
return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Duo, user);
return duoUniversalTokenProvider.Enabled;
}
public async Task<string> GenerateAsync(string purpose, UserManager<User> manager, User user)
@ -57,7 +59,7 @@ public class DuoUniversalTokenProvider(
}
/// <summary>
/// Get the Duo Two Factor Provider for the user if they have access to Duo
/// Get the Duo Two Factor Provider for the user if they have premium access to Duo
/// </summary>
/// <param name="user">Active User</param>
/// <returns>null or Duo TwoFactorProvider</returns>

View File

@ -1,7 +1,6 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Entities;
using Bit.Core.Services;
using Microsoft.AspNetCore.Identity;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
@ -10,31 +9,25 @@ namespace Bit.Core.Auth.Identity.TokenProviders;
public class EmailTwoFactorTokenProvider : EmailTokenProvider
{
private readonly IServiceProvider _serviceProvider;
public EmailTwoFactorTokenProvider(
IServiceProvider serviceProvider,
[FromKeyedServices("persistent")]
IDistributedCache distributedCache) :
base(distributedCache)
{
_serviceProvider = serviceProvider;
TokenAlpha = false;
TokenNumeric = true;
TokenLength = 6;
}
public override async Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
public override Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
{
var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email);
if (!HasProperMetaData(provider))
var emailTokenProvider = user.GetTwoFactorProvider(TwoFactorProviderType.Email);
if (!HasProperMetaData(emailTokenProvider))
{
return false;
return Task.FromResult(false);
}
return await _serviceProvider.GetRequiredService<IUserService>().
TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Email, user);
return Task.FromResult(emailTokenProvider.Enabled);
}
public override Task<string> GenerateAsync(string purpose, UserManager<User> manager, User user)

View File

@ -25,17 +25,16 @@ public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider<User>
_globalSettings = globalSettings;
}
public async Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
public Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
{
var userService = _serviceProvider.GetRequiredService<IUserService>();
var webAuthnProvider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn);
// null check happens in this method
if (!HasProperMetaData(webAuthnProvider))
{
return false;
return Task.FromResult(false);
}
return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.WebAuthn, user);
return Task.FromResult(webAuthnProvider.Enabled);
}
public async Task<string> GenerateAsync(string purpose, UserManager<User> manager, User user)
@ -81,7 +80,7 @@ public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider<User>
var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn);
var keys = LoadKeys(provider);
if (!provider.MetaData.ContainsKey("login"))
if (!provider.MetaData.TryGetValue("login", out var value))
{
return false;
}
@ -89,7 +88,7 @@ public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider<User>
var clientResponse = JsonSerializer.Deserialize<AuthenticatorAssertionRawResponse>(token,
new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
var jsonOptions = provider.MetaData["login"].ToString();
var jsonOptions = value.ToString();
var options = AssertionOptions.FromJson(jsonOptions);
var webAuthCred = keys.Find(k => k.Item2.Descriptor.Id.SequenceEqual(clientResponse.Id));
@ -126,6 +125,12 @@ public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider<User>
}
/// <summary>
/// Checks if the provider has proper metadata.
/// This is used to determine if the provider has been properly configured.
/// </summary>
/// <param name="provider"></param>
/// <returns>true if metadata is present; false if empty or null</returns>
private bool HasProperMetaData(TwoFactorProvider provider)
{
return provider?.MetaData?.Any() ?? false;

View File

@ -23,19 +23,21 @@ public class YubicoOtpTokenProvider : IUserTwoFactorTokenProvider<User>
public async Task<bool> CanGenerateTwoFactorTokenAsync(UserManager<User> manager, User user)
{
// Ensure the user has access to premium
var userService = _serviceProvider.GetRequiredService<IUserService>();
if (!await userService.CanAccessPremium(user))
{
return false;
}
var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey);
if (!provider?.MetaData.Values.Any(v => !string.IsNullOrWhiteSpace((string)v)) ?? true)
// Check if the user has a YubiKey provider configured
var yubicoProvider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey);
if (!yubicoProvider?.MetaData.Values.Any(v => !string.IsNullOrWhiteSpace((string)v)) ?? true)
{
return false;
}
return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.YubiKey, user);
return yubicoProvider.Enabled;
}
public Task<string> GenerateAsync(string purpose, UserManager<User> manager, User user)

View File

@ -1,7 +1,7 @@
using Bit.Core.Context;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.AspNetCore.Identity;
using Microsoft.Extensions.DependencyInjection;
@ -167,7 +167,7 @@ public class UserStore :
public async Task<bool> GetTwoFactorEnabledAsync(User user, CancellationToken cancellationToken)
{
return await _serviceProvider.GetRequiredService<IUserService>().TwoFactorIsEnabledAsync(user);
return await _serviceProvider.GetRequiredService<ITwoFactorIsEnabledQuery>().TwoFactorIsEnabledAsync(user);
}
public Task SetSecurityStampAsync(User user, string stamp, CancellationToken cancellationToken)

View File

@ -1,6 +0,0 @@
namespace Bit.Core.Auth.Models.Api;
public interface ICaptchaProtectedModel
{
string CaptchaResponse { get; set; }
}

View File

@ -1,9 +0,0 @@
namespace Bit.Core.Auth.Models.Business;
public class CaptchaResponse
{
public bool Success { get; set; }
public bool MaybeBot { get; set; }
public bool IsBot { get; set; }
public double Score { get; set; }
}

View File

@ -1,43 +0,0 @@
using System.Text.Json.Serialization;
using Bit.Core.Entities;
using Bit.Core.Tokens;
namespace Bit.Core.Auth.Models.Business.Tokenables;
public class HCaptchaTokenable : ExpiringTokenable
{
private const double _tokenLifetimeInHours = (double)5 / 60; // 5 minutes
public const string ClearTextPrefix = "BWCaptchaBypass_";
public const string DataProtectorPurpose = "CaptchaServiceDataProtector";
public const string TokenIdentifier = "CaptchaBypassToken";
public string Identifier { get; set; } = TokenIdentifier;
public Guid Id { get; set; }
public string Email { get; set; }
[JsonConstructor]
public HCaptchaTokenable()
{
ExpirationDate = DateTime.UtcNow.AddHours(_tokenLifetimeInHours);
}
public HCaptchaTokenable(User user) : this()
{
Id = user?.Id ?? default;
Email = user?.Email;
}
public bool TokenIsValid(User user)
{
if (Id == default || Email == default || user == null)
{
return false;
}
return Id == user.Id &&
Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase);
}
// Validates deserialized
protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email);
}

View File

@ -4,9 +4,10 @@ using Bit.Core.Tokens;
namespace Bit.Core.Auth.Models.Business.Tokenables;
// This token just provides a verifiable authN mechanism for the API service
// TwoFactorController.cs SendEmailLogin anonymous endpoint so it cannot be
// used maliciously.
/// <summary>
/// This token provides a verifiable authN mechanism for the TwoFactorController.SendEmailLoginAsync
/// anonymous endpoint so it cannot used maliciously.
/// </summary>
public class SsoEmail2faSessionTokenable : ExpiringTokenable
{
// Just over 2 min expiration (client expires session after 2 min)

View File

@ -1,10 +1,18 @@
using Bit.Core.Auth.Enums;
using Bit.Core.Services;
namespace Bit.Core.Auth.Models;
public interface ITwoFactorProvidersUser
{
string TwoFactorProviders { get; }
/// <summary>
/// Get the two factor providers for the user. Currently it can be assumed providers are enabled
/// if they exists in the dictionary. When two factor providers are disabled they are removed
/// from the dictionary. <see cref="IUserService.DisableTwoFactorProviderAsync"/>
/// <see cref="IOrganizationService.DisableTwoFactorProviderAsync"/>
/// </summary>
/// <returns>Dictionary of providers with the type enum as the key</returns>
Dictionary<TwoFactorProviderType, TwoFactorProvider> GetTwoFactorProviders();
Guid? GetUserId();
bool GetPremium();

View File

@ -1,15 +0,0 @@
using Bit.Core.Auth.Models.Business;
using Bit.Core.Context;
using Bit.Core.Entities;
namespace Bit.Core.Auth.Services;
public interface ICaptchaValidationService
{
string SiteKey { get; }
string SiteKeyResponseKeyName { get; }
bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null);
Task<CaptchaResponse> ValidateCaptchaResponseAsync(string captchResponse, string clientIpAddress,
User user = null);
string GenerateCaptchaBypassToken(User user);
}

View File

@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Vault.Models.Data;
@ -20,6 +21,15 @@ public interface IEmergencyAccessService
Task InitiateAsync(Guid id, User initiatingUser);
Task ApproveAsync(Guid id, User approvingUser);
Task RejectAsync(Guid id, User rejectingUser);
/// <summary>
/// This request is made by the Grantee user to fetch the policies <see cref="Policy"/> for the Grantor User.
/// The Grantor User has to be the owner of the organization. <see cref="OrganizationUserType"/>
/// If the Grantor user has OrganizationUserType.Owner then the policies for the _Grantor_ user
/// are returned.
/// </summary>
/// <param name="id">EmergencyAccess.Id being acted on</param>
/// <param name="requestingUser">User making the request, this is the Grantee</param>
/// <returns>null if the GrantorUser is not an organization owner; A list of policies otherwise.</returns>
Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser);
Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser);
Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key);

View File

@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities;
@ -16,7 +15,6 @@ using Bit.Core.Tokens;
using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Repositories;
using Bit.Core.Vault.Services;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Auth.Services;
@ -31,8 +29,6 @@ public class EmergencyAccessService : IEmergencyAccessService
private readonly IMailService _mailService;
private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationService _organizationService;
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
@ -45,9 +41,7 @@ public class EmergencyAccessService : IEmergencyAccessService
ICipherService cipherService,
IMailService mailService,
IUserService userService,
IPasswordHasher<User> passwordHasher,
GlobalSettings globalSettings,
IOrganizationService organizationService,
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer,
IRemoveOrganizationUserCommand removeOrganizationUserCommand)
{
@ -59,9 +53,7 @@ public class EmergencyAccessService : IEmergencyAccessService
_cipherService = cipherService;
_mailService = mailService;
_userService = userService;
_passwordHasher = passwordHasher;
_globalSettings = globalSettings;
_organizationService = organizationService;
_dataProtectorTokenizer = dataProtectorTokenizer;
_removeOrganizationUserCommand = removeOrganizationUserCommand;
}
@ -126,7 +118,12 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Emergency Access not valid.");
}
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email))
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data))
{
throw new BadRequestException("Invalid token.");
}
if (!data.IsValid(emergencyAccessId, user.Email))
{
throw new BadRequestException("Invalid token.");
}
@ -140,6 +137,8 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Invitation already accepted.");
}
// TODO PM-21687
// Might not be reachable since the Tokenable.IsValid() does an email comparison
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
@ -163,6 +162,8 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
// TODO PM-19438/PM-21687
// Not sure why the GrantorId and the GranteeId are supposed to be the same?
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
{
throw new BadRequestException("Emergency Access not valid.");
@ -171,9 +172,9 @@ public class EmergencyAccessService : IEmergencyAccessService
await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
}
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId)
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAccessId, string key, Guid confirmingUserId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId);
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
emergencyAccess.GrantorId != confirmingUserId)
{
@ -224,7 +225,6 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task InitiateAsync(Guid id, User initiatingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
{
@ -285,6 +285,9 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
{
// TODO PM-21687
// Should we look up policies here or just verify the EmergencyAccess is correct
// and handle policy logic else where? Should this be a query/Command?
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
@ -295,7 +298,9 @@ public class EmergencyAccessService : IEmergencyAccessService
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner);
var isOrganizationOwner = grantorOrganizations
.Any(organization => organization.Type == OrganizationUserType.Owner);
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
return policies;
@ -311,7 +316,8 @@ public class EmergencyAccessService : IEmergencyAccessService
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
// TODO PM-21687
// Redundant check of the EmergencyAccessType -> checked in IsValidRequest() ln 308
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
@ -336,7 +342,9 @@ public class EmergencyAccessService : IEmergencyAccessService
grantor.LastPasswordChangeDate = grantor.RevisionDate;
grantor.Key = key;
// Disable TwoFactor providers since they will otherwise block logins
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>());
grantor.SetTwoFactorProviders([]);
// Disable New Device Verification since it will otherwise block logins
grantor.VerifyDevices = false;
await _userRepository.ReplaceAsync(grantor);
// Remove grantor from all organizations unless Owner
@ -421,12 +429,22 @@ public class EmergencyAccessService : IEmergencyAccessService
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
}
private string NameOrEmail(User user)
private static string NameOrEmail(User user)
{
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
}
private bool IsValidRequest(EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType)
/*
* Checks if EmergencyAccess Object is null
* Checks the requesting user is the same as the granteeUser (So we are checking for proper grantee action)
* Status _must_ equal RecoveryApproved (This means the grantor has invited, the grantee has accepted, and the grantor has approved so the shared key exists but hasn't been exercised yet)
* request type must equal the type of access requested (View or Takeover)
*/
private static bool IsValidRequest(
EmergencyAccess availableAccess,
User requestingUser,
EmergencyAccessType requestedAccessType)
{
return availableAccess != null &&
availableAccess.GranteeId == requestingUser.Id &&

View File

@ -1,132 +0,0 @@
using System.Net.Http.Json;
using System.Text.Json.Serialization;
using Bit.Core.Auth.Models.Business;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Settings;
using Bit.Core.Tokens;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Auth.Services;
public class HCaptchaValidationService : ICaptchaValidationService
{
private readonly ILogger<HCaptchaValidationService> _logger;
private readonly IHttpClientFactory _httpClientFactory;
private readonly GlobalSettings _globalSettings;
private readonly IDataProtectorTokenFactory<HCaptchaTokenable> _tokenizer;
public HCaptchaValidationService(
ILogger<HCaptchaValidationService> logger,
IHttpClientFactory httpClientFactory,
IDataProtectorTokenFactory<HCaptchaTokenable> tokenizer,
GlobalSettings globalSettings)
{
_logger = logger;
_httpClientFactory = httpClientFactory;
_globalSettings = globalSettings;
_tokenizer = tokenizer;
}
public string SiteKeyResponseKeyName => "HCaptcha_SiteKey";
public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey;
public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user));
public async Task<CaptchaResponse> ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress,
User user = null)
{
var response = new CaptchaResponse { Success = false };
if (string.IsNullOrWhiteSpace(captchaResponse))
{
return response;
}
if (user != null && ValidateCaptchaBypassToken(captchaResponse, user))
{
response.Success = true;
return response;
}
var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService");
var requestMessage = new HttpRequestMessage
{
Method = HttpMethod.Post,
RequestUri = new Uri("https://hcaptcha.com/siteverify"),
Content = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) },
{ "secret", _globalSettings.Captcha.HCaptchaSecretKey },
{ "sitekey", SiteKey },
{ "remoteip", clientIpAddress }
})
};
HttpResponseMessage responseMessage;
try
{
responseMessage = await httpClient.SendAsync(requestMessage);
}
catch (Exception e)
{
_logger.LogError(11389, e, "Unable to verify with HCaptcha.");
return response;
}
if (!responseMessage.IsSuccessStatusCode)
{
return response;
}
using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync<HCaptchaResponse>();
response.Success = hcaptchaResponse.Success;
var score = hcaptchaResponse.Score.GetValueOrDefault();
response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold;
response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold;
response.Score = score;
return response;
}
public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null)
{
if (user == null)
{
return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired;
}
var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts;
var failedLoginCount = user?.FailedLoginCount ?? 0;
var requireOnCloud = !_globalSettings.SelfHosted && !user.EmailVerified &&
user.CreationDate < DateTime.UtcNow.AddHours(-24);
return currentContext.IsBot ||
_globalSettings.Captcha.ForceCaptchaRequired ||
requireOnCloud ||
failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling;
}
private static bool TokenIsValidApiKey(string bypassToken, User user) =>
!string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken;
private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user)
{
return _tokenizer.TryUnprotect(encryptedToken, out var data) &&
data.Valid && data.TokenIsValid(user);
}
private bool ValidateCaptchaBypassToken(string bypassToken, User user) =>
TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user);
public class HCaptchaResponse : IDisposable
{
[JsonPropertyName("success")]
public bool Success { get; set; }
[JsonPropertyName("score")]
public double? Score { get; set; }
[JsonPropertyName("score_reason")]
public List<string> ScoreReason { get; set; }
public void Dispose() { }
}
}

View File

@ -1,18 +0,0 @@
using Bit.Core.Auth.Models.Business;
using Bit.Core.Context;
using Bit.Core.Entities;
namespace Bit.Core.Auth.Services;
public class NoopCaptchaValidationService : ICaptchaValidationService
{
public string SiteKeyResponseKeyName => null;
public string SiteKey => null;
public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) => false;
public string GenerateCaptchaBypassToken(User user) => "";
public Task<CaptchaResponse> ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress,
User user = null)
{
return Task.FromResult(new CaptchaResponse { Success = true });
}
}

View File

@ -108,6 +108,7 @@ public class RegisterUserCommand : IRegisterUserCommand
var result = await _userService.CreateUserAsync(user, masterPasswordHash);
if (result == IdentityResult.Success)
{
var sentWelcomeEmail = false;
if (!string.IsNullOrEmpty(user.ReferenceData))
{
var referenceData = JsonConvert.DeserializeObject<Dictionary<string, object>>(user.ReferenceData);
@ -115,6 +116,7 @@ public class RegisterUserCommand : IRegisterUserCommand
{
var initiationPath = value.ToString();
await SendAppropriateWelcomeEmailAsync(user, initiationPath);
sentWelcomeEmail = true;
if (!string.IsNullOrEmpty(initiationPath))
{
await _referenceEventService.RaiseEventAsync(
@ -128,6 +130,11 @@ public class RegisterUserCommand : IRegisterUserCommand
}
}
if (!sentWelcomeEmail)
{
await _mailService.SendWelcomeEmailAsync(user);
}
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user, _currentContext));
}

View File

@ -2,6 +2,7 @@
namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
public interface ITwoFactorIsEnabledQuery
{
/// <summary>
@ -16,7 +17,8 @@ public interface ITwoFactorIsEnabledQuery
/// <typeparam name="T">The type of user in the list. Must implement <see cref="ITwoFactorProvidersUser"/>.</typeparam>
Task<IEnumerable<(T user, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync<T>(IEnumerable<T> users) where T : ITwoFactorProvidersUser;
/// <summary>
/// Returns whether two factor is enabled for the user.
/// Returns whether two factor is enabled for the user. A user is able to have a TwoFactorProvider that is enabled but requires Premium.
/// If the user does not have premium then the TwoFactorProvider is considered _not_ enabled.
/// </summary>
/// <param name="user">The user to check.</param>
Task<bool> TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user);

View File

@ -1,17 +1,13 @@
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Repositories;
namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth;
public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFactorIsEnabledQuery
{
private readonly IUserRepository _userRepository;
public TwoFactorIsEnabledQuery(IUserRepository userRepository)
{
_userRepository = userRepository;
}
private readonly IUserRepository _userRepository = userRepository;
public async Task<IEnumerable<(Guid userId, bool twoFactorIsEnabled)>> TwoFactorIsEnabledAsync(IEnumerable<Guid> userIds)
{
@ -21,26 +17,15 @@ public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
return result;
}
var userDetails = await _userRepository.GetManyWithCalculatedPremiumAsync(userIds.ToList());
var userDetails = await _userRepository.GetManyWithCalculatedPremiumAsync([.. userIds]);
foreach (var userDetail in userDetails)
{
var hasTwoFactor = false;
var providers = userDetail.GetTwoFactorProviders();
if (providers != null)
{
// Get all enabled providers
var enabledProviderKeys = from provider in providers
where provider.Value?.Enabled ?? false
select provider.Key;
// Find the first provider that is enabled and passes the premium check
hasTwoFactor = enabledProviderKeys
.Select(type => userDetail.HasPremiumAccess || !TwoFactorProvider.RequiresPremium(type))
.FirstOrDefault();
}
result.Add((userDetail.Id, hasTwoFactor));
result.Add(
(userDetail.Id,
await TwoFactorEnabledAsync(userDetail.GetTwoFactorProviders(),
() => Task.FromResult(userDetail.HasPremiumAccess))
)
);
}
return result;
@ -83,41 +68,56 @@ public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery
return false;
}
var providers = user.GetTwoFactorProviders();
if (providers == null || !providers.Any())
return await TwoFactorEnabledAsync(
user.GetTwoFactorProviders(),
async () =>
{
var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value);
return calcUser?.HasPremiumAccess ?? false;
});
}
/// <summary>
/// Checks to see what kind of two-factor is enabled.
/// We use a delegate to check if the user has premium access, since there are multiple ways to
/// determine if a user has premium access.
/// </summary>
/// <param name="providers">dictionary of two factor providers</param>
/// <param name="hasPremiumAccessDelegate">function to check if the user has premium access</param>
/// <returns> true if the user has two factor enabled; false otherwise;</returns>
private async static Task<bool> TwoFactorEnabledAsync(
Dictionary<TwoFactorProviderType, TwoFactorProvider> providers,
Func<Task<bool>> hasPremiumAccessDelegate)
{
// If there are no providers, then two factor is not enabled
if (providers == null || providers.Count == 0)
{
return false;
}
// Get all enabled providers
var enabledProviderKeys = providers
.Where(provider => provider.Value?.Enabled ?? false)
.Select(provider => provider.Key);
// TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into.
var enabledProviderKeys = from provider in providers
where provider.Value?.Enabled ?? false
select provider.Key;
// If no providers are enabled then two factor is not enabled
if (!enabledProviderKeys.Any())
{
return false;
}
// Determine if any enabled provider passes the premium check
var hasTwoFactor = enabledProviderKeys
.Select(type => user.GetPremium() || !TwoFactorProvider.RequiresPremium(type))
.FirstOrDefault();
// If no enabled provider passes the check, check the repository for organization premium access
if (!hasTwoFactor)
// If there are only premium two factor options then standard two factor is not enabled
var onlyHasPremiumTwoFactor = enabledProviderKeys.All(TwoFactorProvider.RequiresPremium);
if (onlyHasPremiumTwoFactor)
{
var userDetails = await _userRepository.GetManyWithCalculatedPremiumAsync(new List<Guid> { userId.Value });
var userDetail = userDetails.FirstOrDefault();
if (userDetail != null)
{
hasTwoFactor = enabledProviderKeys
.Select(type => userDetail.HasPremiumAccess || !TwoFactorProvider.RequiresPremium(type))
.FirstOrDefault();
}
// There are no Standard two factor options, check if the user has premium access
// If the user has premium access, then two factor is enabled
var premiumAccess = await hasPremiumAccessDelegate();
return premiumAccess;
}
return hasTwoFactor;
// The user has at least one non-premium two factor option
return true;
}
}

View File

@ -1,36 +0,0 @@
using Bit.Core.Auth.Models.Api;
using Bit.Core.Auth.Services;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Auth.Utilities;
public class CaptchaProtectedAttribute : ActionFilterAttribute
{
public string ModelParameterName { get; set; } = "model";
public override void OnActionExecuting(ActionExecutingContext context)
{
var currentContext = context.HttpContext.RequestServices.GetRequiredService<ICurrentContext>();
var captchaValidationService = context.HttpContext.RequestServices.GetRequiredService<ICaptchaValidationService>();
if (captchaValidationService.RequireCaptchaValidation(currentContext, null))
{
var captchaResponse = (context.ActionArguments[ModelParameterName] as ICaptchaProtectedModel)?.CaptchaResponse;
if (string.IsNullOrWhiteSpace(captchaResponse))
{
throw new BadRequestException(captchaValidationService.SiteKeyResponseKeyName, captchaValidationService.SiteKey);
}
var captchaValidationResponse = captchaValidationService.ValidateCaptchaResponseAsync(captchaResponse,
currentContext.IpAddress, null).GetAwaiter().GetResult();
if (!captchaValidationResponse.Success || captchaValidationResponse.IsBot)
{
throw new BadRequestException("Captcha is invalid. Please refresh and try again");
}
}
}
}

View File

@ -2,10 +2,6 @@
public static class StripeConstants
{
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class AutomaticTaxStatus
{
public const string Failed = "failed";
@ -69,6 +65,11 @@ public static class StripeConstants
public const string USBankAccount = "us_bank_account";
}
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class ProrationBehavior
{
public const string AlwaysInvoice = "always_invoice";
@ -88,6 +89,13 @@ public static class StripeConstants
public const string Paused = "paused";
}
public static class TaxExempt
{
public const string Exempt = "exempt";
public const string None = "none";
public const string Reverse = "reverse";
}
public static class ValidateTaxLocationTiming
{
public const string Deferred = "deferred";

View File

@ -15,12 +15,7 @@ public static class CustomerExtensions
}
};
/// <summary>
/// Determines if a Stripe customer supports automatic tax
/// </summary>
/// <param name="customer"></param>
/// <returns></returns>
public static bool HasTaxLocationVerified(this Customer customer) =>
public static bool HasRecognizedTaxLocation(this Customer customer) =>
customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation;
public static decimal GetBillingBalance(this Customer customer)

View File

@ -4,7 +4,9 @@ using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Billing.Tax.Commands;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
namespace Bit.Core.Billing.Extensions;
@ -24,5 +26,6 @@ public static class ServiceCollectionExtensions
services.AddTransient<IAutomaticTaxFactory, AutomaticTaxFactory>();
services.AddLicenseServices();
services.AddPricingClient();
services.AddTransient<IPreviewTaxAmountCommand, PreviewTaxAmountCommand>();
}
}

View File

@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions
}
// We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{
return false;
}

View File

@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions
}
// We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{
return false;
}

View File

@ -34,7 +34,6 @@ public static class OrganizationLicenseConstants
public const string UseSecretsManager = nameof(UseSecretsManager);
public const string SmSeats = nameof(SmSeats);
public const string SmServiceAccounts = nameof(SmServiceAccounts);
public const string SmMaxProjects = nameof(SmMaxProjects);
public const string LimitCollectionCreationDeletion = nameof(LimitCollectionCreationDeletion);
public const string AllowAdminAccessToAllCollectionItems = nameof(AllowAdminAccessToAllCollectionItems);
public const string UseRiskInsights = nameof(UseRiskInsights);
@ -43,6 +42,7 @@ public static class OrganizationLicenseConstants
public const string ExpirationWithoutGracePeriod = nameof(ExpirationWithoutGracePeriod);
public const string Trial = nameof(Trial);
public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies);
public const string UseOrganizationDomains = nameof(UseOrganizationDomains);
}
public static class UserLicenseConstants

View File

@ -7,5 +7,4 @@ public class LicenseContext
{
public Guid? InstallationId { get; init; }
public required SubscriptionInfo SubscriptionInfo { get; init; }
public int? SmMaxProjects { get; set; }
}

View File

@ -54,6 +54,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory<Organizati
new(nameof(OrganizationLicenseConstants.ExpirationWithoutGracePeriod), expirationWithoutGracePeriod.ToString(CultureInfo.InvariantCulture)),
new(nameof(OrganizationLicenseConstants.Trial), trial.ToString()),
new(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()),
new(nameof(OrganizationLicenseConstants.UseOrganizationDomains), entity.UseOrganizationDomains.ToString()),
};
if (entity.Name is not null)
@ -112,11 +113,6 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory<Organizati
}
claims.Add(new Claim(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()));
if (licenseContext.SmMaxProjects.HasValue)
{
claims.Add(new Claim(nameof(OrganizationLicenseConstants.SmMaxProjects), licenseContext.SmMaxProjects.ToString()));
}
return Task.FromResult(claims);
}

View File

@ -309,6 +309,7 @@ public class OrganizationMigrator(
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso;
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory;

View File

@ -7,4 +7,5 @@ public class TrialSendVerificationEmailRequestModel : RegisterSendVerificationEm
{
public ProductTierType ProductTier { get; set; }
public IEnumerable<ProductType> Products { get; set; }
public int? TrialLength { get; set; }
}

View File

@ -0,0 +1,36 @@
using OneOf;
namespace Bit.Core.Billing.Models;
public record BadRequest(string TranslationKey)
{
public static BadRequest TaxIdNumberInvalid => new(BillingErrorTranslationKeys.TaxIdInvalid);
public static BadRequest TaxLocationInvalid => new(BillingErrorTranslationKeys.CustomerTaxLocationInvalid);
public static BadRequest UnknownTaxIdType => new(BillingErrorTranslationKeys.UnknownTaxIdType);
}
public record Unhandled(string TranslationKey = BillingErrorTranslationKeys.UnhandledError);
public class BillingCommandResult<T> : OneOfBase<T, BadRequest, Unhandled>
{
private BillingCommandResult(OneOf<T, BadRequest, Unhandled> input) : base(input) { }
public static implicit operator BillingCommandResult<T>(T output) => new(output);
public static implicit operator BillingCommandResult<T>(BadRequest badRequest) => new(badRequest);
public static implicit operator BillingCommandResult<T>(Unhandled unhandled) => new(unhandled);
}
public static class BillingErrorTranslationKeys
{
// "The tax ID number you provided was invalid. Please try again or contact support."
public const string TaxIdInvalid = "taxIdInvalid";
// "Your location wasn't recognized. Please ensure your country and postal code are valid and try again."
public const string CustomerTaxLocationInvalid = "customerTaxLocationInvalid";
// "Something went wrong with your request. Please contact support."
public const string UnhandledError = "unhandledBillingError";
// "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support."
public const string UnknownTaxIdType = "unknownTaxIdType";
}

View File

@ -1,5 +1,6 @@
using Bit.Core.Auth.Models.Mail;
using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
namespace Bit.Core.Billing.Models.Mail;
@ -16,13 +17,26 @@ public class TrialInitiationVerifyEmail : RegisterVerifyEmail
$"&email={Email}" +
$"&fromEmail=true" +
$"&productTier={(int)ProductTier}" +
$"&product={string.Join(",", Product.Select(p => (int)p))}";
$"&product={string.Join(",", Product.Select(p => (int)p))}" +
$"&trialLength={TrialLength}";
}
public string VerifyYourEmailHTMLCopy =>
TrialLength == 7
? "Verify your email address below to finish signing up for your free trial."
: $"Verify your email address below to finish signing up for your {ProductTier.GetDisplayName()} plan.";
public string VerifyYourEmailTextCopy =>
TrialLength == 7
? "Verify your email address using the link below and start your free trial of Bitwarden."
: $"Verify your email address using the link below and start your {ProductTier.GetDisplayName()} Bitwarden plan.";
public ProductTierType ProductTier { get; set; }
public IEnumerable<ProductType> Product { get; set; }
public int TrialLength { get; set; }
/// <summary>
/// Currently we only support one product type at a time, despite Product being a collection.
/// If we receive both PasswordManager and SecretsManager, we'll send the user to the PM trial route

View File

@ -1,4 +1,6 @@
namespace Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Models;
public record PaymentMethod(
long AccountCredit,

View File

@ -1,4 +1,6 @@
namespace Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Models.Sales;
#nullable enable

View File

@ -1,5 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Models.Sales;
@ -26,12 +27,21 @@ public class OrganizationSale
public static OrganizationSale From(
Organization organization,
OrganizationSignup signup) => new()
OrganizationSignup signup)
{
var customerSetup = string.IsNullOrEmpty(organization.GatewayCustomerId) ? GetCustomerSetup(signup) : null;
var subscriptionSetup = GetSubscriptionSetup(signup);
subscriptionSetup.SkipTrial = signup.SkipTrial;
return new OrganizationSale
{
Organization = organization,
CustomerSetup = string.IsNullOrEmpty(organization.GatewayCustomerId) ? GetCustomerSetup(signup) : null,
SubscriptionSetup = GetSubscriptionSetup(signup)
CustomerSetup = customerSetup,
SubscriptionSetup = subscriptionSetup
};
}
public static OrganizationSale From(
Organization organization,

View File

@ -1,4 +1,5 @@
using Bit.Core.Entities;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Business;

View File

@ -24,6 +24,7 @@ public abstract record Plan
public bool Has2fa { get; protected init; }
public bool HasApi { get; protected init; }
public bool HasSso { get; protected init; }
public bool HasOrganizationDomains { get; protected init; }
public bool HasKeyConnector { get; protected init; }
public bool HasScim { get; protected init; }
public bool HasResetPassword { get; protected init; }

View File

@ -26,6 +26,7 @@ public record Enterprise2019Plan : Plan
Has2fa = true;
HasApi = true;
HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true;
HasScim = true;
HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record Enterprise2020Plan : Plan
Has2fa = true;
HasApi = true;
HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true;
HasScim = true;
HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record EnterprisePlan : Plan
Has2fa = true;
HasApi = true;
HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true;
HasScim = true;
HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record Enterprise2023Plan : Plan
Has2fa = true;
HasApi = true;
HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true;
HasScim = true;
HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record PlanAdapter : Plan
Has2fa = HasFeature("2fa");
HasApi = HasFeature("api");
HasSso = HasFeature("sso");
HasOrganizationDomains = HasFeature("organizationDomains");
HasKeyConnector = HasFeature("keyConnector");
HasScim = HasFeature("scim");
HasResetPassword = HasFeature("resetPassword");

View File

@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Services;

View File

@ -1,5 +1,6 @@
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
namespace Bit.Core.Billing.Services;

View File

@ -4,6 +4,7 @@ using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Models.Business;
using Stripe;
@ -59,7 +60,7 @@ public interface IProviderBillingService
int seatAdjustment);
/// <summary>
/// Determines whether the provided <paramref name="seatAdjustment"/> will result in a purchase for the <paramref name="provider"/>'s <see cref="planType"/>.
/// Determines whether the provided <paramref name="seatAdjustment"/> will result in a purchase for the <paramref name="provider"/>'s <see cref="PlanType"/>.
/// Seat adjustments that result in purchases include:
/// <list type="bullet">
/// <item>The <paramref name="provider"/> going from below the seat minimum to above the seat minimum for the provided <paramref name="planType"/></item>

View File

@ -1,4 +1,5 @@
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Stripe;

View File

@ -1,11 +1,13 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
@ -33,16 +35,15 @@ public class OrganizationBillingService(
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService,
IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService
ITaxService taxService) : IOrganizationBillingService
{
public async Task Finalize(OrganizationSale sale)
{
var (organization, customerSetup, subscriptionSetup) = sale;
var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null
? await CreateCustomerAsync(organization, customerSetup)
: await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
: await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup);
@ -119,7 +120,8 @@ public class OrganizationBillingService(
subscription.CurrentPeriodEnd);
}
public async Task UpdatePaymentMethod(
public async Task
UpdatePaymentMethod(
Organization organization,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
@ -149,8 +151,11 @@ public class OrganizationBillingService(
private async Task<Customer> CreateCustomerAsync(
Organization organization,
CustomerSetup customerSetup)
CustomerSetup customerSetup,
PlanType? updatedPlanType = null)
{
var planType = updatedPlanType ?? organization.PlanType;
var displayName = organization.DisplayName();
var customerCreateOptions = new CustomerCreateOptions
@ -210,13 +215,24 @@ public class OrganizationBillingService(
City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country,
Country = customerSetup.TaxInformation.Country
};
customerCreateOptions.Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
};
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge &&
planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families &&
customerSetup.TaxInformation.Country != "US")
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId))
{
var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country,
@ -397,21 +413,68 @@ public class OrganizationBillingService(
TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays
};
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType);
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else
else if (customer.HasRecognizedTaxLocation())
{
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions();
subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation();
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled =
subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
}
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
}
private async Task<Customer> GetCustomerWhileEnsuringCorrectTaxExemptionAsync(
Organization organization,
SubscriptionSetup subscriptionSetup)
{
var customer = await subscriberService.GetCustomerOrThrow(organization,
new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is
not (ProductTierType.Teams or
ProductTierType.TeamsStarter or
ProductTierType.Enterprise))
{
return customer;
}
List<string> expansions = ["tax", "tax_ids"];
customer = customer switch
{
{ Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.Reverse
}),
{ Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.None
}),
_ => customer
};
return customer;
}
private async Task<bool> IsEligibleForSelfHostAsync(
Organization organization)
{

View File

@ -5,14 +5,12 @@ using Bit.Core.Entities;
using Bit.Core.Models.BitStripe;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Billing.Services.Implementations;
public class PaymentHistoryService(
IStripeAdapter stripeAdapter,
ITransactionRepository transactionRepository,
ILogger<PaymentHistoryService> logger) : IPaymentHistoryService
ITransactionRepository transactionRepository) : IPaymentHistoryService
{
public async Task<IEnumerable<BillingHistoryInfo.BillingInvoice>> GetInvoiceHistoryAsync(
ISubscriber subscriber,

View File

@ -2,7 +2,7 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@ -10,7 +10,6 @@ using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Braintree;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Stripe;
using Customer = Stripe.Customer;
@ -22,20 +21,18 @@ using static Utilities;
public class PremiumUserBillingService(
IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings,
ILogger<PremiumUserBillingService> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository,
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService
IUserRepository userRepository) : IPremiumUserBillingService
{
public async Task Credit(User user, decimal amount)
{
var customer = await subscriberService.GetCustomer(user);
// Negative credit represents a balance and all Stripe denomination is in cents.
// Negative credit represents a balance, and all Stripe denomination is in cents.
var credit = (long)(amount * -100);
if (customer == null)
@ -182,7 +179,7 @@ public class PremiumUserBillingService(
City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country,
Country = customerSetup.TaxInformation.Country
},
Description = user.Name,
Email = user.Email,
@ -322,6 +319,10 @@ public class PremiumUserBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id,
Items = subscriptionItemOptionsList,
@ -335,18 +336,6 @@ public class PremiumUserBillingService(
OffSession = true
};
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
};
}
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (usingPayPal)
@ -378,7 +367,7 @@ public class PremiumUserBillingService(
City = taxInformation.City,
PostalCode = taxInformation.PostalCode,
State = taxInformation.State,
Country = taxInformation.Country,
Country = taxInformation.Country
},
Expand = ["tax"],
Tax = new CustomerTaxOptions

View File

@ -1,7 +1,12 @@
using Bit.Core.Billing.Caches;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@ -26,8 +31,7 @@ public class SubscriberService(
ILogger<SubscriberService> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ITaxService taxService,
IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService
ITaxService taxService) : ISubscriberService
{
public async Task CancelSubscription(
ISubscriber subscriber,
@ -126,7 +130,7 @@ public class SubscriberService(
[subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
},
Email = subscriber.BillingEmailAddress(),
PaymentMethodNonce = paymentMethodNonce,
PaymentMethodNonce = paymentMethodNonce
});
if (customerResult.IsSuccess())
@ -480,7 +484,7 @@ public class SubscriberService(
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
// Find the customer's existing setup intents that should be cancelled.
// Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -517,7 +521,7 @@ public class SubscriberService(
await stripeAdapter.PaymentMethodAttachAsync(token,
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
// Find the customer's existing setup intents that should be cancelled.
// Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -635,7 +639,8 @@ public class SubscriberService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInformation.Country,
taxInformation.TaxId);
throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError");
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
}
@ -652,53 +657,84 @@ public class SubscriberService(
logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.",
taxInformation.TaxId,
taxInformation.Country);
throw new Exceptions.BadRequestException("billingInvalidTaxIdError");
throw new BadRequestException("billingInvalidTaxIdError");
default:
logger.LogError(e,
"Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.",
taxInformation.TaxId,
taxInformation.Country,
customer.Id);
throw new Exceptions.BadRequestException("billingTaxIdCreationError");
throw new BadRequestException("billingTaxIdCreationError");
}
}
}
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
var subscription =
customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId);
var isBusinessUseSubscriber = subscriber switch
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families,
Provider => true,
_ => false
};
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber)
{
switch (customer)
{
var subscriptionGetOptions = new SubscriptionGetOptions
case
{
Expand = ["customer.tax", "customer.tax_ids"]
};
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (automaticTaxOptions?.AutomaticTax?.Enabled != null)
Address.Country: not "US",
TaxExempt: not StripeConstants.TaxExempt.Reverse
}:
await stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
break;
case
{
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions);
}
Address.Country: "US",
TaxExempt: StripeConstants.TaxExempt.Reverse
}:
await stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None });
break;
}
}
else
{
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer))
if (!subscription.AutomaticTax.Enabled)
{
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId,
await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
}
else
{
var automaticTaxShouldBeEnabled = subscriber switch
{
User => true,
Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
_ => false
};
return;
bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer)
=> !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) &&
(localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) &&
localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported;
if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled)
{
await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
}
}

View File

@ -0,0 +1,147 @@
#nullable enable
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Tax.Commands;
public interface IPreviewTaxAmountCommand
{
Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters);
}
public class PreviewTaxAmountCommand(
ILogger<PreviewTaxAmountCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
ITaxService taxService) : IPreviewTaxAmountCommand
{
public async Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters)
{
var (planType, productType, taxInformation) = parameters;
var plan = await pricingClient.GetPlanOrThrow(planType);
var options = new InvoiceCreatePreviewOptions
{
Currency = "usd",
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions
{
Country = taxInformation.Country,
PostalCode = taxInformation.PostalCode
}
},
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items = [
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.HasNonSeatBasedPasswordManagerPlan() ? plan.PasswordManager.StripePlanId : plan.PasswordManager.StripeSeatPlanId,
Quantity = 1
}
]
}
};
if (productType == ProductType.SecretsManager)
{
options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = 1
});
options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone;
}
if (!string.IsNullOrEmpty(taxInformation.TaxId))
{
var taxIdType = taxService.GetStripeTaxCode(
taxInformation.Country,
taxInformation.TaxId);
if (string.IsNullOrEmpty(taxIdType))
{
return BadRequest.UnknownTaxIdType;
}
options.CustomerDetails.TaxIds = [
new InvoiceCustomerDetailsTaxIdOptions
{
Type = taxIdType,
Value = taxInformation.TaxId
}
];
}
if (planType.GetProductTier() == ProductTierType.Families)
{
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
}
else
{
options.AutomaticTax = new InvoiceAutomaticTaxOptions
{
Enabled = options.CustomerDetails.Address.Country == "US" ||
options.CustomerDetails.TaxIds is [_, ..]
};
}
try
{
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return Convert.ToDecimal(invoice.Tax) / 100;
}
catch (StripeException stripeException) when (stripeException.StripeError.Code ==
StripeConstants.ErrorCodes.CustomerTaxLocationInvalid)
{
return BadRequest.TaxLocationInvalid;
}
catch (StripeException stripeException) when (stripeException.StripeError.Code ==
StripeConstants.ErrorCodes.TaxIdInvalid)
{
return BadRequest.TaxIdNumberInvalid;
}
catch (StripeException stripeException)
{
logger.LogError(stripeException, "Stripe responded with an error during {Operation}. Code: {Code}", nameof(PreviewTaxAmountCommand), stripeException.StripeError.Code);
return new Unhandled();
}
}
}
#region Command Parameters
public record OrganizationTrialParameters
{
public required PlanType PlanType { get; set; }
public required ProductType ProductType { get; set; }
public required TaxInformationDTO TaxInformation { get; set; }
public void Deconstruct(
out PlanType planType,
out ProductType productType,
out TaxInformationDTO taxInformation)
{
planType = PlanType;
productType = ProductType;
taxInformation = TaxInformation;
}
public record TaxInformationDTO
{
public required string Country { get; set; }
public required string PostalCode { get; set; }
public string? TaxId { get; set; }
}
}
#endregion

View File

@ -1,6 +1,6 @@
using System.Text.RegularExpressions;
namespace Bit.Core.Billing.Models;
namespace Bit.Core.Billing.Tax.Models;
public class TaxIdType
{

View File

@ -1,6 +1,6 @@
using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Models;
namespace Bit.Core.Billing.Tax.Models;
public record TaxInformation(
string Country,

View File

@ -1,6 +1,6 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Core.Billing.Models.Api.Requests.Accounts;
namespace Bit.Core.Billing.Tax.Requests;
public class PreviewIndividualInvoiceRequestBody
{

View File

@ -2,7 +2,7 @@
using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
namespace Bit.Core.Billing.Models.Api.Requests.Organizations;
namespace Bit.Core.Billing.Tax.Requests;
public class PreviewOrganizationInvoiceRequestBody
{

Some files were not shown because too many files have changed in this diff Show More