diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 294a926022..74cfc1f916 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -628,6 +628,19 @@ public class ProviderBillingService( } } + public async Task UpdatePaymentMethod( + Provider provider, + TokenizedPaymentSource tokenizedPaymentSource, + TaxInformation taxInformation) + { + await Task.WhenAll( + subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource), + subscriberService.UpdateTaxInformation(provider, taxInformation)); + + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically }); + } + public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) { if (command.Configuration.Any(x => x.SeatsMinimum < 0)) diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index 73c992040c..bb1fd7bb25 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,5 +1,6 @@ using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; +using Bit.Core; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; @@ -20,6 +21,7 @@ namespace Bit.Api.Billing.Controllers; [Authorize("Application")] public class ProviderBillingController( ICurrentContext currentContext, + IFeatureService featureService, ILogger logger, IPricingClient pricingClient, IProviderBillingService providerBillingService, @@ -71,6 +73,65 @@ public class ProviderBillingController( "text/csv"); } + [HttpPut("payment-method")] + public async Task UpdatePaymentMethodAsync( + [FromRoute] Guid providerId, + [FromBody] UpdatePaymentMethodRequestBody requestBody) + { + var allowProviderPaymentMethod = featureService.IsEnabled(FeatureFlagKeys.PM18794_ProviderPaymentMethod); + + if (!allowProviderPaymentMethod) + { + return TypedResults.NotFound(); + } + + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); + + if (provider == null) + { + return result; + } + + var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain(); + var taxInformation = requestBody.TaxInformation.ToDomain(); + + await providerBillingService.UpdatePaymentMethod( + provider, + tokenizedPaymentSource, + taxInformation); + + return TypedResults.Ok(); + } + + [HttpPost("payment-method/verify-bank-account")] + public async Task VerifyBankAccountAsync( + [FromRoute] Guid providerId, + [FromBody] VerifyBankAccountRequestBody requestBody) + { + var allowProviderPaymentMethod = featureService.IsEnabled(FeatureFlagKeys.PM18794_ProviderPaymentMethod); + + if (!allowProviderPaymentMethod) + { + return TypedResults.NotFound(); + } + + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); + + if (provider == null) + { + return result; + } + + if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM")) + { + return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'"); + } + + await subscriberService.VerifyBankAccount(provider, requestBody.DescriptorCode); + + return TypedResults.Ok(); + } + [HttpGet("subscription")] public async Task GetSubscriptionAsync([FromRoute] Guid providerId) { @@ -102,12 +163,32 @@ public class ProviderBillingController( var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription); + var paymentSource = await subscriberService.GetPaymentSource(provider); + var response = ProviderSubscriptionResponse.From( subscription, configuredProviderPlans, taxInformation, subscriptionSuspension, - provider); + provider, + paymentSource); + + return TypedResults.Ok(response); + } + + [HttpGet("tax-information")] + public async Task GetTaxInformationAsync([FromRoute] Guid providerId) + { + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); + + if (provider == null) + { + return result; + } + + var taxInformation = await subscriberService.GetTaxInformation(provider); + + var response = TaxInformationResponse.From(taxInformation); return TypedResults.Ok(response); } diff --git a/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs index 34c3817e51..ea1479c9df 100644 --- a/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs +++ b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs @@ -16,7 +16,8 @@ public record ProviderSubscriptionResponse( TaxInformation TaxInformation, DateTime? CancelAt, SubscriptionSuspension Suspension, - ProviderType ProviderType) + ProviderType ProviderType, + PaymentSource PaymentSource) { private const string _annualCadence = "Annual"; private const string _monthlyCadence = "Monthly"; @@ -26,7 +27,8 @@ public record ProviderSubscriptionResponse( ICollection providerPlans, TaxInformation taxInformation, SubscriptionSuspension subscriptionSuspension, - Provider provider) + Provider provider, + PaymentSource paymentSource) { var providerPlanResponses = providerPlans .Select(providerPlan => @@ -57,7 +59,8 @@ public record ProviderSubscriptionResponse( taxInformation, subscription.CancelAt, subscriptionSuspension, - provider.Type); + provider.Type, + paymentSource); } } diff --git a/src/Api/Utilities/CommandResultExtensions.cs b/src/Api/Utilities/CommandResultExtensions.cs index 39104db7ff..c7315a0fa0 100644 --- a/src/Api/Utilities/CommandResultExtensions.cs +++ b/src/Api/Utilities/CommandResultExtensions.cs @@ -12,7 +12,7 @@ public static class CommandResultExtensions NoRecordFoundFailure failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status404NotFound }, BadRequestFailure failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status400BadRequest }, Failure failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status400BadRequest }, - Success success => new ObjectResult(success.Data) { StatusCode = StatusCodes.Status200OK }, + Success success => new ObjectResult(success.Value) { StatusCode = StatusCodes.Status200OK }, _ => throw new InvalidOperationException($"Unhandled commandResult type: {commandResult.GetType().Name}") }; } diff --git a/src/Core/AdminConsole/Errors/Error.cs b/src/Core/AdminConsole/Errors/Error.cs new file mode 100644 index 0000000000..6c8eed41a4 --- /dev/null +++ b/src/Core/AdminConsole/Errors/Error.cs @@ -0,0 +1,3 @@ +namespace Bit.Core.AdminConsole.Errors; + +public record Error(string Message, T ErroredValue); diff --git a/src/Core/AdminConsole/Errors/InsufficientPermissionsError.cs b/src/Core/AdminConsole/Errors/InsufficientPermissionsError.cs new file mode 100644 index 0000000000..d04ceba7c9 --- /dev/null +++ b/src/Core/AdminConsole/Errors/InsufficientPermissionsError.cs @@ -0,0 +1,11 @@ +namespace Bit.Core.AdminConsole.Errors; + +public record InsufficientPermissionsError(string Message, T ErroredValue) : Error(Message, ErroredValue) +{ + public const string Code = "Insufficient Permissions"; + + public InsufficientPermissionsError(T ErroredValue) : this(Code, ErroredValue) + { + + } +} diff --git a/src/Core/AdminConsole/Errors/RecordNotFoundError.cs b/src/Core/AdminConsole/Errors/RecordNotFoundError.cs new file mode 100644 index 0000000000..25a169efe1 --- /dev/null +++ b/src/Core/AdminConsole/Errors/RecordNotFoundError.cs @@ -0,0 +1,11 @@ +namespace Bit.Core.AdminConsole.Errors; + +public record RecordNotFoundError(string Message, T ErroredValue) : Error(Message, ErroredValue) +{ + public const string Code = "Record Not Found"; + + public RecordNotFoundError(T ErroredValue) : this(Code, ErroredValue) + { + + } +} diff --git a/src/Core/AdminConsole/Shared/Validation/IValidator.cs b/src/Core/AdminConsole/Shared/Validation/IValidator.cs new file mode 100644 index 0000000000..d90386f00e --- /dev/null +++ b/src/Core/AdminConsole/Shared/Validation/IValidator.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.AdminConsole.Shared.Validation; + +public interface IValidator +{ + public Task> ValidateAsync(T value); +} diff --git a/src/Core/AdminConsole/Shared/Validation/ValidationResult.cs b/src/Core/AdminConsole/Shared/Validation/ValidationResult.cs new file mode 100644 index 0000000000..e25103e701 --- /dev/null +++ b/src/Core/AdminConsole/Shared/Validation/ValidationResult.cs @@ -0,0 +1,15 @@ +using Bit.Core.AdminConsole.Errors; + +namespace Bit.Core.AdminConsole.Shared.Validation; + +public abstract record ValidationResult; + +public record Valid : ValidationResult +{ + public T Value { get; init; } +} + +public record Invalid : ValidationResult +{ + public IEnumerable> Errors { get; init; } +} diff --git a/src/Core/Billing/Services/IProviderBillingService.cs b/src/Core/Billing/Services/IProviderBillingService.cs index d6983da03e..64585f3361 100644 --- a/src/Core/Billing/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Services/IProviderBillingService.cs @@ -95,5 +95,16 @@ public interface IProviderBillingService Task SetupSubscription( Provider provider); + /// + /// Updates the 's payment source and tax information and then sets their subscription's collection_method to be "charge_automatically". + /// + /// The to update the payment source and tax information for. + /// The tokenized payment source (ex. Credit Card) to attach to the . + /// The 's updated tax information. + Task UpdatePaymentMethod( + Provider provider, + TokenizedPaymentSource tokenizedPaymentSource, + TaxInformation taxInformation); + Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command); } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 72224319fb..e09422871d 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -175,6 +175,8 @@ public static class FeatureFlagKeys public const string WebPush = "web-push"; public const string AndroidImportLoginsFlow = "import-logins-flow"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; + public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; + public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; public static List GetAllKeys() { diff --git a/src/Core/Models/Commands/CommandResult.cs b/src/Core/Models/Commands/CommandResult.cs index ae14b7d2f9..a8ec772fc1 100644 --- a/src/Core/Models/Commands/CommandResult.cs +++ b/src/Core/Models/Commands/CommandResult.cs @@ -1,5 +1,7 @@ #nullable enable +using Bit.Core.AdminConsole.Errors; + namespace Bit.Core.Models.Commands; public class CommandResult(IEnumerable errors) @@ -9,7 +11,6 @@ public class CommandResult(IEnumerable errors) public bool Success => ErrorMessages.Count == 0; public bool HasErrors => ErrorMessages.Count > 0; public List ErrorMessages { get; } = errors.ToList(); - public CommandResult() : this(Array.Empty()) { } } @@ -29,22 +30,30 @@ public class Success : CommandResult { } -public abstract class CommandResult -{ +public abstract class CommandResult; +public class Success(T value) : CommandResult +{ + public T Value { get; } = value; } -public class Success(T data) : CommandResult +public class Failure(IEnumerable errorMessages) : CommandResult { - public T? Data { get; init; } = data; + public List ErrorMessages { get; } = errorMessages.ToList(); + + public string ErrorMessage => string.Join(" ", ErrorMessages); + + public Failure(string error) : this([error]) { } } -public class Failure(IEnumerable errorMessage) : CommandResult +public class Partial : CommandResult { - public IEnumerable ErrorMessages { get; init; } = errorMessage; + public T[] Successes { get; set; } = []; + public Error[] Failures { get; set; } = []; - public Failure(string errorMessage) : this(new[] { errorMessage }) + public Partial(IEnumerable successfulItems, IEnumerable> failedItems) { + Successes = successfulItems.ToArray(); + Failures = failedItems.ToArray(); } } - diff --git a/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs b/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs index 4744f4aca3..36a08326ab 100644 --- a/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/DeviceValidator.cs @@ -250,21 +250,26 @@ public class DeviceValidator( var customResponse = new Dictionary(); switch (errorType) { + /* + * The ErrorMessage is brittle and is used to control the flow in the clients. Do not change them without updating the client as well. + * There is a backwards compatibility issue as well: if you make a change on the clients then ensure that they are backwards + * compatible. + */ case DeviceValidationResultType.InvalidUser: result.ErrorDescription = "Invalid user"; - customResponse.Add("ErrorModel", new ErrorResponseModel("Invalid user.")); + customResponse.Add("ErrorModel", new ErrorResponseModel("invalid user")); break; case DeviceValidationResultType.InvalidNewDeviceOtp: result.ErrorDescription = "Invalid New Device OTP"; - customResponse.Add("ErrorModel", new ErrorResponseModel("Invalid new device OTP. Try again.")); + customResponse.Add("ErrorModel", new ErrorResponseModel("invalid new device otp")); break; case DeviceValidationResultType.NewDeviceVerificationRequired: result.ErrorDescription = "New device verification required"; - customResponse.Add("ErrorModel", new ErrorResponseModel("New device verification required.")); + customResponse.Add("ErrorModel", new ErrorResponseModel("new device verification required")); break; case DeviceValidationResultType.NoDeviceInformationProvided: result.ErrorDescription = "No device information provided"; - customResponse.Add("ErrorModel", new ErrorResponseModel("No device information provided.")); + customResponse.Add("ErrorModel", new ErrorResponseModel("no device information provided")); break; } return (result, customResponse); diff --git a/test/Core.Test/AdminConsole/Shared/IValidatorTests.cs b/test/Core.Test/AdminConsole/Shared/IValidatorTests.cs new file mode 100644 index 0000000000..abb49c25c6 --- /dev/null +++ b/test/Core.Test/AdminConsole/Shared/IValidatorTests.cs @@ -0,0 +1,58 @@ +using Bit.Core.AdminConsole.Errors; +using Bit.Core.AdminConsole.Shared.Validation; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.Shared; + +public class IValidatorTests +{ + public class TestClass + { + public string Name { get; set; } = string.Empty; + } + + public record InvalidRequestError(T ErroredValue) : Error(Code, ErroredValue) + { + public const string Code = "InvalidRequest"; + } + + public class TestClassValidator : IValidator + { + public Task> ValidateAsync(TestClass value) + { + if (string.IsNullOrWhiteSpace(value.Name)) + { + return Task.FromResult>(new Invalid + { + Errors = [new InvalidRequestError(value)] + }); + } + + return Task.FromResult>(new Valid { Value = value }); + } + } + + [Fact] + public async Task ValidateAsync_WhenSomethingIsInvalid_ReturnsInvalidWithError() + { + var example = new TestClass(); + + var result = await new TestClassValidator().ValidateAsync(example); + + Assert.IsType>(result); + var invalidResult = result as Invalid; + Assert.Equal(InvalidRequestError.Code, invalidResult.Errors.First().Message); + } + + [Fact] + public async Task ValidateAsync_WhenIsValid_ReturnsValid() + { + var example = new TestClass { Name = "Valid" }; + + var result = await new TestClassValidator().ValidateAsync(example); + + Assert.IsType>(result); + var validResult = result as Valid; + Assert.Equal(example.Name, validResult.Value.Name); + } +} diff --git a/test/Core.Test/Models/Commands/CommandResultTests.cs b/test/Core.Test/Models/Commands/CommandResultTests.cs new file mode 100644 index 0000000000..c500fef4f5 --- /dev/null +++ b/test/Core.Test/Models/Commands/CommandResultTests.cs @@ -0,0 +1,53 @@ +using Bit.Core.AdminConsole.Errors; +using Bit.Core.Models.Commands; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Core.Test.Models.Commands; + +public class CommandResultTests +{ + public class TestItem + { + public Guid Id { get; set; } + public string Value { get; set; } + } + + public CommandResult BulkAction(IEnumerable items) + { + var itemList = items.ToList(); + var successfulItems = items.Where(x => x.Value == "SuccessfulRequest").ToArray(); + + var failedItems = itemList.Except(successfulItems).ToArray(); + + var notFound = failedItems.First(x => x.Value == "Failed due to not found"); + var invalidPermissions = failedItems.First(x => x.Value == "Failed due to invalid permissions"); + + var notFoundError = new RecordNotFoundError(notFound); + var insufficientPermissionsError = new InsufficientPermissionsError(invalidPermissions); + + return new Partial(successfulItems.ToArray(), [notFoundError, insufficientPermissionsError]); + } + + [Theory] + [BitAutoData] + public void Partial_CommandResult_BulkRequestWithSuccessAndFailures(Guid successId1, Guid failureId1, Guid failureId2) + { + var listOfRecords = new List + { + new TestItem() { Id = successId1, Value = "SuccessfulRequest" }, + new TestItem() { Id = failureId1, Value = "Failed due to not found" }, + new TestItem() { Id = failureId2, Value = "Failed due to invalid permissions" } + }; + + var result = BulkAction(listOfRecords); + + Assert.IsType>(result); + + var failures = (result as Partial).Failures.ToArray(); + var success = (result as Partial).Successes.First(); + + Assert.Equal(listOfRecords.First(), success); + Assert.Equal(2, failures.Length); + } +} diff --git a/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs b/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs index bcb357d640..b71dd6c230 100644 --- a/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/DeviceValidatorTests.cs @@ -172,7 +172,7 @@ public class DeviceValidatorTests Assert.False(result); Assert.NotNull(context.CustomResponse["ErrorModel"]); - var expectedErrorModel = new ErrorResponseModel("No device information provided."); + var expectedErrorModel = new ErrorResponseModel("no device information provided"); var actualResponse = (ErrorResponseModel)context.CustomResponse["ErrorModel"]; Assert.Equal(expectedErrorModel.Message, actualResponse.Message); } @@ -418,7 +418,7 @@ public class DeviceValidatorTests Assert.False(result); Assert.NotNull(context.CustomResponse["ErrorModel"]); // PM-13340: The error message should be "invalid user" instead of "no device information provided" - var expectedErrorMessage = "No device information provided."; + var expectedErrorMessage = "no device information provided"; var actualResponse = (ErrorResponseModel)context.CustomResponse["ErrorModel"]; Assert.Equal(expectedErrorMessage, actualResponse.Message); } @@ -552,7 +552,7 @@ public class DeviceValidatorTests Assert.False(result); Assert.NotNull(context.CustomResponse["ErrorModel"]); - var expectedErrorMessage = "Invalid new device OTP. Try again."; + var expectedErrorMessage = "invalid new device otp"; var actualResponse = (ErrorResponseModel)context.CustomResponse["ErrorModel"]; Assert.Equal(expectedErrorMessage, actualResponse.Message); } @@ -604,7 +604,7 @@ public class DeviceValidatorTests Assert.False(result); Assert.NotNull(context.CustomResponse["ErrorModel"]); - var expectedErrorMessage = "New device verification required."; + var expectedErrorMessage = "new device verification required"; var actualResponse = (ErrorResponseModel)context.CustomResponse["ErrorModel"]; Assert.Equal(expectedErrorMessage, actualResponse.Message); }