1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-13 17:18:14 -05:00

Merge branch 'main' into ac/pm-15621/refactor-delete-command

This commit is contained in:
Jimmy Vo 2025-02-04 14:39:24 -05:00
commit 48ced72b0b
No known key found for this signature in database
GPG Key ID: 7CB834D6F4FFCA11
58 changed files with 1323 additions and 126 deletions

View File

@ -314,7 +314,7 @@ jobs:
output-format: sarif
- name: Upload Grype results to GitHub
uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
uses: github/codeql-action/upload-sarif@dd746615b3b9d728a6a37ca2045b68ca76d4841a # v3.28.8
with:
sarif_file: ${{ steps.container-scan.outputs.sarif }}

View File

@ -46,7 +46,7 @@ jobs:
--output-path . ${{ env.INCREMENTAL }}
- name: Upload Checkmarx results to GitHub
uses: github/codeql-action/upload-sarif@48ab28a6f5dbc2a99bf1e0131198dd8f1df78169 # v3.28.0
uses: github/codeql-action/upload-sarif@dd746615b3b9d728a6a37ca2045b68ca76d4841a # v3.28.8
with:
sarif_file: cx_result.sarif
@ -85,6 +85,6 @@ jobs:
/d:sonar.test.inclusions=test/,bitwarden_license/test/ \
/d:sonar.exclusions=test/,bitwarden_license/test/ \
/o:"${{ github.repository_owner }}" /d:sonar.token="${{ secrets.SONAR_TOKEN }}" \
/d:sonar.host.url="https://sonarcloud.io"
/d:sonar.host.url="https://sonarcloud.io" ${{ contains(github.event_name, 'pull_request') && format('/d:sonar.pullrequest.key={0}', github.event.pull_request.number) || '' }}
dotnet build
dotnet-sonarscanner end /d:sonar.token="${{ secrets.SONAR_TOKEN }}"

View File

@ -125,6 +125,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Notifications.Test", "test\
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Infrastructure.Dapper.Test", "test\Infrastructure.Dapper.Test\Infrastructure.Dapper.Test.csproj", "{4A725DB3-BE4F-4C23-9087-82D0610D67AF}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Events.IntegrationTest", "test\Events.IntegrationTest\Events.IntegrationTest.csproj", "{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -313,6 +315,10 @@ Global
{4A725DB3-BE4F-4C23-9087-82D0610D67AF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4A725DB3-BE4F-4C23-9087-82D0610D67AF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4A725DB3-BE4F-4C23-9087-82D0610D67AF}.Release|Any CPU.Build.0 = Release|Any CPU
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -363,6 +369,7 @@ Global
{81673EFB-7134-4B4B-A32F-1EA05F0EF3CE} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
{90D85D8F-5577-4570-A96E-5A2E185F0F6F} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
{4A725DB3-BE4F-4C23-9087-82D0610D67AF} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {E01CBF68-2E20-425F-9EDB-E0A6510CA92F}

View File

@ -1,12 +1,15 @@
using System.Globalization;
using Bit.Commercial.Core.Billing.Models;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
@ -24,6 +27,7 @@ using Stripe;
namespace Bit.Commercial.Core.Billing;
public class ProviderBillingService(
IEventService eventService,
IGlobalSettings globalSettings,
ILogger<ProviderBillingService> logger,
IOrganizationRepository organizationRepository,
@ -31,10 +35,93 @@ public class ProviderBillingService(
IProviderInvoiceItemRepository providerInvoiceItemRepository,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderPlanRepository providerPlanRepository,
IProviderUserRepository providerUserRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService) : IProviderBillingService
{
[RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)]
public async Task AddExistingOrganization(
Provider provider,
Organization organization,
string key)
{
await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
new SubscriptionUpdateOptions
{
CancelAtPeriodEnd = false
});
var subscription =
await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId,
new SubscriptionCancelOptions
{
CancellationDetails = new SubscriptionCancellationDetailsOptions
{
Comment = $"Organization was added to Provider with ID {provider.Id}"
},
InvoiceNow = true,
Prorate = true,
Expand = ["latest_invoice", "test_clock"]
});
var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow;
var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now;
if (!wasTrialing && subscription.LatestInvoice.Status == StripeConstants.InvoiceStatus.Draft)
{
await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId,
new InvoiceFinalizeOptions { AutoAdvance = true });
}
var managedPlanType = await GetManagedPlanTypeAsync(provider, organization);
// TODO: Replace with PricingClient
var plan = StaticStore.GetPlan(managedPlanType);
organization.Plan = plan.Name;
organization.PlanType = plan.Type;
organization.MaxCollections = plan.PasswordManager.MaxCollections;
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso;
organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory;
organization.UseTotp = plan.HasTotp;
organization.Use2fa = plan.Has2fa;
organization.UseApi = plan.HasApi;
organization.UseResetPassword = plan.HasResetPassword;
organization.SelfHost = plan.HasSelfHost;
organization.UsersGetPremium = plan.UsersGetPremium;
organization.UseCustomPermissions = plan.HasCustomPermissions;
organization.UseScim = plan.HasScim;
organization.UseKeyConnector = plan.HasKeyConnector;
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.BillingEmail = provider.BillingEmail!;
organization.GatewaySubscriptionId = null;
organization.ExpirationDate = null;
organization.MaxAutoscaleSeats = null;
organization.Status = OrganizationStatusType.Managed;
var providerOrganization = new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organization.Id,
Key = key
};
await Task.WhenAll(
organizationRepository.ReplaceAsync(organization),
providerOrganizationRepository.CreateAsync(providerOrganization),
ScaleSeats(provider, organization.PlanType, organization.Seats!.Value)
);
await eventService.LogProviderOrganizationEventAsync(
providerOrganization,
EventType.ProviderOrganization_Added);
}
public async Task ChangePlan(ChangeProviderPlanCommand command)
{
var plan = await providerPlanRepository.GetByIdAsync(command.ProviderPlanId);
@ -206,6 +293,81 @@ public class ProviderBillingService(
return memoryStream.ToArray();
}
[RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)]
public async Task<IEnumerable<AddableOrganization>> GetAddableOrganizations(
Provider provider,
Guid userId)
{
var providerUser = await providerUserRepository.GetByProviderUserAsync(provider.Id, userId);
if (providerUser is not { Status: ProviderUserStatusType.Confirmed })
{
throw new UnauthorizedAccessException();
}
var candidates = await organizationRepository.GetAddableToProviderByUserIdAsync(userId, provider.Type);
var active = (await Task.WhenAll(candidates.Select(async organization =>
{
var subscription = await subscriberService.GetSubscription(organization);
return (organization, subscription);
})))
.Where(pair => pair.subscription is
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue
}).ToList();
if (active.Count == 0)
{
return [];
}
return await Task.WhenAll(active.Select(async pair =>
{
var (organization, _) = pair;
var planName = DerivePlanName(provider, organization);
var addable = new AddableOrganization(
organization.Id,
organization.Name,
planName,
organization.Seats!.Value);
if (providerUser.Type != ProviderUserType.ServiceUser)
{
return addable;
}
var applicablePlanType = await GetManagedPlanTypeAsync(provider, organization);
var requiresPurchase =
await SeatAdjustmentResultsInPurchase(provider, applicablePlanType, organization.Seats!.Value);
return addable with { Disabled = requiresPurchase };
}));
string DerivePlanName(Provider localProvider, Organization localOrganization)
{
if (localProvider.Type == ProviderType.Msp)
{
return localOrganization.PlanType switch
{
var planType when PlanConstants.EnterprisePlanTypes.Contains(planType) => "Enterprise",
var planType when PlanConstants.TeamsPlanTypes.Contains(planType) => "Teams",
_ => throw new BillingException()
};
}
// TODO: Replace with PricingClient
var plan = StaticStore.GetPlan(localOrganization.PlanType);
return plan.Name;
}
}
public async Task ScaleSeats(
Provider provider,
PlanType planType,
@ -582,4 +744,21 @@ public class ProviderBillingService(
return providerPlan;
}
private async Task<PlanType> GetManagedPlanTypeAsync(
Provider provider,
Organization organization)
{
if (provider.Type == ProviderType.MultiOrganizationEnterprise)
{
return (await providerPlanRepository.GetByProviderId(provider.Id)).First().PlanType;
}
return organization.PlanType switch
{
var planType when PlanConstants.TeamsPlanTypes.Contains(planType) => PlanType.TeamsMonthly,
var planType when PlanConstants.EnterprisePlanTypes.Contains(planType) => PlanType.EnterpriseMonthly,
_ => throw new BillingException()
};
}
}

View File

