From eb1eb0554c6d29f333596a74a1a7467292dd16f2 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 29 Feb 2024 08:15:18 -0500 Subject: [PATCH] Resolve RisksSubscriptionFailure bugs (#3790) --- .../Controllers/OrganizationsController.cs | 8 ++++---- .../OrganizationBillingStatusResponseModel.cs | 13 +++++++++++++ ...tionRisksSubscriptionFailureResponseModel.cs | 17 ----------------- src/Core/Constants.cs | 1 + src/Core/Models/Business/SubscriptionInfo.cs | 2 ++ .../Implementations/StripePaymentService.cs | 10 +++++++--- 6 files changed, 27 insertions(+), 24 deletions(-) create mode 100644 src/Api/AdminConsole/Models/Response/Organizations/OrganizationBillingStatusResponseModel.cs delete mode 100644 src/Api/AdminConsole/Models/Response/Organizations/OrganizationRisksSubscriptionFailureResponseModel.cs diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 8c9fb9091a..75296d091f 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -261,19 +261,19 @@ public class OrganizationsController : Controller return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); } - [HttpGet("{id}/risks-subscription-failure")] - public async Task RisksSubscriptionFailure(Guid id) + [HttpGet("{id}/billing-status")] + public async Task GetBillingStatus(Guid id) { if (!await _currentContext.EditPaymentMethods(id)) { - return new OrganizationRisksSubscriptionFailureResponseModel(id, false); + throw new NotFoundException(); } var organization = await _organizationRepository.GetByIdAsync(id); var risksSubscriptionFailure = await _paymentService.RisksSubscriptionFailure(organization); - return new OrganizationRisksSubscriptionFailureResponseModel(id, risksSubscriptionFailure); + return new OrganizationBillingStatusResponseModel(organization, risksSubscriptionFailure); } [HttpPost("")] diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationBillingStatusResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationBillingStatusResponseModel.cs new file mode 100644 index 0000000000..635383fa3b --- /dev/null +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationBillingStatusResponseModel.cs @@ -0,0 +1,13 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Models.Api; + +namespace Bit.Api.AdminConsole.Models.Response.Organizations; + +public class OrganizationBillingStatusResponseModel( + Organization organization, + bool risksSubscriptionFailure) : ResponseModel("organizationBillingStatus") +{ + public Guid OrganizationId { get; } = organization.Id; + public string OrganizationName { get; } = organization.Name; + public bool RisksSubscriptionFailure { get; } = risksSubscriptionFailure; +} diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationRisksSubscriptionFailureResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationRisksSubscriptionFailureResponseModel.cs deleted file mode 100644 index e91275da3c..0000000000 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationRisksSubscriptionFailureResponseModel.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Bit.Core.Models.Api; - -namespace Bit.Api.AdminConsole.Models.Response.Organizations; - -public class OrganizationRisksSubscriptionFailureResponseModel : ResponseModel -{ - public Guid OrganizationId { get; } - public bool RisksSubscriptionFailure { get; } - - public OrganizationRisksSubscriptionFailureResponseModel( - Guid organizationId, - bool risksSubscriptionFailure) : base("organizationRisksSubscriptionFailure") - { - OrganizationId = organizationId; - RisksSubscriptionFailure = risksSubscriptionFailure; - } -} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 4d1154a0b1..438154f3ce 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -133,6 +133,7 @@ public static class FeatureFlagKeys public const string PM5766AutomaticTax = "PM-5766-automatic-tax"; public const string PM5864DollarThreshold = "PM-5864-dollar-threshold"; public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; + public const string ShowPaymentMethodWarningBanners = "show-payment-method-warning-banners"; public static List GetAllKeys() { diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index e2a689f613..23f8f95278 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -42,6 +42,7 @@ public class SubscriptionInfo { Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); } + CollectionMethod = sub.CollectionMethod; } public DateTime? TrialStartDate { get; set; } @@ -54,6 +55,7 @@ public class SubscriptionInfo public string Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); + public string CollectionMethod { get; set; } public class BillingSubscriptionItem { diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 1f7d488179..2aa715eefb 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1738,13 +1738,17 @@ public class StripePaymentService : IPaymentService { var subscriptionInfo = await GetSubscriptionAsync(organization); - if (subscriptionInfo.Subscription is not { Status: "active" or "trialing" or "past_due" } || - subscriptionInfo.UpcomingInvoice == null) + if (subscriptionInfo.Subscription is not + { + Status: "active" or "trialing" or "past_due", + CollectionMethod: "charge_automatically" + } + || subscriptionInfo.UpcomingInvoice == null) { return false; } - var customer = await GetCustomerAsync(organization.GatewayCustomerId); + var customer = await GetCustomerAsync(organization.GatewayCustomerId, GetCustomerPaymentOptions()); var paymentSource = await GetBillingPaymentSourceAsync(customer);