From 04cf513d7868085e0c954b5bd2220991e75cd19a Mon Sep 17 00:00:00 2001 From: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> Date: Thu, 5 Dec 2024 09:31:14 -0500 Subject: [PATCH] [PM-11516] Initial license file refactor (#5002) * Added the ability to create a JWT on an organization license that contains all license properties as claims * Added the ability to create a JWT on a user license that contains all license properties as claims * Added ability to consume JWT licenses * Resolved generic type issues when getting claim value * Now validating the jwt signature, exp, and iat * Moved creation of ClaimsPrincipal outside of licenses given dependecy on cert * Ran dotnet format. Resolved identity error * Updated claim types to use string constants * Updated jwt expires to be one year * Fixed bug requiring email verification to be on the token * dotnet format * Patch build process --------- Co-authored-by: Matt Bishop --- .../Implementations/OrganizationService.cs | 3 +- .../Extensions/ServiceCollectionExtensions.cs | 2 + .../Licenses/Extensions/LicenseExtensions.cs | 151 ++++++++++ .../LicenseServiceCollectionExtensions.cs | 16 + src/Core/Billing/Licenses/LicenseConstants.cs | 58 ++++ .../Billing/Licenses/Models/LicenseContext.cs | 10 + .../Services/ILicenseClaimsFactory.cs | 9 + .../OrganizationLicenseClaimsFactory.cs | 75 +++++ .../UserLicenseClaimsFactory.cs | 37 +++ src/Core/Constants.cs | 1 + src/Core/Models/Business/ILicense.cs | 1 + .../Models/Business/OrganizationLicense.cs | 273 +++++++++++++----- src/Core/Models/Business/UserLicense.cs | 83 +++++- .../Cloud/CloudGetOrganizationLicenseQuery.cs | 6 +- .../UpdateOrganizationLicenseCommand.cs | 3 +- src/Core/Services/ILicensingService.cs | 10 +- .../Implementations/LicensingService.cs | 145 +++++++++- .../Services/Implementations/UserService.cs | 21 +- .../NoopLicensingService.cs | 18 +- .../Utilities/ServiceCollectionExtensions.cs | 2 +- .../Business/OrganizationLicenseTests.cs | 7 +- .../UpdateOrganizationLicenseCommandTests.cs | 15 +- test/Core.Test/Services/UserServiceTests.cs | 6 +- 23 files changed, 846 insertions(+), 106 deletions(-) create mode 100644 src/Core/Billing/Licenses/Extensions/LicenseExtensions.cs create mode 100644 src/Core/Billing/Licenses/Extensions/LicenseServiceCollectionExtensions.cs create mode 100644 src/Core/Billing/Licenses/LicenseConstants.cs create mode 100644 src/Core/Billing/Licenses/Models/LicenseContext.cs create mode 100644 src/Core/Billing/Licenses/Services/ILicenseClaimsFactory.cs create mode 100644 src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs create mode 100644 src/Core/Billing/Licenses/Services/Implementations/UserLicenseClaimsFactory.cs diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index e8756fb325..862b566c91 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -642,7 +642,8 @@ public class OrganizationService : IOrganizationService OrganizationLicense license, User owner, string ownerKey, string collectionName, string publicKey, string privateKey) { - var canUse = license.CanUse(_globalSettings, _licensingService, out var exception); + var claimsPrincipal = _licensingService.GetClaimsPrincipalFromLicense(license); + var canUse = license.CanUse(_globalSettings, _licensingService, claimsPrincipal, out var exception); if (!canUse) { throw new BadRequestException(exception); diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index abfceac736..78253f7399 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches.Implementations; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; @@ -15,5 +16,6 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddLicenseServices(); } } diff --git a/src/Core/Billing/Licenses/Extensions/LicenseExtensions.cs b/src/Core/Billing/Licenses/Extensions/LicenseExtensions.cs new file mode 100644 index 0000000000..184d8dad23 --- /dev/null +++ b/src/Core/Billing/Licenses/Extensions/LicenseExtensions.cs @@ -0,0 +1,151 @@ +using System.Security.Claims; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Enums; +using Bit.Core.Models.Business; + +namespace Bit.Core.Billing.Licenses.Extensions; + +public static class LicenseExtensions +{ + public static DateTime CalculateFreshExpirationDate(this Organization org, SubscriptionInfo subscriptionInfo) + { + if (subscriptionInfo?.Subscription == null) + { + if (org.PlanType == PlanType.Custom && org.ExpirationDate.HasValue) + { + return org.ExpirationDate.Value; + } + + return DateTime.UtcNow.AddDays(7); + } + + var subscription = subscriptionInfo.Subscription; + + if (subscription.TrialEndDate > DateTime.UtcNow) + { + return subscription.TrialEndDate.Value; + } + + if (org.ExpirationDate.HasValue && org.ExpirationDate.Value < DateTime.UtcNow) + { + return org.ExpirationDate.Value; + } + + if (subscription.PeriodEndDate.HasValue && subscription.PeriodDuration > TimeSpan.FromDays(180)) + { + return subscription.PeriodEndDate + .Value + .AddDays(Bit.Core.Constants.OrganizationSelfHostSubscriptionGracePeriodDays); + } + + return org.ExpirationDate?.AddMonths(11) ?? DateTime.UtcNow.AddYears(1); + } + + public static DateTime CalculateFreshRefreshDate(this Organization org, SubscriptionInfo subscriptionInfo, DateTime expirationDate) + { + if (subscriptionInfo?.Subscription == null || + subscriptionInfo.Subscription.TrialEndDate > DateTime.UtcNow || + org.ExpirationDate < DateTime.UtcNow) + { + return expirationDate; + } + + return subscriptionInfo.Subscription.PeriodDuration > TimeSpan.FromDays(180) || + DateTime.UtcNow - expirationDate > TimeSpan.FromDays(30) + ? DateTime.UtcNow.AddDays(30) + : expirationDate; + } + + public static DateTime CalculateFreshExpirationDateWithoutGracePeriod(this Organization org, SubscriptionInfo subscriptionInfo, DateTime expirationDate) + { + if (subscriptionInfo?.Subscription is null) + { + return expirationDate; + } + + var subscription = subscriptionInfo.Subscription; + + if (subscription.TrialEndDate <= DateTime.UtcNow && + org.ExpirationDate >= DateTime.UtcNow && + subscription.PeriodEndDate.HasValue && + subscription.PeriodDuration > TimeSpan.FromDays(180)) + { + return subscription.PeriodEndDate.Value; + } + + return expirationDate; + } + + public static T GetValue(this ClaimsPrincipal principal, string claimType) + { + var claim = principal.FindFirst(claimType); + + if (claim is null) + { + return default; + } + + // Handle Guid + if (typeof(T) == typeof(Guid)) + { + return Guid.TryParse(claim.Value, out var guid) + ? (T)(object)guid + : default; + } + + // Handle DateTime + if (typeof(T) == typeof(DateTime)) + { + return DateTime.TryParse(claim.Value, out var dateTime) + ? (T)(object)dateTime + : default; + } + + // Handle TimeSpan + if (typeof(T) == typeof(TimeSpan)) + { + return TimeSpan.TryParse(claim.Value, out var timeSpan) + ? (T)(object)timeSpan + : default; + } + + // Check for Nullable Types + var underlyingType = Nullable.GetUnderlyingType(typeof(T)) ?? typeof(T); + + // Handle Enums + if (underlyingType.IsEnum) + { + if (Enum.TryParse(underlyingType, claim.Value, true, out var enumValue)) + { + return (T)enumValue; // Cast back to T + } + + return default; // Return default value for non-nullable enums or null for nullable enums + } + + // Handle other Nullable Types (e.g., int?, bool?) + if (underlyingType == typeof(int)) + { + return int.TryParse(claim.Value, out var intValue) + ? (T)(object)intValue + : default; + } + + if (underlyingType == typeof(bool)) + { + return bool.TryParse(claim.Value, out var boolValue) + ? (T)(object)boolValue + : default; + } + + if (underlyingType == typeof(double)) + { + return double.TryParse(claim.Value, out var doubleValue) + ? (T)(object)doubleValue + : default; + } + + // Fallback to Convert.ChangeType for other types including strings + return (T)Convert.ChangeType(claim.Value, underlyingType); + } +} diff --git a/src/Core/Billing/Licenses/Extensions/LicenseServiceCollectionExtensions.cs b/src/Core/Billing/Licenses/Extensions/LicenseServiceCollectionExtensions.cs new file mode 100644 index 0000000000..b08adbd004 --- /dev/null +++ b/src/Core/Billing/Licenses/Extensions/LicenseServiceCollectionExtensions.cs @@ -0,0 +1,16 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Licenses.Services; +using Bit.Core.Billing.Licenses.Services.Implementations; +using Bit.Core.Entities; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Core.Billing.Licenses.Extensions; + +public static class LicenseServiceCollectionExtensions +{ + public static void AddLicenseServices(this IServiceCollection services) + { + services.AddTransient, OrganizationLicenseClaimsFactory>(); + services.AddTransient, UserLicenseClaimsFactory>(); + } +} diff --git a/src/Core/Billing/Licenses/LicenseConstants.cs b/src/Core/Billing/Licenses/LicenseConstants.cs new file mode 100644 index 0000000000..564019affc --- /dev/null +++ b/src/Core/Billing/Licenses/LicenseConstants.cs @@ -0,0 +1,58 @@ +namespace Bit.Core.Billing.Licenses; + +public static class OrganizationLicenseConstants +{ + public const string LicenseType = nameof(LicenseType); + public const string LicenseKey = nameof(LicenseKey); + public const string InstallationId = nameof(InstallationId); + public const string Id = nameof(Id); + public const string Name = nameof(Name); + public const string BusinessName = nameof(BusinessName); + public const string BillingEmail = nameof(BillingEmail); + public const string Enabled = nameof(Enabled); + public const string Plan = nameof(Plan); + public const string PlanType = nameof(PlanType); + public const string Seats = nameof(Seats); + public const string MaxCollections = nameof(MaxCollections); + public const string UsePolicies = nameof(UsePolicies); + public const string UseSso = nameof(UseSso); + public const string UseKeyConnector = nameof(UseKeyConnector); + public const string UseScim = nameof(UseScim); + public const string UseGroups = nameof(UseGroups); + public const string UseEvents = nameof(UseEvents); + public const string UseDirectory = nameof(UseDirectory); + public const string UseTotp = nameof(UseTotp); + public const string Use2fa = nameof(Use2fa); + public const string UseApi = nameof(UseApi); + public const string UseResetPassword = nameof(UseResetPassword); + public const string MaxStorageGb = nameof(MaxStorageGb); + public const string SelfHost = nameof(SelfHost); + public const string UsersGetPremium = nameof(UsersGetPremium); + public const string UseCustomPermissions = nameof(UseCustomPermissions); + public const string Issued = nameof(Issued); + public const string UsePasswordManager = nameof(UsePasswordManager); + public const string UseSecretsManager = nameof(UseSecretsManager); + public const string SmSeats = nameof(SmSeats); + public const string SmServiceAccounts = nameof(SmServiceAccounts); + public const string LimitCollectionCreationDeletion = nameof(LimitCollectionCreationDeletion); + public const string AllowAdminAccessToAllCollectionItems = nameof(AllowAdminAccessToAllCollectionItems); + public const string Expires = nameof(Expires); + public const string Refresh = nameof(Refresh); + public const string ExpirationWithoutGracePeriod = nameof(ExpirationWithoutGracePeriod); + public const string Trial = nameof(Trial); +} + +public static class UserLicenseConstants +{ + public const string LicenseType = nameof(LicenseType); + public const string LicenseKey = nameof(LicenseKey); + public const string Id = nameof(Id); + public const string Name = nameof(Name); + public const string Email = nameof(Email); + public const string Premium = nameof(Premium); + public const string MaxStorageGb = nameof(MaxStorageGb); + public const string Issued = nameof(Issued); + public const string Expires = nameof(Expires); + public const string Refresh = nameof(Refresh); + public const string Trial = nameof(Trial); +} diff --git a/src/Core/Billing/Licenses/Models/LicenseContext.cs b/src/Core/Billing/Licenses/Models/LicenseContext.cs new file mode 100644 index 0000000000..8dcc24e939 --- /dev/null +++ b/src/Core/Billing/Licenses/Models/LicenseContext.cs @@ -0,0 +1,10 @@ +#nullable enable +using Bit.Core.Models.Business; + +namespace Bit.Core.Billing.Licenses.Models; + +public class LicenseContext +{ + public Guid? InstallationId { get; init; } + public required SubscriptionInfo SubscriptionInfo { get; init; } +} diff --git a/src/Core/Billing/Licenses/Services/ILicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/ILicenseClaimsFactory.cs new file mode 100644 index 0000000000..926ad04683 --- /dev/null +++ b/src/Core/Billing/Licenses/Services/ILicenseClaimsFactory.cs @@ -0,0 +1,9 @@ +using System.Security.Claims; +using Bit.Core.Billing.Licenses.Models; + +namespace Bit.Core.Billing.Licenses.Services; + +public interface ILicenseClaimsFactory +{ + Task> GenerateClaims(T entity, LicenseContext licenseContext); +} diff --git a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs new file mode 100644 index 0000000000..300b87dcca --- /dev/null +++ b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs @@ -0,0 +1,75 @@ +using System.Globalization; +using System.Security.Claims; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Licenses.Extensions; +using Bit.Core.Billing.Licenses.Models; +using Bit.Core.Enums; +using Bit.Core.Models.Business; + +namespace Bit.Core.Billing.Licenses.Services.Implementations; + +public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory +{ + public Task> GenerateClaims(Organization entity, LicenseContext licenseContext) + { + var subscriptionInfo = licenseContext.SubscriptionInfo; + var expires = entity.CalculateFreshExpirationDate(subscriptionInfo); + var refresh = entity.CalculateFreshRefreshDate(subscriptionInfo, expires); + var expirationWithoutGracePeriod = entity.CalculateFreshExpirationDateWithoutGracePeriod(subscriptionInfo, expires); + var trial = IsTrialing(entity, subscriptionInfo); + + var claims = new List + { + new(nameof(OrganizationLicenseConstants.LicenseType), LicenseType.Organization.ToString()), + new Claim(nameof(OrganizationLicenseConstants.LicenseKey), entity.LicenseKey), + new(nameof(OrganizationLicenseConstants.InstallationId), licenseContext.InstallationId.ToString()), + new(nameof(OrganizationLicenseConstants.Id), entity.Id.ToString()), + new(nameof(OrganizationLicenseConstants.Name), entity.Name), + new(nameof(OrganizationLicenseConstants.BillingEmail), entity.BillingEmail), + new(nameof(OrganizationLicenseConstants.Enabled), entity.Enabled.ToString()), + new(nameof(OrganizationLicenseConstants.Plan), entity.Plan), + new(nameof(OrganizationLicenseConstants.PlanType), entity.PlanType.ToString()), + new(nameof(OrganizationLicenseConstants.Seats), entity.Seats.ToString()), + new(nameof(OrganizationLicenseConstants.MaxCollections), entity.MaxCollections.ToString()), + new(nameof(OrganizationLicenseConstants.UsePolicies), entity.UsePolicies.ToString()), + new(nameof(OrganizationLicenseConstants.UseSso), entity.UseSso.ToString()), + new(nameof(OrganizationLicenseConstants.UseKeyConnector), entity.UseKeyConnector.ToString()), + new(nameof(OrganizationLicenseConstants.UseScim), entity.UseScim.ToString()), + new(nameof(OrganizationLicenseConstants.UseGroups), entity.UseGroups.ToString()), + new(nameof(OrganizationLicenseConstants.UseEvents), entity.UseEvents.ToString()), + new(nameof(OrganizationLicenseConstants.UseDirectory), entity.UseDirectory.ToString()), + new(nameof(OrganizationLicenseConstants.UseTotp), entity.UseTotp.ToString()), + new(nameof(OrganizationLicenseConstants.Use2fa), entity.Use2fa.ToString()), + new(nameof(OrganizationLicenseConstants.UseApi), entity.UseApi.ToString()), + new(nameof(OrganizationLicenseConstants.UseResetPassword), entity.UseResetPassword.ToString()), + new(nameof(OrganizationLicenseConstants.MaxStorageGb), entity.MaxStorageGb.ToString()), + new(nameof(OrganizationLicenseConstants.SelfHost), entity.SelfHost.ToString()), + new(nameof(OrganizationLicenseConstants.UsersGetPremium), entity.UsersGetPremium.ToString()), + new(nameof(OrganizationLicenseConstants.UseCustomPermissions), entity.UseCustomPermissions.ToString()), + new(nameof(OrganizationLicenseConstants.Issued), DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)), + new(nameof(OrganizationLicenseConstants.UsePasswordManager), entity.UsePasswordManager.ToString()), + new(nameof(OrganizationLicenseConstants.UseSecretsManager), entity.UseSecretsManager.ToString()), + new(nameof(OrganizationLicenseConstants.SmSeats), entity.SmSeats.ToString()), + new(nameof(OrganizationLicenseConstants.SmServiceAccounts), entity.SmServiceAccounts.ToString()), + new(nameof(OrganizationLicenseConstants.LimitCollectionCreationDeletion), entity.LimitCollectionCreationDeletion.ToString()), + new(nameof(OrganizationLicenseConstants.AllowAdminAccessToAllCollectionItems), entity.AllowAdminAccessToAllCollectionItems.ToString()), + new(nameof(OrganizationLicenseConstants.Expires), expires.ToString(CultureInfo.InvariantCulture)), + new(nameof(OrganizationLicenseConstants.Refresh), refresh.ToString(CultureInfo.InvariantCulture)), + new(nameof(OrganizationLicenseConstants.ExpirationWithoutGracePeriod), expirationWithoutGracePeriod.ToString(CultureInfo.InvariantCulture)), + new(nameof(OrganizationLicenseConstants.Trial), trial.ToString()), + }; + + if (entity.BusinessName is not null) + { + claims.Add(new Claim(nameof(OrganizationLicenseConstants.BusinessName), entity.BusinessName)); + } + + return Task.FromResult(claims); + } + + private static bool IsTrialing(Organization org, SubscriptionInfo subscriptionInfo) => + subscriptionInfo?.Subscription is null + ? org.PlanType != PlanType.Custom || !org.ExpirationDate.HasValue + : subscriptionInfo.Subscription.TrialEndDate > DateTime.UtcNow; +} diff --git a/src/Core/Billing/Licenses/Services/Implementations/UserLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/UserLicenseClaimsFactory.cs new file mode 100644 index 0000000000..28c779c3d6 --- /dev/null +++ b/src/Core/Billing/Licenses/Services/Implementations/UserLicenseClaimsFactory.cs @@ -0,0 +1,37 @@ +using System.Globalization; +using System.Security.Claims; +using Bit.Core.Billing.Licenses.Models; +using Bit.Core.Entities; +using Bit.Core.Enums; + +namespace Bit.Core.Billing.Licenses.Services.Implementations; + +public class UserLicenseClaimsFactory : ILicenseClaimsFactory +{ + public Task> GenerateClaims(User entity, LicenseContext licenseContext) + { + var subscriptionInfo = licenseContext.SubscriptionInfo; + + var expires = subscriptionInfo.UpcomingInvoice?.Date?.AddDays(7) ?? entity.PremiumExpirationDate?.AddDays(7); + var refresh = subscriptionInfo.UpcomingInvoice?.Date ?? entity.PremiumExpirationDate; + var trial = (subscriptionInfo.Subscription?.TrialEndDate.HasValue ?? false) && + subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow; + + var claims = new List + { + new(nameof(UserLicenseConstants.LicenseType), LicenseType.User.ToString()), + new(nameof(UserLicenseConstants.LicenseKey), entity.LicenseKey), + new(nameof(UserLicenseConstants.Id), entity.Id.ToString()), + new(nameof(UserLicenseConstants.Name), entity.Name), + new(nameof(UserLicenseConstants.Email), entity.Email), + new(nameof(UserLicenseConstants.Premium), entity.Premium.ToString()), + new(nameof(UserLicenseConstants.MaxStorageGb), entity.MaxStorageGb.ToString()), + new(nameof(UserLicenseConstants.Issued), DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)), + new(nameof(UserLicenseConstants.Expires), expires.ToString()), + new(nameof(UserLicenseConstants.Refresh), refresh.ToString()), + new(nameof(UserLicenseConstants.Trial), trial.ToString()), + }; + + return Task.FromResult(claims); + } +} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 46ed985cb8..14258353d6 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -157,6 +157,7 @@ public static class FeatureFlagKeys public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; public const string InlineMenuTotp = "inline-menu-totp"; + public const string SelfHostLicenseRefactor = "pm-11516-self-host-license-refactor"; public static List GetAllKeys() { diff --git a/src/Core/Models/Business/ILicense.cs b/src/Core/Models/Business/ILicense.cs index ad389b0a12..b0e295bdd9 100644 --- a/src/Core/Models/Business/ILicense.cs +++ b/src/Core/Models/Business/ILicense.cs @@ -12,6 +12,7 @@ public interface ILicense bool Trial { get; set; } string Hash { get; set; } string Signature { get; set; } + string Token { get; set; } byte[] SignatureBytes { get; } byte[] GetDataBytes(bool forHash = false); byte[] ComputeHash(); diff --git a/src/Core/Models/Business/OrganizationLicense.cs b/src/Core/Models/Business/OrganizationLicense.cs index ea51273645..37a086646c 100644 --- a/src/Core/Models/Business/OrganizationLicense.cs +++ b/src/Core/Models/Business/OrganizationLicense.cs @@ -1,10 +1,12 @@ using System.Reflection; +using System.Security.Claims; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json.Serialization; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; @@ -151,6 +153,7 @@ public class OrganizationLicense : ILicense public LicenseType? LicenseType { get; set; } public string Hash { get; set; } public string Signature { get; set; } + public string Token { get; set; } [JsonIgnore] public byte[] SignatureBytes => Convert.FromBase64String(Signature); /// @@ -176,6 +179,7 @@ public class OrganizationLicense : ILicense !p.Name.Equals(nameof(Signature)) && !p.Name.Equals(nameof(SignatureBytes)) && !p.Name.Equals(nameof(LicenseType)) && + !p.Name.Equals(nameof(Token)) && // UsersGetPremium was added in Version 2 (Version >= 2 || !p.Name.Equals(nameof(UsersGetPremium))) && // UseEvents was added in Version 3 @@ -236,8 +240,65 @@ public class OrganizationLicense : ILicense } } - public bool CanUse(IGlobalSettings globalSettings, ILicensingService licensingService, out string exception) + public bool CanUse( + IGlobalSettings globalSettings, + ILicensingService licensingService, + ClaimsPrincipal claimsPrincipal, + out string exception) { + if (string.IsNullOrWhiteSpace(Token) || claimsPrincipal is null) + { + return ObsoleteCanUse(globalSettings, licensingService, out exception); + } + + var errorMessages = new StringBuilder(); + + var enabled = claimsPrincipal.GetValue(nameof(Enabled)); + if (!enabled) + { + errorMessages.AppendLine("Your cloud-hosted organization is currently disabled."); + } + + var installationId = claimsPrincipal.GetValue(nameof(InstallationId)); + if (installationId != globalSettings.Installation.Id) + { + errorMessages.AppendLine("The installation ID does not match the current installation."); + } + + var selfHost = claimsPrincipal.GetValue(nameof(SelfHost)); + if (!selfHost) + { + errorMessages.AppendLine("The license does not allow for on-premise hosting of organizations."); + } + + var licenseType = claimsPrincipal.GetValue(nameof(LicenseType)); + if (licenseType != Enums.LicenseType.Organization) + { + errorMessages.AppendLine("Premium licenses cannot be applied to an organization. " + + "Upload this license from your personal account settings page."); + } + + if (errorMessages.Length > 0) + { + exception = $"Invalid license. {errorMessages.ToString().TrimEnd()}"; + return false; + } + + exception = ""; + return true; + } + + /// + /// Do not extend this method. It is only here for backwards compatibility with old licenses. + /// Instead, extend the CanUse method using the ClaimsPrincipal. + /// + /// + /// + /// + /// + private bool ObsoleteCanUse(IGlobalSettings globalSettings, ILicensingService licensingService, out string exception) + { + // Do not extend this method. It is only here for backwards compatibility with old licenses. var errorMessages = new StringBuilder(); if (!Enabled) @@ -291,101 +352,177 @@ public class OrganizationLicense : ILicense return true; } - public bool VerifyData(Organization organization, IGlobalSettings globalSettings) + public bool VerifyData( + Organization organization, + ClaimsPrincipal claimsPrincipal, + IGlobalSettings globalSettings) { + if (string.IsNullOrWhiteSpace(Token)) + { + return ObsoleteVerifyData(organization, globalSettings); + } + + var issued = claimsPrincipal.GetValue(nameof(Issued)); + var expires = claimsPrincipal.GetValue(nameof(Expires)); + var installationId = claimsPrincipal.GetValue(nameof(InstallationId)); + var licenseKey = claimsPrincipal.GetValue(nameof(LicenseKey)); + var enabled = claimsPrincipal.GetValue(nameof(Enabled)); + var planType = claimsPrincipal.GetValue(nameof(PlanType)); + var seats = claimsPrincipal.GetValue(nameof(Seats)); + var maxCollections = claimsPrincipal.GetValue(nameof(MaxCollections)); + var useGroups = claimsPrincipal.GetValue(nameof(UseGroups)); + var useDirectory = claimsPrincipal.GetValue(nameof(UseDirectory)); + var useTotp = claimsPrincipal.GetValue(nameof(UseTotp)); + var selfHost = claimsPrincipal.GetValue(nameof(SelfHost)); + var name = claimsPrincipal.GetValue(nameof(Name)); + var usersGetPremium = claimsPrincipal.GetValue(nameof(UsersGetPremium)); + var useEvents = claimsPrincipal.GetValue(nameof(UseEvents)); + var use2fa = claimsPrincipal.GetValue(nameof(Use2fa)); + var useApi = claimsPrincipal.GetValue(nameof(UseApi)); + var usePolicies = claimsPrincipal.GetValue(nameof(UsePolicies)); + var useSso = claimsPrincipal.GetValue(nameof(UseSso)); + var useResetPassword = claimsPrincipal.GetValue(nameof(UseResetPassword)); + var useKeyConnector = claimsPrincipal.GetValue(nameof(UseKeyConnector)); + var useScim = claimsPrincipal.GetValue(nameof(UseScim)); + var useCustomPermissions = claimsPrincipal.GetValue(nameof(UseCustomPermissions)); + var useSecretsManager = claimsPrincipal.GetValue(nameof(UseSecretsManager)); + var usePasswordManager = claimsPrincipal.GetValue(nameof(UsePasswordManager)); + var smSeats = claimsPrincipal.GetValue(nameof(SmSeats)); + var smServiceAccounts = claimsPrincipal.GetValue(nameof(SmServiceAccounts)); + + return issued <= DateTime.UtcNow && + expires >= DateTime.UtcNow && + installationId == globalSettings.Installation.Id && + licenseKey == organization.LicenseKey && + enabled == organization.Enabled && + planType == organization.PlanType && + seats == organization.Seats && + maxCollections == organization.MaxCollections && + useGroups == organization.UseGroups && + useDirectory == organization.UseDirectory && + useTotp == organization.UseTotp && + selfHost == organization.SelfHost && + name == organization.Name && + usersGetPremium == organization.UsersGetPremium && + useEvents == organization.UseEvents && + use2fa == organization.Use2fa && + useApi == organization.UseApi && + usePolicies == organization.UsePolicies && + useSso == organization.UseSso && + useResetPassword == organization.UseResetPassword && + useKeyConnector == organization.UseKeyConnector && + useScim == organization.UseScim && + useCustomPermissions == organization.UseCustomPermissions && + useSecretsManager == organization.UseSecretsManager && + usePasswordManager == organization.UsePasswordManager && + smSeats == organization.SmSeats && + smServiceAccounts == organization.SmServiceAccounts; + } + + /// + /// Do not extend this method. It is only here for backwards compatibility with old licenses. + /// Instead, extend the VerifyData method using the ClaimsPrincipal. + /// + /// + /// + /// + /// + private bool ObsoleteVerifyData(Organization organization, IGlobalSettings globalSettings) + { + // Do not extend this method. It is only here for backwards compatibility with old licenses. if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) { return false; } - if (ValidLicenseVersion) + if (!ValidLicenseVersion) { - var valid = - globalSettings.Installation.Id == InstallationId && - organization.LicenseKey != null && organization.LicenseKey.Equals(LicenseKey) && - organization.Enabled == Enabled && - organization.PlanType == PlanType && - organization.Seats == Seats && - organization.MaxCollections == MaxCollections && - organization.UseGroups == UseGroups && - organization.UseDirectory == UseDirectory && - organization.UseTotp == UseTotp && - organization.SelfHost == SelfHost && - organization.Name.Equals(Name); + throw new NotSupportedException($"Version {Version} is not supported."); + } - if (valid && Version >= 2) - { - valid = organization.UsersGetPremium == UsersGetPremium; - } + var valid = + globalSettings.Installation.Id == InstallationId && + organization.LicenseKey != null && organization.LicenseKey.Equals(LicenseKey) && + organization.Enabled == Enabled && + organization.PlanType == PlanType && + organization.Seats == Seats && + organization.MaxCollections == MaxCollections && + organization.UseGroups == UseGroups && + organization.UseDirectory == UseDirectory && + organization.UseTotp == UseTotp && + organization.SelfHost == SelfHost && + organization.Name.Equals(Name); - if (valid && Version >= 3) - { - valid = organization.UseEvents == UseEvents; - } + if (valid && Version >= 2) + { + valid = organization.UsersGetPremium == UsersGetPremium; + } - if (valid && Version >= 4) - { - valid = organization.Use2fa == Use2fa; - } + if (valid && Version >= 3) + { + valid = organization.UseEvents == UseEvents; + } - if (valid && Version >= 5) - { - valid = organization.UseApi == UseApi; - } + if (valid && Version >= 4) + { + valid = organization.Use2fa == Use2fa; + } - if (valid && Version >= 6) - { - valid = organization.UsePolicies == UsePolicies; - } + if (valid && Version >= 5) + { + valid = organization.UseApi == UseApi; + } - if (valid && Version >= 7) - { - valid = organization.UseSso == UseSso; - } + if (valid && Version >= 6) + { + valid = organization.UsePolicies == UsePolicies; + } - if (valid && Version >= 8) - { - valid = organization.UseResetPassword == UseResetPassword; - } + if (valid && Version >= 7) + { + valid = organization.UseSso == UseSso; + } - if (valid && Version >= 9) - { - valid = organization.UseKeyConnector == UseKeyConnector; - } + if (valid && Version >= 8) + { + valid = organization.UseResetPassword == UseResetPassword; + } - if (valid && Version >= 10) - { - valid = organization.UseScim == UseScim; - } + if (valid && Version >= 9) + { + valid = organization.UseKeyConnector == UseKeyConnector; + } - if (valid && Version >= 11) - { - valid = organization.UseCustomPermissions == UseCustomPermissions; - } + if (valid && Version >= 10) + { + valid = organization.UseScim == UseScim; + } - /*Version 12 added ExpirationWithoutDatePeriod, but that property is informational only and is not saved + if (valid && Version >= 11) + { + valid = organization.UseCustomPermissions == UseCustomPermissions; + } + + /*Version 12 added ExpirationWithoutDatePeriod, but that property is informational only and is not saved to the Organization object. It's validated as part of the hash but does not need to be validated here. */ - if (valid && Version >= 13) - { - valid = organization.UseSecretsManager == UseSecretsManager && - organization.UsePasswordManager == UsePasswordManager && - organization.SmSeats == SmSeats && - organization.SmServiceAccounts == SmServiceAccounts; - } + if (valid && Version >= 13) + { + valid = organization.UseSecretsManager == UseSecretsManager && + organization.UsePasswordManager == UsePasswordManager && + organization.SmSeats == SmSeats && + organization.SmServiceAccounts == SmServiceAccounts; + } - /* + /* * Version 14 added LimitCollectionCreationDeletion and Version * 15 added AllowAdminAccessToAllCollectionItems, however they * are no longer used and are intentionally excluded from * validation. */ - return valid; - } - - throw new NotSupportedException($"Version {Version} is not supported."); + return valid; } public bool VerifySignature(X509Certificate2 certificate) diff --git a/src/Core/Models/Business/UserLicense.cs b/src/Core/Models/Business/UserLicense.cs index 0f1b191a1d..797aa6692a 100644 --- a/src/Core/Models/Business/UserLicense.cs +++ b/src/Core/Models/Business/UserLicense.cs @@ -1,8 +1,10 @@ using System.Reflection; +using System.Security.Claims; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json.Serialization; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Services; @@ -70,6 +72,7 @@ public class UserLicense : ILicense public LicenseType? LicenseType { get; set; } public string Hash { get; set; } public string Signature { get; set; } + public string Token { get; set; } [JsonIgnore] public byte[] SignatureBytes => Convert.FromBase64String(Signature); @@ -84,6 +87,7 @@ public class UserLicense : ILicense !p.Name.Equals(nameof(Signature)) && !p.Name.Equals(nameof(SignatureBytes)) && !p.Name.Equals(nameof(LicenseType)) && + !p.Name.Equals(nameof(Token)) && ( !forHash || ( @@ -113,8 +117,47 @@ public class UserLicense : ILicense } } - public bool CanUse(User user, out string exception) + public bool CanUse(User user, ClaimsPrincipal claimsPrincipal, out string exception) { + if (string.IsNullOrWhiteSpace(Token) || claimsPrincipal is null) + { + return ObsoleteCanUse(user, out exception); + } + + var errorMessages = new StringBuilder(); + + if (!user.EmailVerified) + { + errorMessages.AppendLine("The user's email is not verified."); + } + + var email = claimsPrincipal.GetValue(nameof(Email)); + if (!email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + errorMessages.AppendLine("The user's email does not match the license email."); + } + + if (errorMessages.Length > 0) + { + exception = $"Invalid license. {errorMessages.ToString().TrimEnd()}"; + return false; + } + + exception = ""; + return true; + } + + /// + /// Do not extend this method. It is only here for backwards compatibility with old licenses. + /// Instead, extend the CanUse method using the ClaimsPrincipal. + /// + /// + /// + /// + /// + private bool ObsoleteCanUse(User user, out string exception) + { + // Do not extend this method. It is only here for backwards compatibility with old licenses. var errorMessages = new StringBuilder(); if (Issued > DateTime.UtcNow) @@ -152,22 +195,46 @@ public class UserLicense : ILicense return true; } - public bool VerifyData(User user) + public bool VerifyData(User user, ClaimsPrincipal claimsPrincipal) { + if (string.IsNullOrWhiteSpace(Token) || claimsPrincipal is null) + { + return ObsoleteVerifyData(user); + } + + var licenseKey = claimsPrincipal.GetValue(nameof(LicenseKey)); + var premium = claimsPrincipal.GetValue(nameof(Premium)); + var email = claimsPrincipal.GetValue(nameof(Email)); + + return licenseKey == user.LicenseKey && + premium == user.Premium && + email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase); + } + + /// + /// Do not extend this method. It is only here for backwards compatibility with old licenses. + /// Instead, extend the VerifyData method using the ClaimsPrincipal. + /// + /// + /// + /// + private bool ObsoleteVerifyData(User user) + { + // Do not extend this method. It is only here for backwards compatibility with old licenses. if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) { return false; } - if (Version == 1) + if (Version != 1) { - return - user.LicenseKey != null && user.LicenseKey.Equals(LicenseKey) && - user.Premium == Premium && - user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); + throw new NotSupportedException($"Version {Version} is not supported."); } - throw new NotSupportedException($"Version {Version} is not supported."); + return + user.LicenseKey != null && user.LicenseKey.Equals(LicenseKey) && + user.Premium == Premium && + user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); } public bool VerifySignature(X509Certificate2 certificate) diff --git a/src/Core/OrganizationFeatures/OrganizationLicenses/Cloud/CloudGetOrganizationLicenseQuery.cs b/src/Core/OrganizationFeatures/OrganizationLicenses/Cloud/CloudGetOrganizationLicenseQuery.cs index b8fad451e2..a4b08736c2 100644 --- a/src/Core/OrganizationFeatures/OrganizationLicenses/Cloud/CloudGetOrganizationLicenseQuery.cs +++ b/src/Core/OrganizationFeatures/OrganizationLicenses/Cloud/CloudGetOrganizationLicenseQuery.cs @@ -33,6 +33,10 @@ public class CloudGetOrganizationLicenseQuery : ICloudGetOrganizationLicenseQuer } var subscriptionInfo = await _paymentService.GetSubscriptionAsync(organization); - return new OrganizationLicense(organization, subscriptionInfo, installationId, _licensingService, version); + + return new OrganizationLicense(organization, subscriptionInfo, installationId, _licensingService, version) + { + Token = await _licensingService.CreateOrganizationTokenAsync(organization, installationId, subscriptionInfo) + }; } } diff --git a/src/Core/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommand.cs b/src/Core/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommand.cs index 1f8c6604b8..ffeee39c07 100644 --- a/src/Core/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommand.cs @@ -39,7 +39,8 @@ public class UpdateOrganizationLicenseCommand : IUpdateOrganizationLicenseComman throw new BadRequestException("License is already in use by another organization."); } - var canUse = license.CanUse(_globalSettings, _licensingService, out var exception) && + var claimsPrincipal = _licensingService.GetClaimsPrincipalFromLicense(license); + var canUse = license.CanUse(_globalSettings, _licensingService, claimsPrincipal, out var exception) && selfHostedOrganization.CanUseLicense(license, out exception); if (!canUse) diff --git a/src/Core/Services/ILicensingService.cs b/src/Core/Services/ILicensingService.cs index e92fa87fd6..7301f7c689 100644 --- a/src/Core/Services/ILicensingService.cs +++ b/src/Core/Services/ILicensingService.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.Entities; +using System.Security.Claims; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -13,5 +14,12 @@ public interface ILicensingService byte[] SignLicense(ILicense license); Task ReadOrganizationLicenseAsync(Organization organization); Task ReadOrganizationLicenseAsync(Guid organizationId); + ClaimsPrincipal GetClaimsPrincipalFromLicense(ILicense license); + Task CreateOrganizationTokenAsync( + Organization organization, + Guid installationId, + SubscriptionInfo subscriptionInfo); + + Task CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo); } diff --git a/src/Core/Services/Implementations/LicensingService.cs b/src/Core/Services/Implementations/LicensingService.cs index 85b8f31200..866f0bb6e1 100644 --- a/src/Core/Services/Implementations/LicensingService.cs +++ b/src/Core/Services/Implementations/LicensingService.cs @@ -1,15 +1,22 @@ -using System.Security.Cryptography.X509Certificates; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Claims; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Licenses.Models; +using Bit.Core.Billing.Licenses.Services; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Settings; using Bit.Core.Utilities; +using IdentityModel; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Microsoft.IdentityModel.Tokens; namespace Bit.Core.Services; @@ -19,27 +26,33 @@ public class LicensingService : ILicensingService private readonly IGlobalSettings _globalSettings; private readonly IUserRepository _userRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IMailService _mailService; private readonly ILogger _logger; + private readonly ILicenseClaimsFactory _organizationLicenseClaimsFactory; + private readonly ILicenseClaimsFactory _userLicenseClaimsFactory; + private readonly IFeatureService _featureService; private IDictionary _userCheckCache = new Dictionary(); public LicensingService( IUserRepository userRepository, IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, IMailService mailService, IWebHostEnvironment environment, ILogger logger, - IGlobalSettings globalSettings) + IGlobalSettings globalSettings, + ILicenseClaimsFactory organizationLicenseClaimsFactory, + IFeatureService featureService, + ILicenseClaimsFactory userLicenseClaimsFactory) { _userRepository = userRepository; _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; _mailService = mailService; _logger = logger; _globalSettings = globalSettings; + _organizationLicenseClaimsFactory = organizationLicenseClaimsFactory; + _featureService = featureService; + _userLicenseClaimsFactory = userLicenseClaimsFactory; var certThumbprint = environment.IsDevelopment() ? "207E64A231E8AA32AAF68A61037C075EBEBD553F" : @@ -104,13 +117,13 @@ public class LicensingService : ILicensingService continue; } - if (!license.VerifyData(org, _globalSettings)) + if (!license.VerifyData(org, GetClaimsPrincipalFromLicense(license), _globalSettings)) { await DisableOrganizationAsync(org, license, "Invalid data."); continue; } - if (!license.VerifySignature(_certificate)) + if (string.IsNullOrWhiteSpace(license.Token) && !license.VerifySignature(_certificate)) { await DisableOrganizationAsync(org, license, "Invalid signature."); continue; @@ -203,13 +216,14 @@ public class LicensingService : ILicensingService return false; } - if (!license.VerifyData(user)) + var claimsPrincipal = GetClaimsPrincipalFromLicense(license); + if (!license.VerifyData(user, claimsPrincipal)) { await DisablePremiumAsync(user, license, "Invalid data."); return false; } - if (!license.VerifySignature(_certificate)) + if (string.IsNullOrWhiteSpace(license.Token) && !license.VerifySignature(_certificate)) { await DisablePremiumAsync(user, license, "Invalid signature."); return false; @@ -234,7 +248,21 @@ public class LicensingService : ILicensingService public bool VerifyLicense(ILicense license) { - return license.VerifySignature(_certificate); + if (string.IsNullOrWhiteSpace(license.Token)) + { + return license.VerifySignature(_certificate); + } + + try + { + _ = GetClaimsPrincipalFromLicense(license); + return true; + } + catch (Exception e) + { + _logger.LogWarning(e, "Invalid token."); + return false; + } } public byte[] SignLicense(ILicense license) @@ -272,4 +300,101 @@ public class LicensingService : ILicensingService using var fs = File.OpenRead(filePath); return await JsonSerializer.DeserializeAsync(fs); } + + public ClaimsPrincipal GetClaimsPrincipalFromLicense(ILicense license) + { + if (string.IsNullOrWhiteSpace(license.Token)) + { + return null; + } + + var audience = license switch + { + OrganizationLicense orgLicense => $"organization:{orgLicense.Id}", + UserLicense userLicense => $"user:{userLicense.Id}", + _ => throw new ArgumentException("Unsupported license type.", nameof(license)), + }; + + var token = license.Token; + var tokenHandler = new JwtSecurityTokenHandler(); + var validationParameters = new TokenValidationParameters + { + ValidateIssuerSigningKey = true, + IssuerSigningKey = new X509SecurityKey(_certificate), + ValidateIssuer = true, + ValidIssuer = "bitwarden", + ValidateAudience = true, + ValidAudience = audience, + ValidateLifetime = true, + ClockSkew = TimeSpan.Zero, + RequireExpirationTime = true + }; + + try + { + return tokenHandler.ValidateToken(token, validationParameters, out _); + } + catch (Exception ex) + { + // Token exceptions thrown are interpreted by the client as Identity errors and cause the user to logout + // Mask them by rethrowing as BadRequestException + throw new BadRequestException($"Invalid license. {ex.Message}"); + } + } + + public async Task CreateOrganizationTokenAsync(Organization organization, Guid installationId, SubscriptionInfo subscriptionInfo) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.SelfHostLicenseRefactor)) + { + return null; + } + + var licenseContext = new LicenseContext + { + InstallationId = installationId, + SubscriptionInfo = subscriptionInfo, + }; + + var claims = await _organizationLicenseClaimsFactory.GenerateClaims(organization, licenseContext); + var audience = $"organization:{organization.Id}"; + + return GenerateToken(claims, audience); + } + + public async Task CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.SelfHostLicenseRefactor)) + { + return null; + } + + var licenseContext = new LicenseContext { SubscriptionInfo = subscriptionInfo }; + var claims = await _userLicenseClaimsFactory.GenerateClaims(user, licenseContext); + var audience = $"user:{user.Id}"; + + return GenerateToken(claims, audience); + } + + private string GenerateToken(List claims, string audience) + { + if (claims.All(claim => claim.Type != JwtClaimTypes.JwtId)) + { + claims.Add(new Claim(JwtClaimTypes.JwtId, Guid.NewGuid().ToString())); + } + + var securityKey = new RsaSecurityKey(_certificate.GetRSAPrivateKey()); + var tokenDescriptor = new SecurityTokenDescriptor + { + Subject = new ClaimsIdentity(claims), + Issuer = "bitwarden", + Audience = audience, + NotBefore = DateTime.UtcNow, + Expires = DateTime.UtcNow.AddYears(1), // Org expiration is a claim + SigningCredentials = new SigningCredentials(securityKey, SecurityAlgorithms.RsaSha256Signature) + }; + + var tokenHandler = new JwtSecurityTokenHandler(); + var token = tokenHandler.CreateToken(tokenDescriptor); + return tokenHandler.WriteToken(token); + } } diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 2199d0a7af..fa8cd3cef8 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -908,7 +908,9 @@ public class UserService : UserManager, IUserService, IDisposable throw new BadRequestException("Invalid license."); } - if (!license.CanUse(user, out var exceptionMessage)) + var claimsPrincipal = _licenseService.GetClaimsPrincipalFromLicense(license); + + if (!license.CanUse(user, claimsPrincipal, out var exceptionMessage)) { throw new BadRequestException(exceptionMessage); } @@ -987,7 +989,9 @@ public class UserService : UserManager, IUserService, IDisposable throw new BadRequestException("Invalid license."); } - if (!license.CanUse(user, out var exceptionMessage)) + var claimsPrincipal = _licenseService.GetClaimsPrincipalFromLicense(license); + + if (!license.CanUse(user, claimsPrincipal, out var exceptionMessage)) { throw new BadRequestException(exceptionMessage); } @@ -1111,7 +1115,9 @@ public class UserService : UserManager, IUserService, IDisposable } } - public async Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, + public async Task GenerateLicenseAsync( + User user, + SubscriptionInfo subscriptionInfo = null, int? version = null) { if (user == null) @@ -1124,8 +1130,13 @@ public class UserService : UserManager, IUserService, IDisposable subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); } - return subscriptionInfo == null ? new UserLicense(user, _licenseService) : - new UserLicense(user, subscriptionInfo, _licenseService); + var userLicense = subscriptionInfo == null + ? new UserLicense(user, _licenseService) + : new UserLicense(user, subscriptionInfo, _licenseService); + + userLicense.Token = await _licenseService.CreateUserTokenAsync(user, subscriptionInfo); + + return userLicense; } public override async Task CheckPasswordAsync(User user, string password) diff --git a/src/Core/Services/NoopImplementations/NoopLicensingService.cs b/src/Core/Services/NoopImplementations/NoopLicensingService.cs index 8eb42a318c..dc733e9a33 100644 --- a/src/Core/Services/NoopImplementations/NoopLicensingService.cs +++ b/src/Core/Services/NoopImplementations/NoopLicensingService.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.Entities; +using System.Security.Claims; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Settings; @@ -53,4 +54,19 @@ public class NoopLicensingService : ILicensingService { return Task.FromResult(null); } + + public ClaimsPrincipal GetClaimsPrincipalFromLicense(ILicense license) + { + return null; + } + + public Task CreateOrganizationTokenAsync(Organization organization, Guid installationId, SubscriptionInfo subscriptionInfo) + { + return Task.FromResult(null); + } + + public Task CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo) + { + return Task.FromResult(null); + } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index ab4108bfef..7585739d82 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -230,7 +230,7 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); + services.AddScoped(); services.AddSingleton(_ => { var options = new LookupClientOptions { Timeout = TimeSpan.FromSeconds(15), UseTcpOnly = true }; diff --git a/test/Core.Test/Models/Business/OrganizationLicenseTests.cs b/test/Core.Test/Models/Business/OrganizationLicenseTests.cs index c2eb0dd934..26945f533e 100644 --- a/test/Core.Test/Models/Business/OrganizationLicenseTests.cs +++ b/test/Core.Test/Models/Business/OrganizationLicenseTests.cs @@ -1,4 +1,5 @@ -using System.Text.Json; +using System.Security.Claims; +using System.Text.Json; using Bit.Core.Models.Business; using Bit.Core.Services; using Bit.Core.Settings; @@ -36,7 +37,7 @@ public class OrganizationLicenseTests [Theory] [BitAutoData(OrganizationLicense.CurrentLicenseFileVersion)] // Previous version (this property is 1 behind) [BitAutoData(OrganizationLicense.CurrentLicenseFileVersion + 1)] // Current version - public void OrganizationLicense_LoadedFromDisk_VerifyData_Passes(int licenseVersion) + public void OrganizationLicense_LoadedFromDisk_VerifyData_Passes(int licenseVersion, ClaimsPrincipal claimsPrincipal) { var license = OrganizationLicenseFileFixtures.GetVersion(licenseVersion); @@ -49,7 +50,7 @@ public class OrganizationLicenseTests { Id = new Guid(OrganizationLicenseFileFixtures.InstallationId) }); - Assert.True(license.VerifyData(organization, globalSettings)); + Assert.True(license.VerifyData(organization, claimsPrincipal, globalSettings)); } /// diff --git a/test/Core.Test/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommandTests.cs index 565f2f32c4..b8e677177c 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationLicenses/UpdateOrganizationLicenseCommandTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.Entities; +using System.Security.Claims; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.Data.Organizations; @@ -48,6 +49,9 @@ public class UpdateOrganizationLicenseCommandTests license.InstallationId = globalSettings.Installation.Id; license.LicenseType = LicenseType.Organization; sutProvider.GetDependency().VerifyLicense(license).Returns(true); + sutProvider.GetDependency() + .GetClaimsPrincipalFromLicense(license) + .Returns((ClaimsPrincipal)null); // Passing values for SelfHostedOrganizationDetails.CanUseLicense // NSubstitute cannot override non-virtual members so we have to ensure the real method passes @@ -79,10 +83,11 @@ public class UpdateOrganizationLicenseCommandTests .Received(1) .ReplaceAndUpdateCacheAsync(Arg.Is( org => AssertPropertyEqual(license, org, - "Id", "MaxStorageGb", "Issued", "Refresh", "Version", "Trial", "LicenseType", - "Hash", "Signature", "SignatureBytes", "InstallationId", "Expires", "ExpirationWithoutGracePeriod") && - // Same property but different name, use explicit mapping - org.ExpirationDate == license.Expires)); + "Id", "MaxStorageGb", "Issued", "Refresh", "Version", "Trial", "LicenseType", + "Hash", "Signature", "SignatureBytes", "InstallationId", "Expires", + "ExpirationWithoutGracePeriod", "Token") && + // Same property but different name, use explicit mapping + org.ExpirationDate == license.Expires)); } finally { diff --git a/test/Core.Test/Services/UserServiceTests.cs b/test/Core.Test/Services/UserServiceTests.cs index aa2c0a5cc9..71cceb86ad 100644 --- a/test/Core.Test/Services/UserServiceTests.cs +++ b/test/Core.Test/Services/UserServiceTests.cs @@ -1,4 +1,5 @@ -using System.Text.Json; +using System.Security.Claims; +using System.Text.Json; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; @@ -62,6 +63,9 @@ public class UserServiceTests sutProvider.GetDependency() .VerifyLicense(userLicense) .Returns(true); + sutProvider.GetDependency() + .GetClaimsPrincipalFromLicense(userLicense) + .Returns((ClaimsPrincipal)null); await sutProvider.Sut.UpdateLicenseAsync(user, userLicense);