@ -20,4 +20,8 @@ IDP_SP_ACS_URL=http://localhost:51822/saml2/yourOrgIdHere/Acs
# Optional reverse proxy configuration
# Should match server listen ports in reverse-proxy.conf
API_PROXY_PORT=4100
IDENTITY_PROXY_PORT=33756
IDENTITY_PROXY_PORT=33756
# Optional RabbitMQ configuration
RABBITMQ_DEFAULT_USER=bitwarden
RABBITMQ_DEFAULT_PASS=SET_A_PASSWORD_HERE_123

View File

@ -84,6 +84,20 @@ services:
profiles:
- idp
rabbitmq:
image: rabbitmq:management
container_name: rabbitmq
ports:
- "5672:5672"
- "15672:15672"
environment:
RABBITMQ_DEFAULT_USER: ${RABBITMQ_DEFAULT_USER}
RABBITMQ_DEFAULT_PASS: ${RABBITMQ_DEFAULT_PASS}
volumes:
- rabbitmq_data:/var/lib/rabbitmq_data
profiles:
- rabbitmq
reverse-proxy:
image: nginx:alpine
container_name: reverse-proxy
@ -99,3 +113,4 @@ volumes:
mssql_dev_data:
postgres_dev_data:
mysql_dev_data:
rabbitmq_data:

View File

@ -421,6 +421,11 @@ public class OrganizationsController : Controller
private void UpdateOrganization(Organization organization, OrganizationEditModel model)
{
if (_accessControlService.UserHasPermission(Permission.Org_Name_Edit))
{
organization.Name = WebUtility.HtmlEncode(model.Name);
}
if (_accessControlService.UserHasPermission(Permission.Org_CheckEnabledBox))
{
organization.Enabled = model.Enabled;

View File

@ -12,6 +12,7 @@
var canViewBilling = AccessControlService.UserHasPermission(Permission.Org_Billing_View);
var canViewPlan = AccessControlService.UserHasPermission(Permission.Org_Plan_View);
var canViewLicensing = AccessControlService.UserHasPermission(Permission.Org_Licensing_View);
var canEditName = AccessControlService.UserHasPermission(Permission.Org_Name_Edit);
var canCheckEnabled = AccessControlService.UserHasPermission(Permission.Org_CheckEnabledBox);
var canEditPlan = AccessControlService.UserHasPermission(Permission.Org_Plan_Edit);
var canEditLicensing = AccessControlService.UserHasPermission(Permission.Org_Licensing_Edit);
@ -28,7 +29,7 @@
<div class="col-sm">
<div class="mb-3">
<label class="form-label" asp-for="Name"></label>
<input type="text" class="form-control" asp-for="Name" value="@Model.Name" required>
<input type="text" class="form-control" asp-for="Name" value="@Model.Name" required disabled="@(canEditName ? null : "disabled")">
</div>
</div>
</div>

View File

@ -22,6 +22,7 @@ public enum Permission
Org_List_View,
Org_OrgInformation_View,
Org_GeneralDetails_View,
Org_Name_Edit,
Org_CheckEnabledBox,
Org_BusinessInformation_View,
Org_InitiateTrial,

View File

@ -24,6 +24,7 @@ public static class RolePermissionMapping
Permission.User_Billing_Edit,
Permission.User_Billing_LaunchGateway,
Permission.User_NewDeviceException_Edit,
Permission.Org_Name_Edit,
Permission.Org_CheckEnabledBox,
Permission.Org_List_View,
Permission.Org_OrgInformation_View,
@ -71,6 +72,7 @@ public static class RolePermissionMapping
Permission.User_Billing_Edit,
Permission.User_Billing_LaunchGateway,
Permission.User_NewDeviceException_Edit,
Permission.Org_Name_Edit,
Permission.Org_CheckEnabledBox,
Permission.Org_List_View,
Permission.Org_OrgInformation_View,
@ -116,6 +118,7 @@ public static class RolePermissionMapping
Permission.User_Billing_View,
Permission.User_Billing_LaunchGateway,
Permission.User_NewDeviceException_Edit,
Permission.Org_Name_Edit,
Permission.Org_CheckEnabledBox,
Permission.Org_List_View,
Permission.Org_OrgInformation_View,
@ -148,6 +151,7 @@ public static class RolePermissionMapping
Permission.User_Billing_View,
Permission.User_Billing_Edit,
Permission.User_Billing_LaunchGateway,
Permission.Org_Name_Edit,
Permission.Org_CheckEnabledBox,
Permission.Org_List_View,
Permission.Org_OrgInformation_View,
@ -185,6 +189,7 @@ public static class RolePermissionMapping
Permission.User_Premium_View,
Permission.User_Licensing_View,
Permission.User_Licensing_Edit,
Permission.Org_Name_Edit,
Permission.Org_CheckEnabledBox,
Permission.Org_List_View,
Permission.Org_OrgInformation_View,

View File

@ -1,4 +1,6 @@
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Controllers;
using Bit.Api.Billing.Models.Requests;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Services;
@ -7,13 +9,15 @@ using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers;
namespace Bit.Api.AdminConsole.Controllers;
[Route("providers/{providerId:guid}/clients")]
public class ProviderClientsController(
ICurrentContext currentContext,
IFeatureService featureService,
ILogger<BaseProviderController> logger,
IOrganizationRepository organizationRepository,
IProviderBillingService providerBillingService,
@ -22,7 +26,10 @@ public class ProviderClientsController(
IProviderService providerService,
IUserService userService) : BaseProviderController(currentContext, logger, providerRepository, userService)
{
private readonly ICurrentContext _currentContext = currentContext;
[HttpPost]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> CreateAsync(
[FromRoute] Guid providerId,
[FromBody] CreateClientOrganizationRequestBody requestBody)
@ -80,6 +87,7 @@ public class ProviderClientsController(
}
[HttpPut("{providerOrganizationId:guid}")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> UpdateAsync(
[FromRoute] Guid providerId,
[FromRoute] Guid providerOrganizationId,
@ -113,7 +121,7 @@ public class ProviderClientsController(
clientOrganization.PlanType,
seatAdjustment);
if (seatAdjustmentResultsInPurchase && !currentContext.ProviderProviderAdmin(provider.Id))
if (seatAdjustmentResultsInPurchase && !_currentContext.ProviderProviderAdmin(provider.Id))
{
return Error.Unauthorized("Service users cannot purchase additional seats.");
}
@ -127,4 +135,58 @@ public class ProviderClientsController(
return TypedResults.Ok();
}
[HttpGet("addable")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> GetAddableOrganizationsAsync([FromRoute] Guid providerId)
{
if (!featureService.IsEnabled(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal))
{
return Error.NotFound();
}
var (provider, result) = await TryGetBillableProviderForServiceUserOperation(providerId);
if (provider == null)
{
return result;
}
var userId = _currentContext.UserId;
if (!userId.HasValue)
{
return Error.Unauthorized();
}
var addable =
await providerBillingService.GetAddableOrganizations(provider, userId.Value);
return TypedResults.Ok(addable);
}
[HttpPost("existing")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> AddExistingOrganizationAsync(
[FromRoute] Guid providerId,
[FromBody] AddExistingOrganizationRequestBody requestBody)
{
var (provider, result) = await TryGetBillableProviderForServiceUserOperation(providerId);
if (provider == null)
{
return result;
}
var organization = await organizationRepository.GetByIdAsync(requestBody.OrganizationId);
if (organization == null)
{
return Error.BadRequest("The organization being added to the provider does not exist.");
}
await providerBillingService.AddExistingOrganization(provider, organization, requestBody.Key);
return TypedResults.Ok();
}
}

View File

@ -36,7 +36,7 @@ public class EventsController : Controller
/// If no filters are provided, it will return the last 30 days of event for the organization.
/// </remarks>
[HttpGet]
[ProducesResponseType(typeof(ListResponseModel<EventResponseModel>), (int)HttpStatusCode.OK)]
[ProducesResponseType(typeof(PagedListResponseModel<EventResponseModel>), (int)HttpStatusCode.OK)]
public async Task<IActionResult> List([FromQuery] EventFilterRequestModel request)
{
var dateRange = request.ToDateRange();
@ -65,7 +65,7 @@ public class EventsController : Controller
}
var eventResponses = result.Data.Select(e => new EventResponseModel(e));
var response = new ListResponseModel<EventResponseModel>(eventResponses, result.ContinuationToken);
var response = new PagedListResponseModel<EventResponseModel>(eventResponses, result.ContinuationToken);
return new JsonResult(response);
}
}

View File

@ -0,0 +1,12 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
public class AddExistingOrganizationRequestBody
{
[Required(ErrorMessage = "'key' must be provided")]
public string Key { get; set; }
[Required(ErrorMessage = "'organizationId' must be provided")]
public Guid OrganizationId { get; set; }
}

View File

@ -4,10 +4,9 @@ namespace Bit.Api.Models.Public.Response;
public class ListResponseModel<T> : IResponseModel where T : IResponseModel
{
public ListResponseModel(IEnumerable<T> data, string continuationToken = null)
public ListResponseModel(IEnumerable<T> data)
{
Data = data;
ContinuationToken = continuationToken;
}
/// <summary>
@ -21,8 +20,4 @@ public class ListResponseModel<T> : IResponseModel where T : IResponseModel
/// </summary>
[Required]
public IEnumerable<T> Data { get; set; }
/// <summary>
/// A cursor for use in pagination.
/// </summary>
public string ContinuationToken { get; set; }
}

View File

@ -0,0 +1,10 @@
namespace Bit.Api.Models.Public.Response;
public class PagedListResponseModel<T>(IEnumerable<T> data, string continuationToken) : ListResponseModel<T>(data)
where T : IResponseModel
{
/// <summary>
/// A cursor for use in pagination.
/// </summary>
public string ContinuationToken { get; set; } = continuationToken;
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Models.Data.Organizations;
#nullable enable
@ -22,4 +23,5 @@ public interface IOrganizationRepository : IRepository<Organization, Guid>
/// Gets the organizations that have a verified domain matching the user's email domain.
/// </summary>
Task<ICollection<Organization>> GetByVerifiedUserEmailDomainAsync(Guid userId);
Task<ICollection<Organization>> GetAddableToProviderByUserIdAsync(Guid userId, ProviderType providerType);
}

View File

@ -0,0 +1,8 @@
using Bit.Core.Models.Data;
namespace Bit.Core.Services;
public interface IEventMessageHandler
{
Task HandleEventAsync(EventMessage eventMessage);
}

View File

@ -0,0 +1,14 @@
using Bit.Core.Models.Data;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Services;
public class EventRepositoryHandler(
[FromKeyedServices("persistent")] IEventWriteService eventWriteService)
: IEventMessageHandler
{
public Task HandleEventAsync(EventMessage eventMessage)
{
return eventWriteService.CreateAsync(eventMessage);
}
}

View File

@ -0,0 +1,28 @@
using System.Net.Http.Json;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
namespace Bit.Core.Services;
public class HttpPostEventHandler : IEventMessageHandler
{
private readonly HttpClient _httpClient;
private readonly string _httpPostUrl;
public const string HttpClientName = "HttpPostEventHandlerHttpClient";
public HttpPostEventHandler(
IHttpClientFactory httpClientFactory,
GlobalSettings globalSettings)
{
_httpClient = httpClientFactory.CreateClient(HttpClientName);
_httpPostUrl = globalSettings.EventLogging.RabbitMq.HttpPostUrl;
}
public async Task HandleEventAsync(EventMessage eventMessage)
{
var content = JsonContent.Create(eventMessage);
var response = await _httpClient.PostAsync(_httpPostUrl, content);
response.EnsureSuccessStatusCode();
}
}

View File

@ -0,0 +1,65 @@
using System.Text.Json;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
using RabbitMQ.Client;
namespace Bit.Core.Services;
public class RabbitMqEventWriteService : IEventWriteService, IAsyncDisposable
{
private readonly ConnectionFactory _factory;
private readonly Lazy<Task<IConnection>> _lazyConnection;
private readonly string _exchangeName;
public RabbitMqEventWriteService(GlobalSettings globalSettings)
{
_factory = new ConnectionFactory
{
HostName = globalSettings.EventLogging.RabbitMq.HostName,
UserName = globalSettings.EventLogging.RabbitMq.Username,
Password = globalSettings.EventLogging.RabbitMq.Password
};
_exchangeName = globalSettings.EventLogging.RabbitMq.ExchangeName;
_lazyConnection = new Lazy<Task<IConnection>>(CreateConnectionAsync);
}
public async Task CreateAsync(IEvent e)
{
var connection = await _lazyConnection.Value;
using var channel = await connection.CreateChannelAsync();
await channel.ExchangeDeclareAsync(exchange: _exchangeName, type: ExchangeType.Fanout, durable: true);
var body = JsonSerializer.SerializeToUtf8Bytes(e);
await channel.BasicPublishAsync(exchange: _exchangeName, routingKey: string.Empty, body: body);
}
public async Task CreateManyAsync(IEnumerable<IEvent> events)
{
var connection = await _lazyConnection.Value;
using var channel = await connection.CreateChannelAsync();
await channel.ExchangeDeclareAsync(exchange: _exchangeName, type: ExchangeType.Fanout, durable: true);
foreach (var e in events)
{
var body = JsonSerializer.SerializeToUtf8Bytes(e);
await channel.BasicPublishAsync(exchange: _exchangeName, routingKey: string.Empty, body: body);
}
}
public async ValueTask DisposeAsync()
{
if (_lazyConnection.IsValueCreated)
{
var connection = await _lazyConnection.Value;
await connection.DisposeAsync();
}
}
private async Task<IConnection> CreateConnectionAsync()
{
return await _factory.CreateConnectionAsync();
}
}

View File

@ -0,0 +1,30 @@
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Constants;
public static class PlanConstants
{
public static List<PlanType> EnterprisePlanTypes =>
[
PlanType.EnterpriseAnnually2019,
PlanType.EnterpriseAnnually2020,
PlanType.EnterpriseAnnually2023,
PlanType.EnterpriseAnnually,
PlanType.EnterpriseMonthly2019,
PlanType.EnterpriseMonthly2020,
PlanType.EnterpriseMonthly2023,
PlanType.EnterpriseMonthly
];
public static List<PlanType> TeamsPlanTypes =>
[
PlanType.TeamsAnnually2019,
PlanType.TeamsAnnually2020,
PlanType.TeamsAnnually2023,
PlanType.TeamsAnnually,
PlanType.TeamsMonthly2019,
PlanType.TeamsMonthly2020,
PlanType.TeamsMonthly2023,
PlanType.TeamsMonthly
];
}

View File

@ -31,6 +31,16 @@ public static class StripeConstants
public const string TaxIdInvalid = "tax_id_invalid";
}
public static class InvoiceStatus
{
public const string Draft = "draft";
}
public static class MetadataKeys
{
public const string OrganizationId = "organizationId";
}
public static class PaymentBehavior
{
public const string DefaultIncomplete = "default_incomplete";

View File

@ -0,0 +1,8 @@
namespace Bit.Core.Billing.Models;
public record AddableOrganization(
Guid Id,
string Name,
string Plan,
int Seats,
bool Disabled = false);

View File

@ -1,4 +1,5 @@
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Entities;
namespace Bit.Core.Billing.Services;
@ -27,4 +28,9 @@ public interface IPremiumUserBillingService
/// </code>
/// </example>
Task Finalize(PremiumUserSale sale);
Task UpdatePaymentMethod(
User user,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation);
}

View File

@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Models.Business;
using Stripe;
@ -10,6 +11,11 @@ namespace Bit.Core.Billing.Services;
public interface IProviderBillingService
{
Task AddExistingOrganization(
Provider provider,
Organization organization,
string key);
/// <summary>
/// Changes the assigned provider plan for the provider.
/// </summary>
@ -35,6 +41,10 @@ public interface IProviderBillingService
Task<byte[]> GenerateClientInvoiceReport(
string invoiceId);
Task<IEnumerable<AddableOrganization>> GetAddableOrganizations(
Provider provider,
Guid userId);
/// <summary>
/// Scales the <paramref name="provider"/>'s seats for the specified <paramref name="planType"/> using the provided <paramref name="seatAdjustment"/>.
/// This operation may autoscale the provider's Stripe <see cref="Stripe.Subscription"/> depending on the <paramref name="provider"/>'s seat minimum for the

View File

@ -356,11 +356,20 @@ public class OrganizationBillingService(
}
}
var customerHasTaxInfo = customer is
{
Address:
{
Country: not null and not "",
PostalCode: not null and not ""
}
};
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
Enabled = customerHasTaxInfo
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id,

View File

@ -1,5 +1,6 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Entities;
using Bit.Core.Enums;
@ -58,6 +59,28 @@ public class PremiumUserBillingService(
await userRepository.ReplaceAsync(user);
}
public async Task UpdatePaymentMethod(
User user,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
{
if (string.IsNullOrEmpty(user.GatewayCustomerId))
{
var customer = await CreateCustomerAsync(user,
new CustomerSetup { TokenizedPaymentSource = tokenizedPaymentSource, TaxInformation = taxInformation });
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
await userRepository.ReplaceAsync(user);
}
else
{
await subscriberService.UpdatePaymentSource(user, tokenizedPaymentSource);
await subscriberService.UpdateTaxInformation(user, taxInformation);
}
}
private async Task<Customer> CreateCustomerAsync(
User user,
CustomerSetup customerSetup)

View File

@ -107,10 +107,17 @@ public static class FeatureFlagKeys
public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint";
public const string IntegrationPage = "pm-14505-admin-console-integration-page";
public const string DeviceApprovalRequestAdminNotifications = "pm-15637-device-approval-request-admin-notifications";
public const string LimitItemDeletion = "pm-15493-restrict-item-deletion-to-can-manage-permission";
/* Tools Team */
public const string ItemShare = "item-share";
public const string GeneratorToolsModernization = "generator-tools-modernization";
public const string MemberAccessReport = "ac-2059-member-access-report";
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";
public const string EnableRiskInsightsNotifications = "enable-risk-insights-notifications";
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";
public const string UseTreeWalkerApiForPageDetailsCollection = "use-tree-walker-api-for-page-details-collection";
public const string ItemShare = "item-share";
public const string DuoRedirect = "duo-redirect";
public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email";
public const string AC1795_UpdatedSubscriptionStatusSection = "AC-1795_updated-subscription-status-section";
@ -120,7 +127,6 @@ public static class FeatureFlagKeys
public const string RestrictProviderAccess = "restrict-provider-access";
public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service";
public const string VaultBulkManagementAction = "vault-bulk-management-action";
public const string MemberAccessReport = "ac-2059-member-access-report";
public const string InlineMenuFieldQualification = "inline-menu-field-qualification";
public const string TwoFactorComponentRefactor = "two-factor-component-refactor";
public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements";
@ -145,9 +151,7 @@ public static class FeatureFlagKeys
public const string TrialPayment = "PM-8163-trial-payment";
public const string RemoveServerVersionHeader = "remove-server-version-header";
public const string PM12275_MultiOrganizationEnterprises = "pm-12275-multi-organization-enterprises";
public const string GeneratorToolsModernization = "generator-tools-modernization";
public const string NewDeviceVerification = "new-device-verification";
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";
public const string NewDeviceVerificationTemporaryDismiss = "new-device-temporary-dismiss";
public const string NewDeviceVerificationPermanentDismiss = "new-device-permanent-dismiss";
public const string SecurityTasks = "security-tasks";
@ -169,7 +173,8 @@ public static class FeatureFlagKeys
public const string AccountDeprovisioningBanner = "pm-17120-account-deprovisioning-admin-console-banner";
public const string SingleTapPasskeyCreation = "single-tap-passkey-creation";
public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication";
public const string EnableRiskInsightsNotifications = "enable-risk-insights-notifications";
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
public const string P15179_AddExistingOrgsFromProviderPortal = "PM-15179-add-existing-orgs-from-provider-portal";
public static List<string> GetAllKeys()
{

View File

@ -21,8 +21,8 @@
<ItemGroup>
<PackageReference Include="AspNetCoreRateLimit.Redis" Version="2.0.0" />
<PackageReference Include="AWSSDK.SimpleEmail" Version="3.7.402.18" />
<PackageReference Include="AWSSDK.SQS" Version="3.7.400.75" />
<PackageReference Include="AWSSDK.SimpleEmail" Version="3.7.402.28" />
<PackageReference Include="AWSSDK.SQS" Version="3.7.400.85" />
<PackageReference Include="Azure.Data.Tables" Version="12.9.0" />
<PackageReference Include="Azure.Extensions.AspNetCore.DataProtection.Blobs" Version="1.3.4" />
<PackageReference Include="Google.Protobuf" Version="3.29.2" />
@ -40,7 +40,7 @@
<PackageReference Include="DnsClient" Version="1.8.0" />
<PackageReference Include="Fido2.AspNet" Version="3.0.1" />
<PackageReference Include="Handlebars.Net" Version="2.1.6" />
<PackageReference Include="MailKit" Version="4.9.0" />
<PackageReference Include="MailKit" Version="4.9.0" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.10" />
<PackageReference Include="Microsoft.Azure.Cosmos" Version="3.46.1" />
<PackageReference Include="Microsoft.Azure.NotificationHubs" Version="4.2.0" />
@ -70,12 +70,13 @@
<PackageReference Include="Quartz" Version="3.13.1" />
<PackageReference Include="Quartz.Extensions.Hosting" Version="3.13.1" />
<PackageReference Include="Quartz.Extensions.DependencyInjection" Version="3.13.1" />
<PackageReference Include="RabbitMQ.Client" Version="7.0.0" />
</ItemGroup>
<ItemGroup>
<Protobuf Include="Billing\Pricing\Protos\password-manager.proto" GrpcServices="Client" />
</ItemGroup>
<ItemGroup>
<Folder Include="Resources\" />
<Folder Include="Properties\" />

View File

@ -0,0 +1,13 @@
using Microsoft.Extensions.Hosting;
namespace Bit.Core.Services;
public abstract class EventLoggingListenerService : BackgroundService
{
protected readonly IEventMessageHandler _handler;
protected EventLoggingListenerService(IEventMessageHandler handler)
{
_handler = handler ?? throw new ArgumentNullException(nameof(handler));
}
}

View File

@ -0,0 +1,92 @@
using System.Text.Json;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
namespace Bit.Core.Services;
public class RabbitMqEventListenerService : EventLoggingListenerService
{
private IChannel _channel;
private IConnection _connection;
private readonly string _exchangeName;
private readonly ConnectionFactory _factory;
private readonly ILogger<RabbitMqEventListenerService> _logger;
private readonly string _queueName;
public RabbitMqEventListenerService(
IEventMessageHandler handler,
ILogger<RabbitMqEventListenerService> logger,
GlobalSettings globalSettings,
string queueName) : base(handler)
{
_factory = new ConnectionFactory
{
HostName = globalSettings.EventLogging.RabbitMq.HostName,
UserName = globalSettings.EventLogging.RabbitMq.Username,
Password = globalSettings.EventLogging.RabbitMq.Password
};
_exchangeName = globalSettings.EventLogging.RabbitMq.ExchangeName;
_logger = logger;
_queueName = queueName;
}
public override async Task StartAsync(CancellationToken cancellationToken)
{
_connection = await _factory.CreateConnectionAsync(cancellationToken);
_channel = await _connection.CreateChannelAsync(cancellationToken: cancellationToken);
await _channel.ExchangeDeclareAsync(exchange: _exchangeName, type: ExchangeType.Fanout, durable: true);
await _channel.QueueDeclareAsync(queue: _queueName,
durable: true,
exclusive: false,
autoDelete: false,
arguments: null,
cancellationToken: cancellationToken);
await _channel.QueueBindAsync(queue: _queueName,
exchange: _exchangeName,
routingKey: string.Empty,
cancellationToken: cancellationToken);
await base.StartAsync(cancellationToken);
}
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
var consumer = new AsyncEventingBasicConsumer(_channel);
consumer.ReceivedAsync += async (_, eventArgs) =>
{
try
{
var eventMessage = JsonSerializer.Deserialize<EventMessage>(eventArgs.Body.Span);
await _handler.HandleEventAsync(eventMessage);
}
catch (Exception ex)
{
_logger.LogError(ex, "An error occurred while processing the message");
}
};
await _channel.BasicConsumeAsync(_queueName, autoAck: true, consumer: consumer, cancellationToken: stoppingToken);
while (!stoppingToken.IsCancellationRequested)
{
await Task.Delay(1_000, stoppingToken);
}
}
public override async Task StopAsync(CancellationToken cancellationToken)
{
await _channel.CloseAsync();
await _connection.CloseAsync();
await base.StopAsync(cancellationToken);
}
public override void Dispose()
{
_channel.Dispose();
_connection.Dispose();
base.Dispose();
}
}

View File

@ -9,6 +9,7 @@ using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
@ -1044,11 +1045,11 @@ public class UserService : UserManager<User>, IUserService, IDisposable
throw new BadRequestException("Invalid token.");
}
var updated = await _paymentService.UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, taxInfo: taxInfo);
if (updated)
{
await SaveUserAsync(user);
}
var tokenizedPaymentSource = new TokenizedPaymentSource(paymentMethodType, paymentToken);
var taxInformation = TaxInformation.From(taxInfo);
await _premiumUserBillingService.UpdatePaymentMethod(user, tokenizedPaymentSource, taxInformation);
await SaveUserAsync(user);
}
public async Task CancelPremiumAsync(User user, bool? endOfPeriod = null)

View File

@ -53,6 +53,7 @@ public class GlobalSettings : IGlobalSettings
public virtual SqlSettings PostgreSql { get; set; } = new SqlSettings();
public virtual SqlSettings MySql { get; set; } = new SqlSettings();
public virtual SqlSettings Sqlite { get; set; } = new SqlSettings() { ConnectionString = "Data Source=:memory:" };
public virtual EventLoggingSettings EventLogging { get; set; } = new EventLoggingSettings();
public virtual MailSettings Mail { get; set; } = new MailSettings();
public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings();
public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings();
@ -256,6 +257,44 @@ public class GlobalSettings : IGlobalSettings
}
}
public class EventLoggingSettings
{
public RabbitMqSettings RabbitMq { get; set; } = new RabbitMqSettings();
public class RabbitMqSettings
{
private string _hostName;
private string _username;
private string _password;
private string _exchangeName;
public virtual string EventRepositoryQueueName { get; set; } = "events-write-queue";
public virtual string HttpPostQueueName { get; set; } = "events-httpPost-queue";
public virtual string HttpPostUrl { get; set; }
public string HostName
{
get => _hostName;
set => _hostName = value.Trim('"');
}
public string Username
{
get => _username;
set => _username = value.Trim('"');
}
public string Password
{
get => _password;
set => _password = value.Trim('"');
}
public string ExchangeName
{
get => _exchangeName;
set => _exchangeName = value.Trim('"');
}
}
}
public class ConnectionStringSettings : IConnectionStringSettings
{
private string _connectionString;

View File

@ -27,4 +27,5 @@ public interface IGlobalSettings
string DatabaseProvider { get; set; }
GlobalSettings.SqlSettings SqlServer { get; set; }
string DevelopmentDirectory { get; set; }
GlobalSettings.EventLoggingSettings EventLogging { get; set; }
}

View File

@ -82,6 +82,35 @@ public class Startup
{
services.AddHostedService<Core.HostedServices.ApplicationCacheHostedService>();
}
// Optional RabbitMQ Listeners
if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.HostName) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.Username) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.Password) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.ExchangeName))
{
services.AddSingleton<EventRepositoryHandler>();
services.AddKeyedSingleton<IEventWriteService, RepositoryEventWriteService>("persistent");
services.AddSingleton<IHostedService>(provider =>
new RabbitMqEventListenerService(
provider.GetRequiredService<EventRepositoryHandler>(),
provider.GetRequiredService<ILogger<RabbitMqEventListenerService>>(),
provider.GetRequiredService<GlobalSettings>(),
globalSettings.EventLogging.RabbitMq.EventRepositoryQueueName));
if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.HttpPostUrl))
{
services.AddSingleton<HttpPostEventHandler>();
services.AddHttpClient(HttpPostEventHandler.HttpClientName);
services.AddSingleton<IHostedService>(provider =>
new RabbitMqEventListenerService(
provider.GetRequiredService<HttpPostEventHandler>(),
provider.GetRequiredService<ILogger<RabbitMqEventListenerService>>(),
provider.GetRequiredService<GlobalSettings>(),
globalSettings.EventLogging.RabbitMq.HttpPostQueueName));
}
}
}
public void Configure(

View File

@ -4,6 +4,7 @@ using Bit.Core.Context;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Bit.SharedWeb.Utilities;
using Microsoft.AspNetCore.Mvc;
@ -17,6 +18,7 @@ public class AccountsController(
IReferenceEventService referenceEventService) : Microsoft.AspNetCore.Mvc.Controller
{
[HttpPost("trial/send-verification-email")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> PostTrialInitiationSendVerificationEmailAsync([FromBody] TrialSendVerificationEmailRequestModel model)
{
var token = await sendTrialInitiationEmailForRegistrationCommand.Handle(

View File

@ -85,28 +85,17 @@ public class DeviceValidator(
}
}
// At this point we have established either new device verification is not required or the NewDeviceOtp is valid
// At this point we have established either new device verification is not required or the NewDeviceOtp is valid,
// so we save the device to the database and proceed with authentication
requestDevice.UserId = context.User.Id;
await _deviceService.SaveAsync(requestDevice);
context.Device = requestDevice;
// backwards compatibility -- If NewDeviceVerification not enabled send the new login emails
// PM-13340: removal Task; remove entire if block emails should no longer be sent
if (!_featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification))
if (!_globalSettings.DisableEmailNewDevice)
{
// This ensures the user doesn't receive a "new device" email on the first login
var now = DateTime.UtcNow;
if (now - context.User.CreationDate > TimeSpan.FromMinutes(10))
{
var deviceType = requestDevice.Type.GetType().GetMember(requestDevice.Type.ToString())
.FirstOrDefault()?.GetCustomAttribute<DisplayAttribute>()?.GetName();
if (!_globalSettings.DisableEmailNewDevice)
{
await _mailService.SendNewDeviceLoggedInEmail(context.User.Email, deviceType, now,
_currentContext.IpAddress);
}
}
await SendNewDeviceLoginEmail(context.User, requestDevice);
}
return true;
}
@ -174,6 +163,19 @@ public class DeviceValidator(
return DeviceValidationResultType.NewDeviceVerificationRequired;
}
private async Task SendNewDeviceLoginEmail(User user, Device requestDevice)
{
// Ensure that the user doesn't receive a "new device" email on the first login
var now = DateTime.UtcNow;
if (now - user.CreationDate > TimeSpan.FromMinutes(10))
{
var deviceType = requestDevice.Type.GetType().GetMember(requestDevice.Type.ToString())
.FirstOrDefault()?.GetCustomAttribute<DisplayAttribute>()?.GetName();
await _mailService.SendNewDeviceLoggedInEmail(user.Email, deviceType, now,
_currentContext.IpAddress);
}
}
public async Task<Device> GetKnownDeviceAsync(User user, Device device)
{
if (user == null || device == null)

View File

@ -1,5 +1,6 @@
using System.Data;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Auth.Entities;
using Bit.Core.Entities;
using Bit.Core.Models.Data.Organizations;
@ -180,4 +181,19 @@ public class OrganizationRepository : Repository<Organization, Guid>, IOrganizat
return result.ToList();
}
}
public async Task<ICollection<Organization>> GetAddableToProviderByUserIdAsync(
Guid userId,
ProviderType providerType)
{
using (var connection = new SqlConnection(ConnectionString))
{
var result = await connection.QueryAsync<Organization>(
$"[{Schema}].[{Table}_ReadAddableToProviderByUserId]",
new { UserId = userId, ProviderType = providerType },
commandType: CommandType.StoredProcedure);
return result.ToList();
}
}
}

View File

@ -1,9 +1,12 @@
using AutoMapper;
using AutoMapper.QueryableExtensions;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Repositories;
using LinqToDB.Tools;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
@ -114,13 +117,19 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
var dbContext = GetDatabaseContext(scope);
var disallowedPlanTypes = new List<PlanType>
{
PlanType.Free,
PlanType.Custom,
PlanType.FamiliesAnnually2019,
PlanType.FamiliesAnnually
};
var query =
from o in dbContext.Organizations
where
((o.PlanType >= PlanType.TeamsMonthly2019 && o.PlanType <= PlanType.EnterpriseAnnually2019) ||
(o.PlanType >= PlanType.TeamsMonthly2020 && o.PlanType <= PlanType.EnterpriseAnnually)) &&
!dbContext.ProviderOrganizations.Any(po => po.OrganizationId == o.Id) &&
(string.IsNullOrWhiteSpace(name) || EF.Functions.Like(o.Name, $"%{name}%"))
where o.PlanType.NotIn(disallowedPlanTypes) &&
!dbContext.ProviderOrganizations.Any(po => po.OrganizationId == o.Id) &&
(string.IsNullOrWhiteSpace(name) || EF.Functions.Like(o.Name, $"%{name}%"))
select o;
if (string.IsNullOrWhiteSpace(ownerEmail))
@ -152,7 +161,7 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
select o;
}
return await query.OrderByDescending(o => o.CreationDate).Skip(skip).Take(take).ToArrayAsync();
return await query.OrderByDescending(o => o.CreationDate).ThenByDescending(o => o.Id).Skip(skip).Take(take).ToArrayAsync();
}
public async Task UpdateStorageAsync(Guid id)
@ -298,6 +307,41 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
}
}
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetAddableToProviderByUserIdAsync(
Guid userId,
ProviderType providerType)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var planTypes = providerType switch
{
ProviderType.Msp => PlanConstants.EnterprisePlanTypes.Concat(PlanConstants.TeamsPlanTypes),
ProviderType.MultiOrganizationEnterprise => PlanConstants.EnterprisePlanTypes,
_ => []
};
var query =
from organizationUser in dbContext.OrganizationUsers
join organization in dbContext.Organizations on organizationUser.OrganizationId equals organization.Id
where
organizationUser.UserId == userId &&
organizationUser.Type == OrganizationUserType.Owner &&
organizationUser.Status == OrganizationUserStatusType.Confirmed &&
organization.Enabled &&
organization.GatewayCustomerId != null &&
organization.GatewaySubscriptionId != null &&
organization.Seats > 0 &&
organization.Status == OrganizationStatusType.Created &&
!organization.UseSecretsManager &&
organization.PlanType.In(planTypes)
select organization;
return await query.ToArrayAsync();
}
}
public Task EnableCollectionEnhancements(Guid organizationId)
{
throw new NotImplementedException("Collection enhancements migration is not yet supported for Entity Framework.");

View File

@ -46,27 +46,17 @@ public class OrganizationDomainRepository : Repository<Core.Entities.Organizatio
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var domains = await dbContext.OrganizationDomains
.Where(x => x.VerifiedDate == null
&& x.JobRunCount != 3
&& x.NextRunDate.Year == date.Year
&& x.NextRunDate.Month == date.Month
&& x.NextRunDate.Day == date.Day
&& x.NextRunDate.Hour == date.Hour)
.AsNoTracking()
var start36HoursWindow = date.AddHours(-36);
var end36HoursWindow = date;
var pastDomains = await dbContext.OrganizationDomains
.Where(x => x.NextRunDate >= start36HoursWindow
&& x.NextRunDate <= end36HoursWindow
&& x.VerifiedDate == null
&& x.JobRunCount != 3)
.ToListAsync();
//Get records that have ignored/failed by the background service
var pastDomains = dbContext.OrganizationDomains
.AsEnumerable()
.Where(x => (date - x.NextRunDate).TotalHours > 36
&& x.VerifiedDate == null
&& x.JobRunCount != 3)
.ToList();
var results = domains.Union(pastDomains);
return Mapper.Map<List<Core.Entities.OrganizationDomain>>(results);
return Mapper.Map<List<Core.Entities.OrganizationDomain>>(pastDomains);
}
public async Task<OrganizationDomainSsoDetailsData?> GetOrganizationDomainSsoDetailsAsync(string email)

View File

@ -325,7 +325,17 @@ public static class ServiceCollectionExtensions
}
else if (globalSettings.SelfHosted)
{
services.AddSingleton<IEventWriteService, RepositoryEventWriteService>();
if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.HostName) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.Username) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.Password) &&
CoreHelpers.SettingHasValue(globalSettings.EventLogging.RabbitMq.ExchangeName))
{
services.AddSingleton<IEventWriteService, RabbitMqEventWriteService>();
}
else
{
services.AddSingleton<IEventWriteService, RepositoryEventWriteService>();
}
}
else
{

View File

@ -0,0 +1,23 @@
CREATE PROCEDURE [dbo].[Organization_ReadAddableToProviderByUserId]
@UserId UNIQUEIDENTIFIER,
@ProviderType TINYINT
AS
BEGIN
SET NOCOUNT ON
SELECT O.* FROM [dbo].[OrganizationUser] AS OU
JOIN [dbo].[Organization] AS O ON O.[Id] = OU.[OrganizationId]
WHERE
OU.[UserId] = @UserId AND
OU.[Type] = 0 AND
OU.[Status] = 2 AND
O.[Enabled] = 1 AND
O.[GatewayCustomerId] IS NOT NULL AND
O.[GatewaySubscriptionId] IS NOT NULL AND
O.[Seats] > 0 AND
O.[Status] = 1 AND
O.[UseSecretsManager] = 0 AND
-- All Teams & Enterprise for MSP
(@ProviderType = 0 AND O.[PlanType] IN (2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20) OR
-- All Enterprise for MOE
@ProviderType = 2 AND O.[PlanType] IN (4, 5, 10, 11, 14, 15, 19, 20));
END

View File

@ -1,5 +1,5 @@
CREATE PROCEDURE [dbo].[Organization_UnassignedToProviderSearch]
@Name NVARCHAR(50),
@Name NVARCHAR(55),
@OwnerEmail NVARCHAR(256),
@Skip INT = 0,
@Take INT = 25
@ -9,7 +9,7 @@ BEGIN
SET NOCOUNT ON
DECLARE @NameLikeSearch NVARCHAR(55) = '%' + @Name + '%'
DECLARE @OwnerLikeSearch NVARCHAR(55) = @OwnerEmail + '%'
IF @OwnerEmail IS NOT NULL
BEGIN
SELECT
@ -21,11 +21,11 @@ BEGIN
INNER JOIN
[dbo].[User] U ON U.[Id] = OU.[UserId]
WHERE
((O.[PlanType] >= 2 AND O.[PlanType] <= 5) OR (O.[PlanType] >= 8 AND O.[PlanType] <= 20) AND (O.PlanType <> 16)) -- All 'Teams' and 'Enterprise' organizations
O.[PlanType] NOT IN (0, 1, 6, 7) -- Not 'Free', 'Custom' or 'Families'
AND NOT EXISTS (SELECT * FROM [dbo].[ProviderOrganizationView] PO WHERE PO.[OrganizationId] = O.[Id])
AND (@Name IS NULL OR O.[Name] LIKE @NameLikeSearch)
AND (U.[Email] LIKE @OwnerLikeSearch)
ORDER BY O.[CreationDate] DESC
ORDER BY O.[CreationDate] DESC, O.[Id]
OFFSET @Skip ROWS
FETCH NEXT @Take ROWS ONLY
END
@ -36,11 +36,11 @@ BEGIN
FROM
[dbo].[OrganizationView] O
WHERE
((O.[PlanType] >= 2 AND O.[PlanType] <= 5) OR (O.[PlanType] >= 8 AND O.[PlanType] <= 20) AND (O.PlanType <> 16)) -- All 'Teams' and 'Enterprise' organizations
O.[PlanType] NOT IN (0, 1, 6, 7) -- Not 'Free', 'Custom' or 'Families'
AND NOT EXISTS (SELECT * FROM [dbo].[ProviderOrganizationView] PO WHERE PO.[OrganizationId] = O.[Id])
AND (@Name IS NULL OR O.[Name] LIKE @NameLikeSearch)
ORDER BY O.[CreationDate] DESC
ORDER BY O.[CreationDate] DESC, O.[Id]
OFFSET @Skip ROWS
FETCH NEXT @Take ROWS ONLY
END
END
END

View File

@ -1,5 +1,5 @@
using System.Security.Claims;
using Bit.Api.Billing.Controllers;
using Bit.Api.AdminConsole.Controllers;
using Bit.Api.Billing.Models.Requests;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
@ -19,10 +19,9 @@ using Microsoft.AspNetCore.Http.HttpResults;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
using static Bit.Api.Test.Billing.Utilities;
namespace Bit.Api.Test.Billing.Controllers;
namespace Bit.Api.Test.AdminConsole.Controllers;
[ControllerCustomize(typeof(ProviderClientsController))]
[SutProviderCustomize]

View File

@ -8,6 +8,8 @@ public class MockedHttpMessageHandler : HttpMessageHandler
{
private readonly List<IHttpRequestMatcher> _matchers = new();
public List<HttpRequestMessage> CapturedRequests { get; } = new List<HttpRequestMessage>();
/// <summary>
/// The fallback handler to use when the request does not match any of the provided matchers.
/// </summary>
@ -16,6 +18,7 @@ public class MockedHttpMessageHandler : HttpMessageHandler
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
CapturedRequests.Add(request);
var matcher = _matchers.FirstOrDefault(x => x.Matches(request));
if (matcher == null)
{

View File

@ -0,0 +1,24 @@
using Bit.Core.Models.Data;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Bit.Test.Common.Helpers;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class EventRepositoryHandlerTests
{
[Theory, BitAutoData]
public async Task HandleEventAsync_WritesEventToIEventWriteService(
EventMessage eventMessage,
SutProvider<EventRepositoryHandler> sutProvider)
{
await sutProvider.Sut.HandleEventAsync(eventMessage);
await sutProvider.GetDependency<IEventWriteService>().Received(1).CreateAsync(
Arg.Is(AssertHelper.AssertPropertyEqual<IEvent>(eventMessage))
);
}
}

View File

@ -0,0 +1,66 @@
using System.Net;
using System.Net.Http.Json;
using Bit.Core.Models.Data;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Bit.Test.Common.Helpers;
using Bit.Test.Common.MockedHttpClient;
using NSubstitute;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class HttpPostEventHandlerTests
{
private readonly MockedHttpMessageHandler _handler;
private HttpClient _httpClient;
private const string _httpPostUrl = "http://localhost/test/event";
public HttpPostEventHandlerTests()
{
_handler = new MockedHttpMessageHandler();
_handler.Fallback
.WithStatusCode(HttpStatusCode.OK)
.WithContent(new StringContent("<html><head><title>test</title></head><body>test</body></html>"));
_httpClient = _handler.ToHttpClient();
}
public SutProvider<HttpPostEventHandler> GetSutProvider()
{
var clientFactory = Substitute.For<IHttpClientFactory>();
clientFactory.CreateClient(HttpPostEventHandler.HttpClientName).Returns(_httpClient);
var globalSettings = new GlobalSettings();
globalSettings.EventLogging.RabbitMq.HttpPostUrl = _httpPostUrl;
return new SutProvider<HttpPostEventHandler>()
.SetDependency(globalSettings)
.SetDependency(clientFactory)
.Create();
}
[Theory, BitAutoData]
public async Task HandleEventAsync_PostsEventsToUrl(EventMessage eventMessage)
{
var sutProvider = GetSutProvider();
var content = JsonContent.Create(eventMessage);
await sutProvider.Sut.HandleEventAsync(eventMessage);
sutProvider.GetDependency<IHttpClientFactory>().Received(1).CreateClient(
Arg.Is(AssertHelper.AssertPropertyEqual<string>(HttpPostEventHandler.HttpClientName))
);
Assert.Single(_handler.CapturedRequests);
var request = _handler.CapturedRequests[0];
Assert.NotNull(request);
var returned = await request.Content.ReadFromJsonAsync<EventMessage>();
Assert.Equal(HttpMethod.Post, request.Method);
Assert.Equal(_httpPostUrl, request.RequestUri.ToString());
AssertHelper.AssertPropertyEqual(eventMessage, returned, new[] { "IdempotencyId" });
}
}

View File

@ -0,0 +1,29 @@
using System.Net.Http.Json;
using Bit.Core.Enums;
using Bit.Events.Models;
namespace Bit.Events.IntegrationTest.Controllers;
public class CollectControllerTests
{
// This is a very simple test, and should be updated to assert more things, but for now
// it ensures that the events startup doesn't throw any errors with fairly basic configuration.
[Fact]
public async Task Post_Works()
{
var eventsApplicationFactory = new EventsApplicationFactory();
var (accessToken, _) = await eventsApplicationFactory.LoginWithNewAccount();
var client = eventsApplicationFactory.CreateAuthedClient(accessToken);
var response = await client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
}

View File

@ -0,0 +1,29 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Events\Events.csproj" />
<ProjectReference Include="..\IntegrationTestCommon\IntegrationTestCommon.csproj" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,57 @@
using Bit.Identity.Models.Request.Accounts;
using Bit.IntegrationTestCommon.Factories;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Events.IntegrationTest;
public class EventsApplicationFactory : WebApplicationFactoryBase<Startup>
{
private readonly IdentityApplicationFactory _identityApplicationFactory;
private const string _connectionString = "DataSource=:memory:";
public EventsApplicationFactory()
{
SqliteConnection = new SqliteConnection(_connectionString);
SqliteConnection.Open();
_identityApplicationFactory = new IdentityApplicationFactory();
_identityApplicationFactory.SqliteConnection = SqliteConnection;
}
protected override void ConfigureWebHost(IWebHostBuilder builder)
{
base.ConfigureWebHost(builder);
builder.ConfigureTestServices(services =>
{
services.Configure<JwtBearerOptions>(JwtBearerDefaults.AuthenticationScheme, options =>
{
options.BackchannelHttpHandler = _identityApplicationFactory.Server.CreateHandler();
});
});
}
/// <summary>
/// Helper for registering and logging in to a new account
/// </summary>
public async Task<(string Token, string RefreshToken)> LoginWithNewAccount(string email = "integration-test@bitwarden.com", string masterPasswordHash = "master_password_hash")
{
await _identityApplicationFactory.RegisterAsync(new RegisterRequestModel
{
Email = email,
MasterPasswordHash = masterPasswordHash,
});
return await _identityApplicationFactory.TokenFromPasswordAsync(email, masterPasswordHash);
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
SqliteConnection!.Dispose();
}
}

View File

@ -0,0 +1 @@
global using Xunit;

View File

@ -227,7 +227,7 @@ public class DeviceValidatorTests
}
[Theory, BitAutoData]
public async void ValidateRequestDeviceAsync_NewDeviceVerificationFeatureFlagFalse_SendsEmail_ReturnsTrue(
public async void ValidateRequestDeviceAsync_ExistingUserNewDeviceLogin_SendNewDeviceLoginEmail_ReturnsTrue(
CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request)
{
@ -237,8 +237,6 @@ public class DeviceValidatorTests
_globalSettings.DisableEmailNewDevice = false;
_deviceRepository.GetByIdentifierAsync(context.Device.Identifier, context.User.Id)
.Returns(null as Device);
_featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification)
.Returns(false);
// set user creation to more than 10 minutes ago
context.User.CreationDate = DateTime.UtcNow - TimeSpan.FromMinutes(11);
@ -253,7 +251,7 @@ public class DeviceValidatorTests
}
[Theory, BitAutoData]
public async void ValidateRequestDeviceAsync_NewDeviceVerificationFeatureFlagFalse_NewUser_DoesNotSendEmail_ReturnsTrue(
public async void ValidateRequestDeviceAsync_NewUserNewDeviceLogin_DoesNotSendNewDeviceLoginEmail_ReturnsTrue(
CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request)
{
@ -263,8 +261,6 @@ public class DeviceValidatorTests
_globalSettings.DisableEmailNewDevice = false;
_deviceRepository.GetByIdentifierAsync(context.Device.Identifier, context.User.Id)
.Returns(null as Device);
_featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification)
.Returns(false);
// set user creation to less than 10 minutes ago
context.User.CreationDate = DateTime.UtcNow - TimeSpan.FromMinutes(9);
@ -279,7 +275,7 @@ public class DeviceValidatorTests
}
[Theory, BitAutoData]
public async void ValidateRequestDeviceAsync_NewDeviceVerificationFeatureFlagFalse_DisableEmailTrue_DoesNotSendEmail_ReturnsTrue(
public async void ValidateRequestDeviceAsynce_DisableNewDeviceLoginEmailTrue_DoesNotSendNewDeviceEmail_ReturnsTrue(
CustomValidatorRequestContext context,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request)
{
@ -289,8 +285,6 @@ public class DeviceValidatorTests
_globalSettings.DisableEmailNewDevice = true;
_deviceRepository.GetByIdentifierAsync(context.Device.Identifier, context.User.Id)
.Returns(null as Device);
_featureService.IsEnabled(FeatureFlagKeys.NewDeviceVerification)
.Returns(false);
// Act
var result = await _sut.ValidateRequestDeviceAsync(request, context);

View File

@ -11,7 +11,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit" Version="2.9.3" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.0.1">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>

View File

@ -188,4 +188,122 @@ public class OrganizationDomainRepositoryTests
var expectedDomain2 = domains.FirstOrDefault(domain => domain.DomainName == organizationDomain2.DomainName);
Assert.Null(expectedDomain2);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByNextRunDateAsync_ShouldReturnUnverifiedDomains(
IOrganizationRepository organizationRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
// Arrange
var id = Guid.NewGuid();
var organization1 = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = $"test+{id}@example.com",
Plan = "Test",
PrivateKey = "privatekey",
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization1.Id,
DomainName = $"domain2+{id}@example.com",
Txt = "btw+12345"
};
var within36HoursWindow = 1;
organizationDomain.SetNextRunDate(within36HoursWindow);
await organizationDomainRepository.CreateAsync(organizationDomain);
var date = organizationDomain.NextRunDate;
// Act
var domains = await organizationDomainRepository.GetManyByNextRunDateAsync(date);
// Assert
var expectedDomain = domains.FirstOrDefault(domain => domain.DomainName == organizationDomain.DomainName);
Assert.NotNull(expectedDomain);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByNextRunDateAsync_ShouldNotReturnUnverifiedDomains_WhenNextRunDateIsOutside36hoursWindow(
IOrganizationRepository organizationRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
// Arrange
var id = Guid.NewGuid();
var organization1 = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = $"test+{id}@example.com",
Plan = "Test",
PrivateKey = "privatekey",
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization1.Id,
DomainName = $"domain2+{id}@example.com",
Txt = "btw+12345"
};
var outside36HoursWindow = 20;
organizationDomain.SetNextRunDate(outside36HoursWindow);
await organizationDomainRepository.CreateAsync(organizationDomain);
var date = DateTimeOffset.UtcNow.Date.AddDays(1);
// Act
var domains = await organizationDomainRepository.GetManyByNextRunDateAsync(date);
// Assert
var expectedDomain = domains.FirstOrDefault(domain => domain.DomainName == organizationDomain.DomainName);
Assert.Null(expectedDomain);
}
[DatabaseTheory, DatabaseData]
public async Task GetManyByNextRunDateAsync_ShouldNotReturnVerifiedDomains(
IOrganizationRepository organizationRepository,
IOrganizationDomainRepository organizationDomainRepository)
{
// Arrange
var id = Guid.NewGuid();
var organization1 = await organizationRepository.CreateAsync(new Organization
{
Name = $"Test Org {id}",
BillingEmail = $"test+{id}@example.com",
Plan = "Test",
PrivateKey = "privatekey",
});
var organizationDomain = new OrganizationDomain
{
OrganizationId = organization1.Id,
DomainName = $"domain2+{id}@example.com",
Txt = "btw+12345"
};
var within36HoursWindow = 1;
organizationDomain.SetNextRunDate(within36HoursWindow);
organizationDomain.SetVerifiedDate();
await organizationDomainRepository.CreateAsync(organizationDomain);
var date = DateTimeOffset.UtcNow.Date.AddDays(1);
// Act
var domains = await organizationDomainRepository.GetManyByNextRunDateAsync(date);
// Assert
var expectedDomain = domains.FirstOrDefault(domain => domain.DomainName == organizationDomain.DomainName);
Assert.Null(expectedDomain);
}
}

View File

@ -14,6 +14,7 @@ using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using NSubstitute;
@ -188,44 +189,27 @@ public abstract class WebApplicationFactoryBase<T> : WebApplicationFactory<T>
// QUESTION: The normal licensing service should run fine on developer machines but not in CI
// should we have a fork here to leave the normal service for developers?
// TODO: Eventually add the license file to CI
var licensingService = services.First(sd => sd.ServiceType == typeof(ILicensingService));
services.Remove(licensingService);
services.AddSingleton<ILicensingService, NoopLicensingService>();
Replace<ILicensingService, NoopLicensingService>(services);
// FUTURE CONSIDERATION: Add way to run this self hosted/cloud, for now it is cloud only
var pushRegistrationService = services.First(sd => sd.ServiceType == typeof(IPushRegistrationService));
services.Remove(pushRegistrationService);
services.AddSingleton<IPushRegistrationService, NoopPushRegistrationService>();
Replace<IPushRegistrationService, NoopPushRegistrationService>(services);
// Even though we are cloud we currently set this up as cloud, we can use the EF/selfhosted service
// instead of using Noop for this service
// TODO: Install and use azurite in CI pipeline
var eventWriteService = services.First(sd => sd.ServiceType == typeof(IEventWriteService));
services.Remove(eventWriteService);
services.AddSingleton<IEventWriteService, RepositoryEventWriteService>();
Replace<IEventWriteService, RepositoryEventWriteService>(services);
var eventRepositoryService = services.First(sd => sd.ServiceType == typeof(IEventRepository));
services.Remove(eventRepositoryService);
services.AddSingleton<IEventRepository, EventRepository>();
Replace<IEventRepository, EventRepository>(services);
var mailDeliveryService = services.First(sd => sd.ServiceType == typeof(IMailDeliveryService));
services.Remove(mailDeliveryService);
services.AddSingleton<IMailDeliveryService, NoopMailDeliveryService>();
Replace<IMailDeliveryService, NoopMailDeliveryService>(services);
var captchaValidationService = services.First(sd => sd.ServiceType == typeof(ICaptchaValidationService));
services.Remove(captchaValidationService);
services.AddSingleton<ICaptchaValidationService, NoopCaptchaValidationService>();
Replace<ICaptchaValidationService, NoopCaptchaValidationService>(services);
// TODO: Install and use azurite in CI pipeline
var installationDeviceRepository =
services.First(sd => sd.ServiceType == typeof(IInstallationDeviceRepository));
services.Remove(installationDeviceRepository);
services.AddSingleton<IInstallationDeviceRepository, NoopRepos.InstallationDeviceRepository>();
Replace<IInstallationDeviceRepository, NoopRepos.InstallationDeviceRepository>(services);
// TODO: Install and use azurite in CI pipeline
var referenceEventService = services.First(sd => sd.ServiceType == typeof(IReferenceEventService));
services.Remove(referenceEventService);
services.AddSingleton<IReferenceEventService, NoopReferenceEventService>();
Replace<IReferenceEventService, NoopReferenceEventService>(services);
// Our Rate limiter works so well that it begins to fail tests unless we carve out
// one whitelisted ip. We should still test the rate limiter though and they should change the Ip
@ -245,14 +229,9 @@ public abstract class WebApplicationFactoryBase<T> : WebApplicationFactory<T>
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
// Noop StripePaymentService - this could be changed to integrate with our Stripe test account
var stripePaymentService = services.First(sd => sd.ServiceType == typeof(IPaymentService));
services.Remove(stripePaymentService);
services.AddSingleton(Substitute.For<IPaymentService>());
Replace(services, Substitute.For<IPaymentService>());
var organizationBillingService =
services.First(sd => sd.ServiceType == typeof(IOrganizationBillingService));
services.Remove(organizationBillingService);
services.AddSingleton(Substitute.For<IOrganizationBillingService>());
Replace(services, Substitute.For<IOrganizationBillingService>());
});
foreach (var configureTestService in _configureTestServices)
@ -261,6 +240,35 @@ public abstract class WebApplicationFactoryBase<T> : WebApplicationFactory<T>
}
}
private static void Replace<TService, TNewImplementation>(IServiceCollection services)
where TService : class
where TNewImplementation : class, TService
{
services.RemoveAll<TService>();
services.AddSingleton<TService, TNewImplementation>();
}
private static void Replace<TService>(IServiceCollection services, TService implementation)
where TService : class
{
services.RemoveAll<TService>();
services.AddSingleton<TService>(implementation);
}
public HttpClient CreateAuthedClient(string accessToken)
{
var handler = Server.CreateHandler((context) =>
{
context.Request.Headers.Authorization = $"Bearer {accessToken}";
});
return new HttpClient(handler)
{
BaseAddress = Server.BaseAddress,
Timeout = TimeSpan.FromSeconds(200),
};
}
public DatabaseContext GetDatabaseContext()
{
var scope = Services.CreateScope();

View File

@ -0,0 +1,31 @@
-- Drop existing SPROC
IF OBJECT_ID('[dbo].[Organization_ReadAddableToProviderByUserId') IS NOT NULL
BEGIN
DROP PROCEDURE [dbo].[Organization_ReadAddableToProviderByUserId]
END
GO
CREATE PROCEDURE [dbo].[Organization_ReadAddableToProviderByUserId]
@UserId UNIQUEIDENTIFIER,
@ProviderType TINYINT
AS
BEGIN
SET NOCOUNT ON
SELECT O.* FROM [dbo].[OrganizationUser] AS OU
JOIN [dbo].[Organization] AS O ON O.[Id] = OU.[OrganizationId]
WHERE
OU.[UserId] = @UserId AND
OU.[Type] = 0 AND
OU.[Status] = 2 AND
O.[Enabled] = 1 AND
O.[GatewayCustomerId] IS NOT NULL AND
O.[GatewaySubscriptionId] IS NOT NULL AND
O.[Seats] > 0 AND
O.[Status] = 1 AND
O.[UseSecretsManager] = 0 AND
-- All Teams & Enterprise for MSP
(@ProviderType = 0 AND O.[PlanType] IN (2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20) OR
-- All Enterprise for MOE
@ProviderType = 2 AND O.[PlanType] IN (4, 5, 10, 11, 14, 15, 19, 20));
END
GO

View File

@ -0,0 +1,54 @@
-- Drop existing SPROC
IF OBJECT_ID('[dbo].[Organization_UnassignedToProviderSearch]') IS NOT NULL
BEGIN
DROP PROCEDURE [dbo].[Organization_UnassignedToProviderSearch]
END
GO
CREATE PROCEDURE [dbo].[Organization_UnassignedToProviderSearch]
@Name NVARCHAR(55),
@OwnerEmail NVARCHAR(256),
@Skip INT = 0,
@Take INT = 25
WITH RECOMPILE
AS
BEGIN
SET NOCOUNT ON
DECLARE @NameLikeSearch NVARCHAR(55) = '%' + @Name + '%'
DECLARE @OwnerLikeSearch NVARCHAR(55) = @OwnerEmail + '%'
IF @OwnerEmail IS NOT NULL
BEGIN
SELECT
O.*
FROM
[dbo].[OrganizationView] O
INNER JOIN
[dbo].[OrganizationUser] OU ON O.[Id] = OU.[OrganizationId]
INNER JOIN
[dbo].[User] U ON U.[Id] = OU.[UserId]
WHERE
O.[PlanType] NOT IN (0, 1, 6, 7) -- Not 'Free', 'Custom' or 'Families'
AND NOT EXISTS (SELECT * FROM [dbo].[ProviderOrganizationView] PO WHERE PO.[OrganizationId] = O.[Id])
AND (@Name IS NULL OR O.[Name] LIKE @NameLikeSearch)
AND (U.[Email] LIKE @OwnerLikeSearch)
ORDER BY O.[CreationDate] DESC, O.[Id]
OFFSET @Skip ROWS
FETCH NEXT @Take ROWS ONLY
END
ELSE
BEGIN
SELECT
O.*
FROM
[dbo].[OrganizationView] O
WHERE
O.[PlanType] NOT IN (0, 1, 6, 7) -- Not 'Free', 'Custom' or 'Families'
AND NOT EXISTS (SELECT * FROM [dbo].[ProviderOrganizationView] PO WHERE PO.[OrganizationId] = O.[Id])
AND (@Name IS NULL OR O.[Name] LIKE @NameLikeSearch)
ORDER BY O.[CreationDate] DESC, O.[Id]
OFFSET @Skip ROWS
FETCH NEXT @Take ROWS ONLY
END
END
GO

View File

@ -0,0 +1,7 @@
-- Refresh Views
IF OBJECT_ID('[dbo].[OrganizationView]') IS NOT NULL
BEGIN
EXECUTE sp_refreshview N'[dbo].[OrganizationView]';
END
GO