From bae03feffecbef488cb52f5f5bc133dfdbbaa316 Mon Sep 17 00:00:00 2001 From: Justin Baur <19896123+justindbaur@users.noreply.github.com> Date: Mon, 29 Aug 2022 15:53:48 -0400 Subject: [PATCH] Revert filescoped (#2227) * Revert "Add git blame entry (#2226)" This reverts commit 239286737d15cb84a893703ee5a8b33a2d67ad3d. * Revert "Turn on file scoped namespaces (#2225)" This reverts commit 34fb4cca2aa78deb84d4cbc359992a7c6bba7ea5. --- .editorconfig | 3 - .git-blame-ignore-revs | 3 - README.md | 12 - .../Services/ProviderService.cs | 905 ++-- .../Utilities/ServiceCollectionExtensions.cs | 11 +- .../src/Scim/Context/IScimContext.cs | 25 +- .../src/Scim/Context/ScimContext.cs | 89 +- .../src/Scim/Controllers/InfoController.cs | 27 +- .../Scim/Controllers/v2/GroupsController.cs | 553 +- .../Scim/Controllers/v2/UsersController.cs | 511 +- .../src/Scim/Models/BaseScimGroupModel.cs | 19 +- .../src/Scim/Models/BaseScimModel.cs | 21 +- .../src/Scim/Models/BaseScimUserModel.cs | 85 +- .../src/Scim/Models/ScimErrorResponseModel.cs | 17 +- .../src/Scim/Models/ScimGroupRequestModel.cs | 41 +- .../src/Scim/Models/ScimGroupResponseModel.cs | 37 +- .../src/Scim/Models/ScimListResponseModel.cs | 21 +- .../src/Scim/Models/ScimMetaModel.cs | 19 +- .../src/Scim/Models/ScimPatchModel.cs | 25 +- .../src/Scim/Models/ScimUserRequestModel.cs | 13 +- .../src/Scim/Models/ScimUserResponseModel.cs | 43 +- bitwarden_license/src/Scim/Program.cs | 49 +- bitwarden_license/src/Scim/ScimSettings.cs | 7 +- bitwarden_license/src/Scim/Startup.cs | 185 +- .../Utilities/ApiKeyAuthenticationHandler.cs | 139 +- .../Utilities/ApiKeyAuthenticationOptions.cs | 9 +- .../src/Scim/Utilities/ScimConstants.cs | 15 +- .../Scim/Utilities/ScimContextMiddleware.cs | 27 +- .../src/Sso/Controllers/AccountController.cs | 1217 ++--- .../src/Sso/Controllers/HomeController.cs | 73 +- .../src/Sso/Controllers/InfoController.cs | 25 +- .../src/Sso/Controllers/MetadataController.cs | 105 +- .../src/Sso/Models/ErrorViewModel.cs | 37 +- .../src/Sso/Models/RedirectViewModel.cs | 9 +- .../src/Sso/Models/SamlEnvironment.cs | 9 +- .../Sso/Models/SsoPreValidateResponseModel.cs | 15 +- bitwarden_license/src/Sso/Program.cs | 47 +- bitwarden_license/src/Sso/Startup.cs | 263 +- .../src/Sso/Utilities/ClaimsExtensions.cs | 57 +- .../Utilities/DiscoveryResponseGenerator.cs | 47 +- .../Utilities/DynamicAuthenticationScheme.cs | 143 +- .../DynamicAuthenticationSchemeProvider.cs | 745 +-- .../Utilities/ExtendedOptionsMonitorCache.cs | 51 +- .../Utilities/IDynamicAuthenticationScheme.cs | 13 +- .../Utilities/IExtendedOptionsMonitorCache.cs | 9 +- .../OpenIdConnectOptionsExtensions.cs | 93 +- .../src/Sso/Utilities/OpenIdConnectScopes.cs | 115 +- .../Sso/Utilities/Saml2OptionsExtensions.cs | 179 +- .../src/Sso/Utilities/SamlClaimTypes.cs | 19 +- .../src/Sso/Utilities/SamlNameIdFormats.cs | 31 +- .../src/Sso/Utilities/SamlPropertyKeys.cs | 9 +- .../Utilities/ServiceCollectionExtensions.cs | 115 +- .../Utilities/SsoAuthenticationMiddleware.cs | 121 +- .../AutoFixture/ProviderUserFixtures.cs | 65 +- .../Services/ProviderServiceTests.cs | 975 ++-- src/Admin/AdminSettings.cs | 23 +- src/Admin/Controllers/ErrorController.cs | 29 +- src/Admin/Controllers/HomeController.cs | 171 +- src/Admin/Controllers/InfoController.cs | 25 +- src/Admin/Controllers/LoginController.cs | 133 +- src/Admin/Controllers/LogsController.cs | 143 +- .../Controllers/OrganizationsController.cs | 363 +- src/Admin/Controllers/ProvidersController.cs | 223 +- src/Admin/Controllers/ToolsController.cs | 925 ++-- src/Admin/Controllers/UsersController.cs | 171 +- .../AmazonSqsBlockIpHostedService.cs | 135 +- .../AzureQueueBlockIpHostedService.cs | 91 +- .../AzureQueueMailHostedService.cs | 129 +- .../HostedServices/BlockIpHostedService.cs | 243 +- .../DatabaseMigrationHostedService.cs | 89 +- src/Admin/Jobs/AliveJob.cs | 37 +- src/Admin/Jobs/DatabaseExpiredGrantsJob.cs | 33 +- .../Jobs/DatabaseExpiredSponsorshipsJob.cs | 47 +- src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs | 33 +- src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs | 39 +- src/Admin/Jobs/DeleteCiphersJob.cs | 47 +- src/Admin/Jobs/DeleteSendsJob.cs | 49 +- src/Admin/Jobs/JobsHostedService.cs | 163 +- src/Admin/Models/BillingInformationModel.cs | 13 +- src/Admin/Models/ChargeBraintreeModel.cs | 35 +- src/Admin/Models/CreateProviderModel.cs | 15 +- .../Models/CreateUpdateTransactionModel.cs | 133 +- src/Admin/Models/CursorPagedModel.cs | 15 +- src/Admin/Models/ErrorViewModel.cs | 11 +- src/Admin/Models/HomeModel.cs | 11 +- src/Admin/Models/LicenseModel.cs | 49 +- src/Admin/Models/LogModel.cs | 85 +- src/Admin/Models/LoginModel.cs | 19 +- src/Admin/Models/LogsModel.cs | 15 +- src/Admin/Models/OrganizationEditModel.cs | 273 +- src/Admin/Models/OrganizationViewModel.cs | 83 +- src/Admin/Models/OrganizationsModel.cs | 17 +- src/Admin/Models/PagedModel.cs | 17 +- src/Admin/Models/PromoteAdminModel.cs | 19 +- src/Admin/Models/ProviderEditModel.cs | 47 +- src/Admin/Models/ProviderViewModel.cs | 31 +- src/Admin/Models/ProvidersModel.cs | 17 +- src/Admin/Models/StripeSubscriptionsModel.cs | 65 +- src/Admin/Models/TaxRateAddEditModel.cs | 17 +- src/Admin/Models/TaxRatesModel.cs | 9 +- src/Admin/Models/UserEditModel.cs | 119 +- src/Admin/Models/UserViewModel.cs | 23 +- src/Admin/Models/UsersModel.cs | 11 +- src/Admin/Program.cs | 55 +- src/Admin/Startup.cs | 205 +- src/Admin/TagHelpers/ActivePageTagHelper.cs | 107 +- .../TagHelpers/OptionSelectedTagHelper.cs | 55 +- .../Controllers/AccountsBillingController.cs | 69 +- src/Api/Controllers/AccountsController.cs | 1601 +++--- src/Api/Controllers/CiphersController.cs | 1513 +++--- src/Api/Controllers/CollectionsController.cs | 411 +- src/Api/Controllers/DevicesController.cs | 195 +- .../Controllers/EmergencyAccessController.cs | 293 +- src/Api/Controllers/EventsController.cs | 275 +- src/Api/Controllers/FoldersController.cs | 133 +- src/Api/Controllers/GroupsController.cs | 225 +- src/Api/Controllers/HibpController.cs | 137 +- src/Api/Controllers/InfoController.cs | 53 +- .../Controllers/InstallationsController.cs | 55 +- src/Api/Controllers/LicensesController.cs | 111 +- src/Api/Controllers/MiscController.cs | 65 +- .../OrganizationConnectionsController.cs | 317 +- .../OrganizationExportController.cs | 91 +- .../OrganizationSponsorshipsController.cs | 307 +- .../OrganizationUsersController.cs | 733 +-- .../Controllers/OrganizationsController.cs | 1317 ++--- src/Api/Controllers/PlansController.cs | 47 +- src/Api/Controllers/PoliciesController.cs | 239 +- .../ProviderOrganizationsController.cs | 129 +- .../Controllers/ProviderUsersController.cs | 291 +- src/Api/Controllers/ProvidersController.cs | 129 +- src/Api/Controllers/PushController.cs | 183 +- ...ostedOrganizationSponsorshipsController.cs | 97 +- src/Api/Controllers/SendsController.cs | 543 +- src/Api/Controllers/SettingsController.cs | 63 +- src/Api/Controllers/SyncController.cs | 141 +- src/Api/Controllers/TwoFactorController.cs | 827 +-- src/Api/Controllers/UsersController.cs | 39 +- src/Api/Jobs/AliveJob.cs | 19 +- .../Jobs/EmergencyAccessNotificationJob.cs | 29 +- src/Api/Jobs/EmergencyAccessTimeoutJob.cs | 29 +- src/Api/Jobs/JobsHostedService.cs | 135 +- src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs | 85 +- src/Api/Jobs/ValidateOrganizationsJob.cs | 29 +- src/Api/Jobs/ValidateUsersJob.cs | 29 +- src/Api/Models/CipherAttachmentModel.cs | 27 +- src/Api/Models/CipherCardModel.cs | 63 +- src/Api/Models/CipherFieldModel.cs | 53 +- src/Api/Models/CipherIdentityModel.cs | 159 +- src/Api/Models/CipherLoginModel.cs | 131 +- src/Api/Models/CipherPasswordHistoryModel.cs | 37 +- src/Api/Models/CipherSecureNoteModel.cs | 19 +- .../AssociationWithPermissionsBaseModel.cs | 29 +- src/Api/Models/Public/CollectionBaseModel.cs | 19 +- src/Api/Models/Public/GroupBaseModel.cs | 45 +- src/Api/Models/Public/MemberBaseModel.cs | 91 +- src/Api/Models/Public/PolicyBaseModel.cs | 25 +- .../AssociationWithPermissionsRequestModel.cs | 17 +- .../Request/CollectionUpdateRequestModel.cs | 23 +- .../Public/Request/EventFilterRequestModel.cs | 81 +- .../Request/GroupCreateUpdateRequestModel.cs | 39 +- .../Request/MemberCreateRequestModel.cs | 29 +- .../Request/MemberUpdateRequestModel.cs | 27 +- .../Request/OrganizationImportRequestModel.cs | 171 +- .../Request/PolicyUpdateRequestModel.cs | 27 +- .../Request/UpdateGroupIdsRequestModel.cs | 15 +- .../Request/UpdateMemberIdsRequestModel.cs | 15 +- ...AssociationWithPermissionsResponseModel.cs | 17 +- .../Response/CollectionResponseModel.cs | 59 +- .../Public/Response/ErrorResponseModel.cs | 135 +- .../Public/Response/EventResponseModel.cs | 163 +- .../Public/Response/GroupResponseModel.cs | 63 +- .../Models/Public/Response/IResponseModel.cs | 9 +- .../Public/Response/ListResponseModel.cs | 45 +- .../Public/Response/MemberResponseModel.cs | 155 +- .../Public/Response/PolicyResponseModel.cs | 69 +- .../Accounts/DeleteRecoverRequestModel.cs | 15 +- .../Request/Accounts/EmailRequestModel.cs | 29 +- .../Accounts/EmailTokenRequestModel.cs | 15 +- .../Accounts/ImportCiphersRequestModel.cs | 13 +- .../Request/Accounts/KdfRequestModel.cs | 39 +- .../OrganizationApiKeyRequestModel.cs | 9 +- .../Accounts/PasswordHintRequestModel.cs | 15 +- .../Request/Accounts/PasswordRequestModel.cs | 21 +- .../Request/Accounts/PremiumRequestModel.cs | 57 +- .../RegenerateTwoFactorRequestModel.cs | 17 +- .../SecretVerificationRequestModel.cs | 23 +- .../SetKeyConnectorKeyRequestModel.cs | 41 +- .../Accounts/SetPasswordRequestModel.cs | 51 +- .../Request/Accounts/StorageRequestModel.cs | 21 +- .../Accounts/TaxInfoUpdateRequestModel.cs | 23 +- .../Request/Accounts/UpdateKeyRequestModel.cs | 31 +- .../Accounts/UpdateProfileRequestModel.cs | 27 +- .../UpdateTempPasswordRequestModel.cs | 11 +- .../VerifyDeleteRecoverRequestModel.cs | 15 +- .../Accounts/VerifyEmailRequestModel.cs | 15 +- .../Request/Accounts/VerifyOTPRequestModel.cs | 11 +- .../Models/Request/AttachmentRequestModel.cs | 15 +- .../Request/BitPayInvoiceRequestModel.cs | 103 +- .../Request/CipherPartialRequestModel.cs | 13 +- src/Api/Models/Request/CipherRequestModel.cs | 625 +-- .../Models/Request/CollectionRequestModel.cs | 43 +- src/Api/Models/Request/DeviceRequestModels.cs | 73 +- .../Request/DeviceVerificationRequestModel.cs | 19 +- .../Request/EmergencyAccessRequstModels.cs | 75 +- src/Api/Models/Request/FolderRequestModel.cs | 41 +- src/Api/Models/Request/GroupRequestModel.cs | 47 +- .../Models/Request/IapCheckRequestModel.cs | 21 +- .../Request/InstallationRequestModel.cs | 29 +- src/Api/Models/Request/LicenseRequestModel.cs | 11 +- .../ImportOrganizationCiphersRequestModel.cs | 13 +- .../ImportOrganizationUsersRequestModel.cs | 105 +- .../OrganizationConnectionRequestModel.cs | 75 +- .../OrganizationCreateLicenseRequestModel.cs | 19 +- .../OrganizationCreateRequestModel.cs | 169 +- .../OrganizationKeysRequestModel.cs | 85 +- .../OrganizationSeatRequestModel.cs | 19 +- ...ganizationSponsorshipCreateRequestModel.cs | 23 +- ...ganizationSponsorshipRedeemRequestModel.cs | 15 +- .../OrganizationSsoRequestModel.cs | 411 +- ...anizationSubscriptionUpdateRequestModel.cs | 13 +- .../OrganizationTaxInfoUpdateRequestModel.cs | 17 +- .../OrganizationUpdateRequestModel.cs | 53 +- .../OrganizationUpgradeRequestModel.cs | 59 +- .../OrganizationUserRequestModels.cs | 167 +- ...ganizationUserResetPasswordRequestModel.cs | 17 +- .../OrganizationVerifyBankRequestModel.cs | 19 +- src/Api/Models/Request/PaymentRequestModel.cs | 15 +- src/Api/Models/Request/PolicyRequestModel.cs | 41 +- .../ProviderOrganizationAddRequestModel.cs | 15 +- .../ProviderOrganizationCreateRequestModel.cs | 19 +- .../Providers/ProviderSetupRequestModel.cs | 45 +- .../Providers/ProviderUpdateRequestModel.cs | 41 +- .../Providers/ProviderUserRequestModels.cs | 111 +- .../Request/SelectionReadOnlyRequestModel.cs | 29 +- .../Models/Request/SendAccessRequestModel.cs | 11 +- src/Api/Models/Request/SendRequestModel.cs | 229 +- .../Models/Request/TwoFactorRequestModels.cs | 445 +- .../Request/UpdateDomainsRequestModel.cs | 23 +- .../Models/Response/ApiKeyResponseModel.cs | 41 +- .../Response/AttachmentResponseModel.cs | 77 +- .../AttachmentUploadDataResponseModel.cs | 19 +- .../Response/BillingHistoryResponseModel.cs | 19 +- .../Response/BillingPaymentResponseModel.cs | 21 +- .../Models/Response/BillingResponseModel.cs | 131 +- .../Models/Response/CipherResponseModel.cs | 261 +- .../Response/CollectionResponseModel.cs | 69 +- .../Models/Response/DeviceResponseModel.cs | 37 +- .../DeviceVerificationResponseModel.cs | 21 +- .../Models/Response/DomainsResponseModel.cs | 81 +- .../Response/EmergencyAccessResponseModel.cs | 159 +- src/Api/Models/Response/EventResponseModel.cs | 81 +- .../Models/Response/FolderResponseModel.cs | 29 +- src/Api/Models/Response/GroupResponseModel.cs | 53 +- .../Response/InstallationResponseModel.cs | 25 +- src/Api/Models/Response/KeysResponseModel.cs | 29 +- src/Api/Models/Response/ListResponseModel.cs | 21 +- .../OrganizationExportResponseModel.cs | 17 +- ...anizationApiKeyInformationResponseModel.cs | 19 +- ...ganizationAutoEnrollStatusResponseModel.cs | 19 +- .../OrganizationConnectionResponseModel.cs | 39 +- .../OrganizationKeysResponseModel.cs | 23 +- .../OrganizationResponseModel.cs | 187 +- ...ationSponsorshipSyncStatusResponseModel.cs | 17 +- .../OrganizationSsoResponseModel.cs | 57 +- .../OrganizationUserResponseModel.cs | 215 +- .../Models/Response/PaymentResponseModel.cs | 19 +- src/Api/Models/Response/PlanResponseModel.cs | 181 +- .../Models/Response/PolicyResponseModel.cs | 43 +- .../ProfileOrganizationResponseModel.cs | 177 +- ...rofileProviderOrganizationResponseModel.cs | 71 +- .../Models/Response/ProfileResponseModel.cs | 101 +- .../Providers/ProfileProviderResponseModel.cs | 49 +- .../ProviderOrganizationResponseModel.cs | 115 +- .../Providers/ProviderResponseModel.cs | 51 +- .../Providers/ProviderUserResponseModel.cs | 127 +- .../SelectionReadOnlyResponseModel.cs | 27 +- .../Response/SendAccessResponseModel.cs | 75 +- .../SendFileDownloadDataResponseModel.cs | 13 +- .../SendFileUploadDataResponseModel.cs | 15 +- src/Api/Models/Response/SendResponseModel.cs | 113 +- .../Response/SubscriptionResponseModel.cs | 173 +- src/Api/Models/Response/SyncResponseModel.cs | 73 +- .../Models/Response/TaxInfoResponseModel.cs | 51 +- .../Models/Response/TaxRateResponseModel.cs | 37 +- .../TwoFactorAuthenticatorResponseModel.cs | 45 +- .../TwoFactor/TwoFactorDuoResponseModel.cs | 87 +- .../TwoFactor/TwoFactorEmailResponseModel.cs | 41 +- .../TwoFactorProviderResponseModel.cs | 69 +- .../TwoFactorRecoverResponseModel.cs | 21 +- .../TwoFactorWebAuthnResponseModel.cs | 53 +- .../TwoFactorYubiKeyResponseModel.cs | 99 +- .../Models/Response/UserKeyResponseModel.cs | 21 +- src/Api/Models/SendFileModel.cs | 37 +- src/Api/Models/SendTextModel.cs | 27 +- src/Api/Program.cs | 69 +- .../Controllers/CollectionsController.cs | 195 +- .../Public/Controllers/EventsController.cs | 99 +- .../Public/Controllers/GroupsController.cs | 309 +- .../Public/Controllers/MembersController.cs | 397 +- .../Controllers/OrganizationController.cs | 81 +- .../Public/Controllers/PoliciesController.cs | 163 +- src/Api/Startup.cs | 365 +- .../Utilities/ApiExplorerGroupConvention.cs | 13 +- src/Api/Utilities/ApiHelpers.cs | 109 +- .../DisableFormValueModelBindingAttribute.cs | 25 +- .../ExceptionHandlerFilterAttribute.cs | 175 +- .../ModelStateValidationFilterAttribute.cs | 29 +- src/Api/Utilities/MultipartFormDataHelper.cs | 187 +- .../PublicApiControllersModelConvention.cs | 17 +- src/Api/Utilities/SecretsManagerAttribute.cs | 28 +- .../Utilities/ServiceCollectionExtensions.cs | 117 +- src/Billing/BillingSettings.cs | 37 +- src/Billing/Constants/HandledStripeWebhook.cs | 23 +- src/Billing/Controllers/AppleController.cs | 85 +- src/Billing/Controllers/BitPayController.cs | 351 +- .../Controllers/FreshdeskController.cs | 257 +- .../Controllers/FreshsalesController.cs | 395 +- src/Billing/Controllers/InfoController.cs | 25 +- src/Billing/Controllers/LoginController.cs | 81 +- src/Billing/Controllers/PayPalController.cs | 409 +- src/Billing/Controllers/StripeController.cs | 1495 +++--- src/Billing/Jobs/JobsHostedService.cs | 65 +- src/Billing/Models/BitPayEventModel.cs | 45 +- src/Billing/Models/FreshdeskWebhookModel.cs | 19 +- src/Billing/Models/LoginModel.cs | 13 +- src/Billing/Program.cs | 57 +- src/Billing/Startup.cs | 167 +- src/Billing/Utilities/PayPalIpnClient.cs | 273 +- src/Core/Constants.cs | 35 +- .../Context/CurrentContentOrganization.cs | 27 +- src/Core/Context/CurrentContentProvider.cs | 27 +- src/Core/Context/CurrentContext.cs | 827 +-- src/Core/Context/ICurrentContext.cs | 107 +- src/Core/Entities/Cipher.cs | 161 +- src/Core/Entities/Collection.cs | 27 +- src/Core/Entities/CollectionCipher.cs | 11 +- src/Core/Entities/CollectionGroup.cs | 15 +- src/Core/Entities/CollectionUser.cs | 15 +- src/Core/Entities/Device.cs | 35 +- src/Core/Entities/EmergencyAccess.cs | 73 +- src/Core/Entities/Event.cs | 89 +- src/Core/Entities/Folder.cs | 23 +- src/Core/Entities/Grant.cs | 39 +- src/Core/Entities/Group.cs | 31 +- src/Core/Entities/GroupUser.cs | 11 +- src/Core/Entities/IReferenceable.cs | 13 +- src/Core/Entities/IRevisable.cs | 11 +- src/Core/Entities/IStorable.cs | 15 +- src/Core/Entities/IStorableSubscriber.cs | 9 +- src/Core/Entities/ISubscriber.cs | 27 +- src/Core/Entities/ITableObject.cs | 11 +- src/Core/Entities/Installation.cs | 27 +- src/Core/Entities/Organization.cs | 339 +- src/Core/Entities/OrganizationApiKey.cs | 25 +- src/Core/Entities/OrganizationConnection.cs | 65 +- src/Core/Entities/OrganizationSponsorship.cs | 37 +- src/Core/Entities/OrganizationUser.cs | 43 +- src/Core/Entities/Policy.cs | 43 +- src/Core/Entities/Provider/Provider.cs | 43 +- .../Entities/Provider/ProviderOrganization.cs | 29 +- src/Core/Entities/Provider/ProviderUser.cs | 35 +- src/Core/Entities/Role.cs | 15 +- src/Core/Entities/Send.cs | 45 +- src/Core/Entities/SsoConfig.cs | 43 +- src/Core/Entities/SsoUser.cs | 27 +- src/Core/Entities/TaxRate.cs | 35 +- src/Core/Entities/Transaction.cs | 41 +- src/Core/Entities/User.cs | 325 +- src/Core/Enums/ApplicationCacheMessageType.cs | 11 +- src/Core/Enums/BitwardenClient.cs | 21 +- src/Core/Enums/CipherRepromptType.cs | 11 +- src/Core/Enums/CipherStateAction.cs | 13 +- src/Core/Enums/CipherType.cs | 19 +- src/Core/Enums/DeviceType.cs | 91 +- src/Core/Enums/EmergencyAccessStatusType.cs | 17 +- src/Core/Enums/EmergencyAccessType.cs | 11 +- src/Core/Enums/EncryptionType.cs | 21 +- src/Core/Enums/EventType.cs | 137 +- src/Core/Enums/FieldType.cs | 15 +- src/Core/Enums/FileUploadType.cs | 11 +- src/Core/Enums/GatewayType.cs | 35 +- src/Core/Enums/GlobalEquivalentDomainsType.cs | 185 +- src/Core/Enums/KdfType.cs | 9 +- src/Core/Enums/LicenseType.cs | 11 +- src/Core/Enums/OrganizationApiKeyType.cs | 13 +- src/Core/Enums/OrganizationConnectionType.cs | 11 +- src/Core/Enums/OrganizationUserStatusType.cs | 15 +- src/Core/Enums/OrganizationUserType.cs | 17 +- src/Core/Enums/PaymentMethodType.cs | 47 +- src/Core/Enums/PlanSponsorshipType.cs | 11 +- src/Core/Enums/PlanType.cs | 55 +- src/Core/Enums/PolicyType.cs | 29 +- src/Core/Enums/ProductType.cs | 23 +- src/Core/Enums/Provider/ProviderStatusType.cs | 11 +- .../Enums/Provider/ProviderUserStatusType.cs | 13 +- src/Core/Enums/Provider/ProviderUserType.cs | 11 +- src/Core/Enums/PushType.cs | 37 +- src/Core/Enums/ReferenceEventSource.cs | 15 +- src/Core/Enums/ReferenceEventType.cs | 79 +- src/Core/Enums/Saml2BindingType.cs | 11 +- src/Core/Enums/Saml2NameIdFormat.cs | 25 +- src/Core/Enums/Saml2SigningBehavior.cs | 13 +- src/Core/Enums/ScimProviderType.cs | 21 +- src/Core/Enums/SecureNoteType.cs | 9 +- src/Core/Enums/SendType.cs | 11 +- src/Core/Enums/SsoType.cs | 11 +- src/Core/Enums/SupportedDatabaseProviders.cs | 13 +- src/Core/Enums/TransactionType.cs | 27 +- src/Core/Enums/TwoFactorProviderType.cs | 23 +- src/Core/Enums/UriMatchType.cs | 19 +- src/Core/Exceptions/BadRequestException.cs | 41 +- src/Core/Exceptions/GatewayException.cs | 13 +- src/Core/Exceptions/InvalidEmailException.cs | 11 +- .../InvalidGatewayCustomerIdException.cs | 11 +- src/Core/Exceptions/NotFoundException.cs | 7 +- .../ApplicationCacheHostedService.cs | 159 +- .../IpRateLimitSeedStartupService.cs | 65 +- .../Identity/AuthenticatorTokenProvider.cs | 55 +- ...stomIdentityServiceCollectionExtensions.cs | 77 +- src/Core/Identity/DuoWebTokenProvider.cs | 117 +- src/Core/Identity/EmailTokenProvider.cs | 119 +- .../IOrganizationTwoFactorTokenProvider.cs | 13 +- .../LowerInvariantLookupNormalizer.cs | 27 +- .../OrganizationDuoWebTokenProvider.cs | 99 +- .../Identity/PasswordlessSignInManager.cs | 139 +- .../ReadOnlyDatabaseIdentityUserStore.cs | 53 +- .../Identity/ReadOnlyEnvIdentityUserStore.cs | 91 +- .../Identity/ReadOnlyIdentityUserStore.cs | 191 +- src/Core/Identity/RoleStore.cs | 89 +- .../TwoFactorRememberTokenProvider.cs | 23 +- src/Core/Identity/UserStore.cs | 285 +- src/Core/Identity/WebAuthnTokenProvider.cs | 233 +- src/Core/Identity/YubicoOtpTokenProvider.cs | 97 +- src/Core/IdentityServer/ApiClient.cs | 127 +- src/Core/IdentityServer/ApiResources.cs | 55 +- src/Core/IdentityServer/ApiScopes.cs | 25 +- .../IdentityServer/AuthorizationCodeStore.cs | 57 +- .../IdentityServer/BaseRequestValidator.cs | 1077 ++-- src/Core/IdentityServer/ClientStore.cs | 305 +- ...onfigureOpenIdConnectDistributedOptions.cs | 69 +- .../CustomTokenRequestValidator.cs | 237 +- .../CustomValidatorRequestContext.cs | 13 +- .../DistributedCacheCookieManager.cs | 105 +- .../DistributedCacheTicketDataFormatter.cs | 93 +- .../IdentityServer/MemoryCacheTicketStore.cs | 77 +- src/Core/IdentityServer/OidcIdentityClient.cs | 33 +- .../IdentityServer/PersistedGrantStore.cs | 143 +- src/Core/IdentityServer/ProfileService.cs | 133 +- .../IdentityServer/RedisCacheTicketStore.cs | 89 +- .../ResourceOwnerPasswordValidator.cs | 279 +- src/Core/IdentityServer/StaticClientStore.cs | 29 +- src/Core/IdentityServer/TokenRetrieval.cs | 39 +- .../IdentityServer/VaultCorsPolicyService.cs | 23 +- src/Core/Jobs/BaseJob.cs | 37 +- src/Core/Jobs/BaseJobsHostedService.cs | 225 +- src/Core/Jobs/JobFactory.cs | 33 +- src/Core/Jobs/JobListener.cs | 55 +- .../Api/Request/Accounts/KeysRequestModel.cs | 33 +- .../Request/Accounts/PreloginRequestModel.cs | 15 +- .../Request/Accounts/RegisterRequestModel.cs | 125 +- .../Api/Request/ICaptchaProtectedModel.cs | 9 +- .../OrganizationSponsorshipRequestModel.cs | 87 +- ...OrganizationSponsorshipSyncRequestModel.cs | 57 +- .../Request/PushRegistrationRequestModel.cs | 27 +- .../Api/Request/PushSendRequestModel.cs | 31 +- .../Api/Request/PushUpdateRequestModel.cs | 29 +- .../Accounts/PreloginResponseModel.cs | 19 +- .../Models/Api/Response/ErrorResponseModel.cs | 129 +- .../OrganizationSponsorshipResponseModel.cs | 75 +- ...rganizationSponsorshipSyncResponseModel.cs | 39 +- src/Core/Models/Api/Response/ResponseModel.cs | 21 +- .../Models/Business/AppleReceiptStatus.cs | 235 +- src/Core/Models/Business/BillingInfo.cs | 249 +- src/Core/Models/Business/CaptchaResponse.cs | 15 +- src/Core/Models/Business/ExpiringToken.cs | 19 +- src/Core/Models/Business/ILicense.cs | 33 +- src/Core/Models/Business/ImportedGroup.cs | 11 +- .../Business/ImportedOrganizationUser.cs | 11 +- .../Models/Business/OrganizationLicense.cs | 545 +- .../Models/Business/OrganizationSignup.cs | 23 +- .../Models/Business/OrganizationUpgrade.cs | 23 +- .../Models/Business/OrganizationUserInvite.cs | 35 +- .../Business/Provider/ProviderUserInvite.cs | 51 +- src/Core/Models/Business/ReferenceEvent.cs | 103 +- .../Business/SubscriptionCreateOptions.cs | 131 +- src/Core/Models/Business/SubscriptionInfo.cs | 135 +- .../Models/Business/SubscriptionUpdate.cs | 369 +- src/Core/Models/Business/TaxInfo.cs | 283 +- .../EmergencyAccessInviteTokenable.cs | 53 +- .../Business/Tokenables/HCaptchaTokenable.cs | 61 +- .../OrganizationSponsorshipOfferTokenable.cs | 87 +- .../Business/Tokenables/SsoTokenable.cs | 63 +- src/Core/Models/Business/UserLicense.cs | 285 +- .../Models/Data/AttachmentResponseData.cs | 15 +- src/Core/Models/Data/CipherAttachment.cs | 57 +- src/Core/Models/Data/CipherCardData.cs | 21 +- src/Core/Models/Data/CipherData.cs | 17 +- src/Core/Models/Data/CipherDetails.cs | 15 +- src/Core/Models/Data/CipherFieldData.cs | 17 +- src/Core/Models/Data/CipherIdentityData.cs | 45 +- src/Core/Models/Data/CipherLoginData.cs | 45 +- .../Models/Data/CipherOrganizationDetails.cs | 9 +- .../Models/Data/CipherPasswordHistoryData.cs | 13 +- src/Core/Models/Data/CipherSecureNoteData.cs | 11 +- src/Core/Models/Data/CollectionDetails.cs | 11 +- src/Core/Models/Data/DictionaryEntity.cs | 205 +- .../Models/Data/EmergencyAccessDetails.cs | 15 +- src/Core/Models/Data/EmergencyAccessNotify.cs | 13 +- .../Models/Data/EmergencyAccessViewData.cs | 11 +- src/Core/Models/Data/EventMessage.cs | 55 +- src/Core/Models/Data/EventTableEntity.cs | 253 +- src/Core/Models/Data/GroupWithCollections.cs | 9 +- src/Core/Models/Data/IEvent.cs | 39 +- .../Models/Data/InstallationDeviceEntity.cs | 49 +- .../Data/Organizations/OrganizationAbility.cs | 57 +- .../OrganizationConnectionData.cs | 43 +- .../OrganizationSponsorshipData.cs | 47 +- .../OrganizationSponsorshipSyncData.cs | 13 +- .../OrganizationUserInviteData.cs | 17 +- .../OrganizationUserOrganizationDetails.cs | 81 +- .../OrganizationUserPublicKey.cs | 13 +- .../OrganizationUserResetPasswordDetails.cs | 47 +- .../OrganizationUserUserDetails.cs | 87 +- .../OrganizationUserWithCollections.cs | 9 +- .../Policies/IPolicyDataModel.cs | 7 +- .../Policies/ResetPasswordDataModel.cs | 11 +- .../Policies/SendOptionsPolicyData.cs | 11 +- src/Core/Models/Data/PageOptions.cs | 11 +- src/Core/Models/Data/PagedResult.cs | 11 +- src/Core/Models/Data/Permissions.cs | 79 +- .../Models/Data/Provider/ProviderAbility.cs | 27 +- ...ProviderOrganizationOrganizationDetails.cs | 29 +- .../ProviderUserOrganizationDetails.cs | 65 +- .../Provider/ProviderUserProviderDetails.cs | 27 +- .../Data/Provider/ProviderUserPublicKey.cs | 13 +- .../Data/Provider/ProviderUserUserDetails.cs | 23 +- src/Core/Models/Data/SelectionReadOnly.cs | 13 +- src/Core/Models/Data/SendData.cs | 23 +- src/Core/Models/Data/SendFileData.cs | 33 +- src/Core/Models/Data/SendTextData.cs | 25 +- src/Core/Models/Data/SsoConfigurationData.cs | 227 +- src/Core/Models/Data/UserKdfInformation.cs | 11 +- src/Core/Models/IExternal.cs | 9 +- src/Core/Models/ITwoFactorProvidersUser.cs | 15 +- src/Core/Models/Mail/AddedCreditViewModel.cs | 9 +- .../Mail/AdminResetPasswordViewModel.cs | 11 +- src/Core/Models/Mail/BaseMailModel.cs | 33 +- .../Models/Mail/ChangeEmailExistsViewModel.cs | 11 +- src/Core/Models/Mail/EmailTokenViewModel.cs | 9 +- .../Mail/EmergencyAccessAcceptedViewModel.cs | 9 +- .../Mail/EmergencyAccessApprovedViewModel.cs | 9 +- .../Mail/EmergencyAccessConfirmedViewModel.cs | 9 +- .../Mail/EmergencyAccessInvitedViewModel.cs | 17 +- ...mergencyAccessRecoveryTimedOutViewModel.cs | 11 +- .../Mail/EmergencyAccessRecoveryViewModel.cs | 13 +- .../Mail/EmergencyAccessRejectedViewModel.cs | 9 +- .../Models/Mail/FailedAuthAttemptsModel.cs | 9 +- .../FamiliesForEnterpriseOfferViewModel.cs | 29 +- ...EnterpriseSponsorshipRevertingViewModel.cs | 9 +- src/Core/Models/Mail/IMailQueueMessage.cs | 19 +- .../Models/Mail/InvoiceUpcomingViewModel.cs | 15 +- .../Models/Mail/LicenseExpiredViewModel.cs | 11 +- src/Core/Models/Mail/MailMessage.cs | 21 +- src/Core/Models/Mail/MailQueueMessage.cs | 43 +- .../Mail/MasterPasswordHintViewModel.cs | 9 +- .../Models/Mail/NewDeviceLoggedInModel.cs | 17 +- .../OrganizationSeatsAutoscaledViewModel.cs | 13 +- .../OrganizationSeatsMaxReachedViewModel.cs | 11 +- .../Mail/OrganizationUserAcceptedViewModel.cs | 13 +- .../OrganizationUserConfirmedViewModel.cs | 9 +- .../Mail/OrganizationUserInvitedViewModel.cs | 37 +- ...nUserRemovedForPolicySingleOrgViewModel.cs | 9 +- ...ionUserRemovedForPolicyTwoStepViewModel.cs | 9 +- .../Models/Mail/PasswordlessSignInModel.cs | 9 +- .../Models/Mail/PaymentFailedViewModel.cs | 11 +- .../Provider/ProviderSetupInviteViewModel.cs | 23 +- .../ProviderUserConfirmedViewModel.cs | 9 +- .../Provider/ProviderUserInvitedViewModel.cs | 35 +- .../Provider/ProviderUserRemovedViewModel.cs | 9 +- src/Core/Models/Mail/RecoverTwoFactorModel.cs | 15 +- .../Mail/UpdateTempPasswordViewModel.cs | 9 +- src/Core/Models/Mail/VerifyDeleteModel.cs | 25 +- src/Core/Models/Mail/VerifyEmailModel.cs | 19 +- .../BillingSyncConfig.cs | 11 +- .../ScimConfig.cs | 13 +- src/Core/Models/PushNotification.cs | 73 +- src/Core/Models/StaticStore/Plan.cs | 93 +- src/Core/Models/StaticStore/SponsoredPlan.cs | 17 +- .../Stripe/StripeSubscriptionListOptions.cs | 79 +- src/Core/Models/TwoFactorProvider.cs | 97 +- .../GetOrganizationApiKeyCommand.cs | 55 +- .../IGetOrganizationApiKeyCommand.cs | 9 +- .../IRotateOrganizationApiKeyCommand.cs | 9 +- .../RotateOrganizationApiKeyCommand.cs | 29 +- .../CreateOrganizationConnectionCommand.cs | 23 +- .../DeleteOrganizationConnectionCommand.cs | 23 +- .../ICreateOrganizationConnectionCommand.cs | 9 +- .../IDeleteOrganizationConnectionCommand.cs | 9 +- .../IUpdateOrganizationConnectionCommand.cs | 9 +- .../UpdateOrganizationConnectionCommand.cs | 41 +- ...OrganizationServiceCollectionExtensions.cs | 103 +- .../CancelSponsorshipCommand.cs | 49 +- .../Cloud/CloudRevokeSponsorshipCommand.cs | 37 +- .../Cloud/CloudSyncSponsorshipsCommand.cs | 197 +- .../OrganizationSponsorshipRenewCommand.cs | 33 +- .../Cloud/RemoveSponsorshipCommand.cs | 27 +- .../Cloud/SendSponsorshipOfferCommand.cs | 93 +- .../Cloud/SetUpSponsorshipCommand.cs | 99 +- .../Cloud/ValidateBillingSyncKeyCommand.cs | 49 +- .../Cloud/ValidateRedemptionTokenCommand.cs | 41 +- .../Cloud/ValidateSponsorshipCommand.cs | 189 +- .../CreateSponsorshipCommand.cs | 123 +- .../Interfaces/ICreateSponsorshipCommand.cs | 11 +- .../IOrganizationSponsorshipRenewCommand.cs | 9 +- .../Interfaces/IRemoveSponsorshipCommand.cs | 9 +- .../Interfaces/IRevokeSponsorshipCommand.cs | 9 +- .../ISendSponsorshipOfferCommand.cs | 15 +- .../Interfaces/ISetUpSponsorshipCommand.cs | 11 +- .../ISyncOrganizationSponsorshipsCommand.cs | 17 +- .../IValidateBillingSyncKeyCommand.cs | 9 +- .../IValidateRedemptionTokenCommand.cs | 9 +- .../Interfaces/IValidateSponsorshipCommand.cs | 9 +- .../SelfHostedRevokeSponsorshipCommand.cs | 37 +- .../SelfHostedSyncSponsorshipsCommand.cs | 203 +- src/Core/Repositories/ICipherRepository.cs | 65 +- .../ICollectionCipherRepository.cs | 21 +- .../Repositories/ICollectionRepository.cs | 29 +- src/Core/Repositories/IDeviceRepository.cs | 17 +- .../IEmergencyAccessRepository.cs | 19 +- src/Core/Repositories/IEventRepository.cs | 35 +- src/Core/Repositories/IFolderRepository.cs | 11 +- src/Core/Repositories/IGrantRepository.cs | 17 +- src/Core/Repositories/IGroupRepository.cs | 25 +- .../IInstallationDeviceRepository.cs | 13 +- .../Repositories/IInstallationRepository.cs | 7 +- .../Repositories/IMaintenanceRepository.cs | 17 +- src/Core/Repositories/IMetaDataRepository.cs | 17 +- .../IOrganizationApiKeyRepository.cs | 9 +- .../IOrganizationConnectionRepository.cs | 11 +- .../Repositories/IOrganizationRepository.cs | 19 +- .../IOrganizationSponsorshipRepository.cs | 23 +- .../IOrganizationUserRepository.cs | 69 +- src/Core/Repositories/IPolicyRepository.cs | 21 +- .../IProviderOrganizationRepository.cs | 11 +- src/Core/Repositories/IProviderRepository.cs | 11 +- .../Repositories/IProviderUserRepository.cs | 31 +- src/Core/Repositories/IRepository.cs | 17 +- src/Core/Repositories/ISendRepository.cs | 11 +- src/Core/Repositories/ISsoConfigRepository.cs | 13 +- src/Core/Repositories/ISsoUserRepository.cs | 11 +- src/Core/Repositories/ITaxRateRepository.cs | 15 +- .../Repositories/ITransactionRepository.cs | 13 +- src/Core/Repositories/IUserRepository.cs | 27 +- .../Noop/InstallationDeviceRepository.cs | 27 +- .../Repositories/Noop/MetaDataRepository.cs | 43 +- .../TableStorage/EventRepository.cs | 287 +- .../InstallationDeviceRepository.cs | 105 +- .../TableStorage/MetaDataRepository.cs | 141 +- src/Core/Resources/SharedResources.cs | 7 +- src/Core/Services/IAppleIapService.cs | 13 +- src/Core/Services/IApplicationCacheService.cs | 17 +- .../Services/IAttachmentStorageService.cs | 33 +- src/Core/Services/IBlockIpService.cs | 9 +- .../Services/ICaptchaValidationService.cs | 19 +- src/Core/Services/ICipherService.cs | 75 +- src/Core/Services/ICollectionService.cs | 15 +- src/Core/Services/IDeviceService.cs | 13 +- src/Core/Services/IEmergencyAccessService.cs | 41 +- src/Core/Services/IEventService.cs | 31 +- src/Core/Services/IEventWriteService.cs | 11 +- src/Core/Services/IGroupService.cs | 13 +- src/Core/Services/II18nService.cs | 15 +- src/Core/Services/ILicensingService.cs | 21 +- src/Core/Services/IMailDeliveryService.cs | 9 +- src/Core/Services/IMailEnqueuingService.cs | 11 +- src/Core/Services/IMailService.cs | 101 +- src/Core/Services/IOrganizationService.cs | 121 +- src/Core/Services/IPaymentService.cs | 61 +- src/Core/Services/IPolicyService.cs | 11 +- src/Core/Services/IProviderService.cs | 39 +- src/Core/Services/IPushNotificationService.cs | 41 +- src/Core/Services/IPushRegistrationService.cs | 17 +- src/Core/Services/IReferenceEventService.cs | 9 +- src/Core/Services/ISendService.cs | 23 +- src/Core/Services/ISendStorageService.cs | 23 +- src/Core/Services/ISsoConfigService.cs | 9 +- src/Core/Services/IStripeAdapter.cs | 71 +- src/Core/Services/IStripeSyncService.cs | 9 +- src/Core/Services/IUserService.cs | 143 +- .../AmazonSesMailDeliveryService.cs | 227 +- .../AmazonSqsBlockIpService.cs | 127 +- .../Implementations/AppleIapService.cs | 215 +- .../AzureAttachmentStorageService.cs | 429 +- .../AzureQueueBlockIpService.cs | 47 +- .../AzureQueueEventWriteService.cs | 17 +- .../Implementations/AzureQueueMailService.cs | 23 +- .../AzureQueuePushNotificationService.cs | 311 +- .../AzureQueueReferenceEventService.cs | 59 +- .../Implementations/AzureQueueService.cs | 109 +- .../AzureSendFileStorageService.cs | 205 +- .../BaseIdentityClientService.cs | 335 +- .../BlockingMailQueueService.cs | 21 +- .../Services/Implementations/CipherService.cs | 1777 +++---- .../Implementations/CollectionService.cs | 211 +- .../Services/Implementations/DeviceService.cs | 67 +- .../Implementations/EmergencyAccessService.cs | 761 +-- .../Services/Implementations/EventService.cs | 547 +- .../Services/Implementations/GroupService.cs | 117 +- .../HCaptchaValidationService.cs | 193 +- .../Implementations/HandlebarsMailService.cs | 1705 +++---- .../Services/Implementations/I18nService.cs | 49 +- .../Implementations/I18nViewLocalizer.cs | 43 +- .../InMemoryApplicationCacheService.cs | 159 +- ...MemoryServiceBusApplicationCacheService.cs | 95 +- .../Implementations/LicensingService.cs | 409 +- .../LocalAttachmentStorageService.cs | 327 +- .../LocalSendStorageService.cs | 175 +- .../MailKitSmtpMailDeliveryService.cs | 145 +- .../MultiServiceMailDeliveryService.cs | 53 +- .../MultiServicePushNotificationService.cs | 287 +- .../NotificationHubPushNotificationService.cs | 383 +- .../NotificationHubPushRegistrationService.cs | 311 +- ...NotificationsApiPushNotificationService.cs | 329 +- .../Implementations/OrganizationService.cs | 4437 +++++++++-------- .../Services/Implementations/PolicyService.cs | 273 +- .../RelayPushNotificationService.cs | 385 +- .../RelayPushRegistrationService.cs | 101 +- .../RepositoryEventWriteService.cs | 33 +- .../SendGridMailDeliveryService.cs | 173 +- .../Services/Implementations/SendService.cs | 609 +-- .../Implementations/SsoConfigService.cs | 163 +- .../Services/Implementations/StripeAdapter.cs | 405 +- .../Implementations/StripePaymentService.cs | 3189 ++++++------ .../Implementations/StripeSyncService.cs | 39 +- .../Services/Implementations/UserService.cs | 2799 +++++------ .../NoopAttachmentStorageService.cs | 103 +- .../NoopImplementations/NoopBlockIpService.cs | 13 +- .../NoopCaptchaValidationService.cs | 21 +- .../NoopImplementations/NoopEventService.cs | 103 +- .../NoopEventWriteService.cs | 19 +- .../NoopLicensingService.cs | 87 +- .../NoopMailDeliveryService.cs | 11 +- .../NoopImplementations/NoopMailService.cs | 373 +- .../NoopProviderService.cs | 35 +- .../NoopPushNotificationService.cs | 135 +- .../NoopPushRegistrationService.cs | 37 +- .../NoopReferenceEventService.cs | 11 +- .../NoopSendFileStorageService.cs | 63 +- src/Core/Settings/GlobalSettings.cs | 945 ++-- src/Core/Settings/IBaseServiceUriSettings.cs | 37 +- .../Settings/IConnectionStringSettings.cs | 10 +- src/Core/Settings/IFileStorageSettings.cs | 13 +- src/Core/Settings/IGlobalSettings.cs | 33 +- src/Core/Settings/IInstallationSettings.cs | 15 +- src/Core/Settings/ISsoSettings.cs | 11 +- src/Core/Settings/ITwoFactorAuthSettings.cs | 9 +- src/Core/Sso/SamlSigningAlgorithms.cs | 29 +- src/Core/Tokens/BadTokenException.cs | 15 +- src/Core/Tokens/DataProtectorTokenFactory.cs | 81 +- src/Core/Tokens/ExpiringTokenable.cs | 15 +- src/Core/Tokens/IBillingSyncTokenable.cs | 11 +- src/Core/Tokens/IDataProtectorTokenFactory.cs | 15 +- src/Core/Tokens/Token.cs | 53 +- src/Core/Tokens/Tokenable.cs | 25 +- src/Core/Utilities/BillingHelpers.cs | 83 +- src/Core/Utilities/BitPayClient.cs | 39 +- .../Utilities/CaptchaProtectedAttribute.cs | 37 +- src/Core/Utilities/ClaimsExtensions.cs | 11 +- src/Core/Utilities/CoreHelpers.cs | 1529 +++--- .../Utilities/CurrentContextMiddleware.cs | 25 +- .../Utilities/CustomIpRateLimitMiddleware.cs | 137 +- .../Utilities/DistributedCacheExtensions.cs | 73 +- src/Core/Utilities/DuoApi.cs | 491 +- src/Core/Utilities/DuoWeb.cs | 353 +- .../EncryptedStringLengthAttribute.cs | 21 +- src/Core/Utilities/EncryptedValueAttribute.cs | 249 +- .../Utilities/EpochDateTimeJsonConverter.cs | 19 +- .../HandlebarsObjectJsonConverter.cs | 21 +- src/Core/Utilities/HostBuilderExtensions.cs | 57 +- src/Core/Utilities/JsonHelpers.cs | 373 +- src/Core/Utilities/LoggerFactoryExtensions.cs | 247 +- .../LoggingExceptionHandlerFilterAttribute.cs | 25 +- .../Utilities/SecurityHeadersMiddleware.cs | 35 +- src/Core/Utilities/SelfHostedAttribute.cs | 29 +- src/Core/Utilities/StaticStore.cs | 977 ++-- .../Utilities/StrictEmailAddressAttribute.cs | 75 +- .../StrictEmailAddressListAttribute.cs | 49 +- src/Events/Controllers/CollectController.cs | 165 +- src/Events/Controllers/InfoController.cs | 25 +- src/Events/Models/EventModel.cs | 15 +- src/Events/Program.cs | 59 +- src/Events/Startup.cs | 193 +- .../AzureQueueHostedService.cs | 185 +- src/EventsProcessor/Program.cs | 29 +- src/EventsProcessor/Startup.cs | 87 +- src/Icons/Controllers/IconsController.cs | 165 +- src/Icons/Controllers/InfoController.cs | 25 +- src/Icons/IconsSettings.cs | 13 +- src/Icons/Models/DomainName.cs | 565 +-- src/Icons/Models/Icon.cs | 11 +- src/Icons/Models/IconResult.cs | 99 +- src/Icons/Program.cs | 29 +- src/Icons/Services/DomainMappingService.cs | 35 +- src/Icons/Services/IDomainMappingService.cs | 9 +- src/Icons/Services/IIconFetchingService.cs | 9 +- src/Icons/Services/IconFetchingService.cs | 727 +-- src/Icons/Startup.cs | 123 +- .../Controllers/AccountsController.cs | 93 +- src/Identity/Controllers/InfoController.cs | 25 +- src/Identity/Controllers/SsoController.cs | 467 +- src/Identity/Models/RedirectViewModel.cs | 9 +- src/Identity/Program.cs | 65 +- src/Identity/Startup.cs | 361 +- .../Utilities/DiscoveryResponseGenerator.cs | 47 +- .../Utilities/ServiceCollectionExtensions.cs | 73 +- src/Infrastructure.Dapper/DapperHelpers.cs | 211 +- .../DapperServiceCollectionExtensions.cs | 67 +- .../Repositories/BaseRepository.cs | 41 +- .../Repositories/CipherRepository.cs | 1551 +++--- .../CollectionCipherRepository.cs | 137 +- .../Repositories/CollectionRepository.cs | 259 +- .../Repositories/DateTimeHandler.cs | 19 +- .../Repositories/DeviceRepository.cs | 119 +- .../Repositories/EmergencyAccessRepository.cs | 125 +- .../Repositories/EventRepository.cs | 383 +- .../Repositories/FolderRepository.cs | 53 +- .../Repositories/GrantRepository.cs | 107 +- .../Repositories/GroupRepository.cs | 191 +- .../Repositories/InstallationRepository.cs | 19 +- .../Repositories/MaintenanceRepository.cs | 105 +- .../OrganizationApiKeyRepository.cs | 47 +- .../OrganizationConnectionRepository.cs | 43 +- .../Repositories/OrganizationRepository.cs | 129 +- .../OrganizationSponsorshipRepository.cs | 243 +- .../OrganizationUserRepository.cs | 763 +-- .../Repositories/PolicyRepository.cs | 115 +- .../ProviderOrganizationRepository.cs | 57 +- .../Repositories/ProviderRepository.cs | 57 +- .../Repositories/ProviderUserRepository.cs | 219 +- .../Repositories/Repository.cs | 135 +- .../Repositories/SendRepository.cs | 57 +- .../Repositories/SsoConfigRepository.cs | 75 +- .../Repositories/SsoUserRepository.cs | 53 +- .../Repositories/TaxRateRepository.cs | 89 +- .../Repositories/TransactionRepository.cs | 75 +- .../Repositories/UserRepository.cs | 237 +- ...ityFrameworkServiceCollectionExtensions.cs | 97 +- .../Models/Cipher.cs | 25 +- .../Models/Collection.cs | 27 +- .../Models/CollectionCipher.cs | 23 +- .../Models/CollectionGroup.cs | 23 +- .../Models/CollectionUser.cs | 23 +- .../Models/Device.cs | 21 +- .../Models/EmergencyAccess.cs | 23 +- .../Models/Event.cs | 19 +- .../Models/Folder.cs | 21 +- .../Models/Grant.cs | 19 +- .../Models/Group.cs | 23 +- .../Models/GroupUser.cs | 23 +- .../Models/Installation.cs | 19 +- .../Models/Organization.cs | 37 +- .../Models/OrganizationApiKey.cs | 21 +- .../Models/OrganizationConnection.cs | 21 +- .../Models/OrganizationSponsorship.cs | 23 +- .../Models/OrganizationUser.cs | 25 +- .../Models/Policy.cs | 21 +- .../Models/Provider/Provider.cs | 19 +- .../Models/Provider/ProviderOrganization.cs | 23 +- .../Models/Provider/ProviderUser.cs | 23 +- .../Models/Role.cs | 19 +- .../Models/Send.cs | 23 +- .../Models/SsoConfig.cs | 21 +- .../Models/SsoUser.cs | 23 +- .../Models/TaxRate.cs | 19 +- .../Models/Transaction.cs | 23 +- .../Models/User.cs | 33 +- .../BaseEntityFrameworkRepository.cs | 423 +- .../Repositories/CipherRepository.cs | 1145 ++--- .../CollectionCipherRepository.cs | 429 +- .../Repositories/CollectionRepository.cs | 437 +- .../Repositories/DatabaseContext.cs | 255 +- .../Repositories/DeviceRepository.cs | 95 +- .../Repositories/EmergencyAccessRepository.cs | 175 +- .../Repositories/EventRepository.cs | 331 +- .../Repositories/FolderRepository.cs | 47 +- .../Repositories/GrantRepository.cs | 155 +- .../Repositories/GroupRepository.cs | 261 +- .../Repositories/InstallationRepository.cs | 13 +- .../Repositories/MaintenanceRepository.cs | 75 +- .../OrganizationApiKeyRepository.cs | 33 +- .../OrganizationConnectionRepository.cs | 51 +- .../Repositories/OrganizationRepository.cs | 179 +- .../OrganizationSponsorshipRepository.cs | 213 +- .../OrganizationUserRepository.cs | 869 ++-- .../Repositories/PolicyRepository.cs | 95 +- .../ProviderOrganizationRepository.cs | 41 +- .../Repositories/ProviderRepository.cs | 79 +- .../Repositories/ProviderUserRepository.cs | 267 +- .../Queries/CipherDetailsQuery.cs | 59 +- .../CipherOrganizationDetailsReadByIdQuery.cs | 63 +- ...anizationDetailsReadByOrgizationIdQuery.cs | 63 +- .../CipherReadCanEditByIdUserIdQuery.cs | 95 +- .../Queries/CipherUpdateCollectionsQuery.cs | 103 +- ...llectionCipherReadByUserIdCipherIdQuery.cs | 25 +- .../CollectionCipherReadByUserIdQuery.cs | 71 +- ...ollectionReadCountByOrganizationIdQuery.cs | 29 +- .../Queries/CollectionUserUpdateUsersQuery.cs | 203 +- .../EmergencyAccessDetailsViewQuery.cs | 61 +- ...ncyAccessReadCountByGrantorIdEmailQuery.cs | 47 +- .../Queries/EventReadPageByCipherIdQuery.cs | 75 +- ...adPageByOrganizationIdActingUserIdQuery.cs | 61 +- .../EventReadPageByOrganizationIdQuery.cs | 55 +- ...ntReadPageByProviderIdActingUserIdQuery.cs | 61 +- .../Queries/EventReadPageByProviderIdQuery.cs | 55 +- .../Queries/EventReadPageByUserIdQuery.cs | 57 +- .../Queries/GroupUserUpdateGroupsQuery.cs | 121 +- .../Repositories/Queries/IQuery.cs | 9 +- ...izationUserOrganizationDetailsViewQuery.cs | 115 +- ...adCountByFreeOrganizationAdminUserQuery.cs | 39 +- ...ganizationUserReadCountByOnlyOwnerQuery.cs | 53 +- ...UserReadCountByOrganizationIdEmailQuery.cs | 47 +- ...ationUserReadCountByOrganizationIdQuery.cs | 29 +- ...anizationUserUpdateWithCollectionsQuery.cs | 183 +- .../Queries/OrganizationUserUserViewQuery.cs | 55 +- .../PolicyReadByTypeApplicableToUserQuery.cs | 79 +- .../Queries/PolicyReadByUserIdQuery.cs | 41 +- ...rganizationDetailsReadByProviderIdQuery.cs | 61 +- ...roviderUserOrganizationDetailsViewQuery.cs | 77 +- ...rProviderDetailsReadByUserIdStatusQuery.cs | 61 +- .../ProviderUserReadCountByOnlyOwnerQuery.cs | 53 +- ...rBumpAccountRevisionDateByCipherIdQuery.cs | 85 +- ...ccountRevisionDateByOrganizationIdQuery.cs | 35 +- .../Queries/UserCipherDetailsQuery.cs | 125 +- .../Queries/UserCollectionDetailsQuery.cs | 89 +- ...serReadPublicKeysByProviderUserIdsQuery.cs | 47 +- .../Repositories/Repository.cs | 187 +- .../Repositories/SendRepository.cs | 55 +- .../Repositories/SsoConfigRepository.cs | 55 +- .../Repositories/SsoUserRepository.cs | 43 +- .../Repositories/TaxRateRepository.cs | 91 +- .../Repositories/TransactionRepository.cs | 63 +- .../Repositories/UserRepository.cs | 235 +- src/Notifications/AzureQueueHostedService.cs | 127 +- src/Notifications/ConnectionCounter.cs | 39 +- .../Controllers/InfoController.cs | 25 +- .../Controllers/SendController.cs | 35 +- src/Notifications/HeartbeatHostedService.cs | 85 +- src/Notifications/HubHelpers.cs | 113 +- src/Notifications/Jobs/JobsHostedService.cs | 49 +- .../Jobs/LogConnectionCounterJob.cs | 33 +- src/Notifications/NotificationsHub.cs | 65 +- src/Notifications/Program.cs | 77 +- src/Notifications/Startup.cs | 187 +- src/Notifications/SubjectUserIdProvider.cs | 11 +- .../ExceptionHandlerFilterAttribute.cs | 131 +- .../ModelStateValidationFilterAttribute.cs | 37 +- .../Utilities/ServiceCollectionExtensions.cs | 1131 ++--- .../Factories/ApiApplicationFactory.cs | 55 +- .../Controllers/AccountsControllerTests.cs | 775 +-- .../Controllers/CollectionsControllerTests.cs | 125 +- .../OrganizationConnectionsControllerTests.cs | 567 +-- ...OrganizationSponsorshipsControllerTests.cs | 223 +- .../OrganizationUsersControllerTests.cs | 93 +- .../OrganizationsControllerTests.cs | 183 +- .../Controllers/SendsControllerTests.cs | 103 +- .../Accounts/PremiumRequestModelTests.cs | 97 +- .../Models/Request/SendRequestModelTests.cs | 81 +- test/Api.Test/Utilities/ApiHelpersTests.cs | 27 +- .../Controllers/FreshdeskControllerTests.cs | 115 +- .../Controllers/FreshsalesControllerTests.cs | 115 +- .../Attributes/BitAutoDataAttribute.cs | 37 +- .../Attributes/BitCustomizeAttribute.cs | 29 +- .../Attributes/BitMemberAutoDataAttribute.cs | 31 +- .../ControllerCustomizeAttribute.cs | 29 +- .../Attributes/CustomAutoDataAttribute.cs | 31 +- .../Attributes/EnvironmentDataAttribute.cs | 59 +- .../InlineCustomAutoDataAttribute.cs | 27 +- .../Attributes/InlineSutAutoDataAttribute.cs | 25 +- .../JsonDocumentCustomizeAttribute.cs | 11 +- .../RequiredEnvironmentTheoryAttribute.cs | 49 +- .../Attributes/SutAutoDataAttribute.cs | 21 +- .../BuilderWithoutAutoProperties.cs | 53 +- .../AutoFixture/ControllerCustomization.cs | 39 +- test/Common/AutoFixture/FixtureExtensions.cs | 15 +- .../AutoFixture/GlobalSettingsFixtures.cs | 19 +- test/Common/AutoFixture/ISutProvider.cs | 11 +- .../AutoFixture/JsonDocumentFixtures.cs | 39 +- test/Common/AutoFixture/SutProvider.cs | 201 +- .../AutoFixture/SutProviderCustomization.cs | 47 +- test/Common/Helpers/AssertHelper.cs | 385 +- .../Helpers/BitAutoDataAttributeHelpers.cs | 71 +- test/Common/Helpers/TestCaseHelper.cs | 55 +- test/Common/Test/TestCaseHelperTests.cs | 73 +- .../AutoFixture/Attributes/CiSkippedTheory.cs | 15 +- .../CipherAttachmentMetaDataFixtures.cs | 45 +- test/Core.Test/AutoFixture/CipherFixtures.cs | 119 +- .../AutoFixture/CollectionFixtures.cs | 11 +- .../AutoFixture/CurrentContextFixtures.cs | 57 +- .../AutoFixture/GlobalSettingsFixtures.cs | 37 +- test/Core.Test/AutoFixture/GroupFixtures.cs | 25 +- .../AutoFixture/OrganizationFixtures.cs | 319 +- .../OrganizationLicenseCustomization.cs | 21 +- .../OrganizationSponsorshipFixtures.cs | 49 +- .../AutoFixture/OrganizationUserFixtures.cs | 65 +- test/Core.Test/AutoFixture/PolicyFixtures.cs | 55 +- test/Core.Test/AutoFixture/SendFixtures.cs | 111 +- test/Core.Test/AutoFixture/UserFixtures.cs | 69 +- test/Core.Test/Entities/OrganizationTests.cs | 159 +- test/Core.Test/Entities/UserTests.cs | 229 +- test/Core.Test/Helpers/Factories.cs | 17 +- .../AuthenticationTokenProviderTests.cs | 55 +- .../Identity/BaseTokenProviderTests.cs | 137 +- .../Identity/EmailTokenProviderTests.cs | 69 +- .../IdentityServer/TokenRetrievalTests.cs | 135 +- test/Core.Test/Models/Business/BillingInfo.cs | 25 +- .../Core.Test/Models/Business/TaxInfoTests.cs | 195 +- .../EmergencyAccessInviteTokenableTests.cs | 41 +- .../Tokenables/HCaptchaTokenableTests.cs | 125 +- ...anizationSponsorshipOfferTokenableTests.cs | 221 +- .../Business/Tokenables/SsoTokenableTests.cs | 131 +- test/Core.Test/Models/CipherTests.cs | 17 +- .../Models/Data/SendFileDataTests.cs | 33 +- test/Core.Test/Models/PermissionsTests.cs | 95 +- .../GetOrganizationApiKeyCommandTests.cs | 153 +- .../RotateOrganizationApiKeyCommandTests.cs | 23 +- ...reateOrganizationConnectionCommandTests.cs | 23 +- ...eleteOrganizationConnectionCommandTests.cs | 23 +- ...pdateOrganizationConnectionCommandTests.cs | 71 +- .../CancelSponsorshipCommandTestsBase.cs | 123 +- .../CloudRevokeSponsorshipCommandTests.cs | 67 +- .../CloudSyncSponsorshipsCommandTests.cs | 344 +- ...rganizationSponsorshipRenewCommandTests.cs | 25 +- .../Cloud/RemoveSponsorshipCommandTests.cs | 53 +- .../Cloud/SendSponsorshipOfferCommandTests.cs | 177 +- .../Cloud/SetUpSponsorshipCommandTests.cs | 135 +- .../ValidateBillingSyncKeyCommandTests.cs | 75 +- .../ValidateRedemptionTokenCommandTests.cs | 127 +- .../Cloud/ValidateSponsorshipCommandTests.cs | 381 +- .../CreateSponsorshipCommandTests.cs | 287 +- .../FamiliesForEnterpriseTestsBase.cs | 31 +- ...SelfHostedRevokeSponsorshipCommandTests.cs | 69 +- .../SelfHostedSyncSponsorshipsCommandTests.cs | 286 +- test/Core.Test/Resources/VerifyResources.cs | 37 +- .../AmazonSesMailDeliveryServiceTests.cs | 133 +- .../Services/AmazonSqsBlockIpServiceTests.cs | 103 +- .../Services/AppleIapServiceTests.cs | 49 +- .../AzureAttachmentStorageServiceTests.cs | 39 +- .../Services/AzureQueueBlockIpServiceTests.cs | 35 +- .../AzureQueueEventWriteServiceTests.cs | 43 +- .../AzureQueuePushNotificationServiceTests.cs | 45 +- test/Core.Test/Services/CipherServiceTests.cs | 379 +- .../Services/CollectionServiceTests.cs | 263 +- test/Core.Test/Services/DeviceServiceTests.cs | 47 +- .../Services/EmergencyAccessServiceTests.cs | 299 +- test/Core.Test/Services/EventServiceTests.cs | 165 +- test/Core.Test/Services/GroupServiceTests.cs | 203 +- .../Services/HandlebarsMailServiceTests.cs | 297 +- .../InMemoryApplicationCacheServiceTests.cs | 39 +- ...yServiceBusApplicationCacheServiceTests.cs | 51 +- .../Services/LicensingServiceTests.cs | 81 +- .../LocalAttachmentStorageServiceTests.cs | 365 +- .../MailKitSmtpMailDeliveryServiceTests.cs | 49 +- ...ultiServicePushNotificationServiceTests.cs | 81 +- ...ficationHubPushNotificationServiceTests.cs | 51 +- ...ficationHubPushRegistrationServiceTests.cs | 45 +- ...icationsApiPushNotificationServiceTests.cs | 57 +- .../Services/OrganizationServiceTests.cs | 1825 +++---- test/Core.Test/Services/PolicyServiceTests.cs | 695 +-- .../RelayPushNotificationServiceTests.cs | 63 +- .../RelayPushRegistrationServiceTests.cs | 51 +- .../RepositoryEventWriteServiceTests.cs | 35 +- .../SendGridMailDeliveryServiceTests.cs | 119 +- test/Core.Test/Services/SendServiceTests.cs | 1289 ++--- .../Services/SsoConfigServiceTests.cs | 573 +-- .../Services/StripePaymentServiceTests.cs | 647 +-- test/Core.Test/Services/UserServiceTests.cs | 679 +-- test/Core.Test/TempDirectory.cs | 63 +- .../Tokens/DataProtectorTokenFactoryTests.cs | 229 +- test/Core.Test/Tokens/ExpiringTokenTests.cs | 91 +- test/Core.Test/Tokens/TestTokenable.cs | 37 +- test/Core.Test/Tokens/TokenTests.cs | 49 +- .../Utilities/ClaimsExtensionsTests.cs | 51 +- test/Core.Test/Utilities/CoreHelpersTests.cs | 793 +-- .../EncryptedStringAttributeTests.cs | 61 +- test/Core.Test/Utilities/JsonHelpersTests.cs | 107 +- .../PermissiveStringConverterTests.cs | 297 +- .../Utilities/SelfHostedAttributeTests.cs | 141 +- .../StrictEmailAddressAttributeTests.cs | 93 +- .../StrictEmailAddressListAttributeTests.cs | 79 +- test/Icons.Test/Resources/VerifyResources.cs | 23 +- .../Services/IconFetchingServiceTests.cs | 77 +- .../Controllers/AccountsControllerTests.cs | 41 +- .../Endpoints/IdentityServerTests.cs | 931 ++-- .../Controllers/AccountsControllerTests.cs | 173 +- .../AutoFixture/CipherFixtures.cs | 191 +- .../AutoFixture/CollectionCipherFixtures.cs | 89 +- .../AutoFixture/CollectionFixtures.cs | 81 +- .../AutoFixture/DeviceFixtures.cs | 81 +- .../AutoFixture/EmergencyAccessFixtures.cs | 83 +- .../EntityFrameworkRepositoryFixtures.cs | 181 +- .../AutoFixture/EventFixtures.cs | 77 +- .../AutoFixture/FolderFixtures.cs | 81 +- .../AutoFixture/GrantFixtures.cs | 77 +- .../AutoFixture/GroupFixtures.cs | 81 +- .../AutoFixture/GroupUserFixtures.cs | 75 +- .../AutoFixture/InstallationFixtures.cs | 75 +- .../AutoFixture/OrganizationFixtures.cs | 79 +- .../OrganizationSponsorshipFixtures.cs | 79 +- .../AutoFixture/OrganizationUserFixtures.cs | 105 +- .../AutoFixture/PolicyFixtures.cs | 117 +- .../Relays/MaxLengthStringRelay.cs | 59 +- .../AutoFixture/SendFixtures.cs | 105 +- .../AutoFixture/SsoConfigFixtures.cs | 83 +- .../AutoFixture/SsoUserFixtures.cs | 51 +- .../AutoFixture/TaxRateFixtures.cs | 77 +- .../AutoFixture/TransactionFixutres.cs | 105 +- .../AutoFixture/UserFixtures.cs | 45 +- .../Helpers/DatabaseOptionsFactory.cs | 31 +- .../Repositories/CipherRepositoryTests.cs | 323 +- .../Repositories/CollectionRepository.cs | 67 +- .../Repositories/DeviceRepositoryTests.cs | 57 +- .../EmergencyAccessRepositoryTests.cs | 73 +- .../EqualityComparers/CipherCompare.cs | 25 +- .../EqualityComparers/CollectionCompare.cs | 21 +- .../EqualityComparers/DeviceCompare.cs | 25 +- .../EmergencyAccessCompare.cs | 31 +- .../EqualityComparers/EventCompare.cs | 23 +- .../EqualityComparers/FolderCompare.cs | 19 +- .../EqualityComparers/GrantCompare.cs | 33 +- .../EqualityComparers/GroupCompare.cs | 23 +- .../EqualityComparers/InstallationCompare.cs | 23 +- .../EqualityComparers/OrganizationCompare.cs | 91 +- .../OrganizationSponsorshipCompare.cs | 29 +- .../OrganizationUserCompare.cs | 29 +- .../EqualityComparers/PolicyCompare.cs | 37 +- .../EqualityComparers/SendCompare.cs | 37 +- .../EqualityComparers/SsoConfigCompare.cs | 23 +- .../EqualityComparers/SsoUserCompare.cs | 19 +- .../EqualityComparers/TaxRateCompare.cs | 27 +- .../EqualityComparers/TransactionCompare.cs | 31 +- .../EqualityComparers/UserCompare.cs | 63 +- .../EqualityComparers/UserKdfInformation.cs | 21 +- .../Repositories/FolderRepositoryTests.cs | 63 +- .../InstallationRepositoryTests.cs | 47 +- .../OrganizationRepositoryTests.cs | 223 +- .../OrganizationSponsorshipRepositoryTests.cs | 205 +- .../OrganizationUserRepositoryTests.cs | 235 +- .../Repositories/PolicyRepositoryTests.cs | 317 +- .../Repositories/SendRepositoryTests.cs | 89 +- .../Repositories/SsoConfigRepositoryTests.cs | 355 +- .../Repositories/SsoUserRepositoryTests.cs | 295 +- .../Repositories/TaxRateRepositoryTests.cs | 49 +- .../TransactionRepositoryTests.cs | 89 +- .../Repositories/UserRepositoryTests.cs | 485 +- .../Factories/IdentityApplicationFactory.cs | 57 +- .../Factories/WebApplicationFactoryBase.cs | 177 +- .../WebApplicationFactoryExtensions.cs | 99 +- util/EfShared/MigrationBuilderExtensions.cs | 43 +- util/Migrator/DbMigrator.cs | 173 +- util/Migrator/DbUpLogger.cs | 39 +- util/MySqlMigrations/Factories.cs | 51 +- .../Migrations/20210617183900_Init.cs | 2179 ++++---- ..._RemoveProviderOrganizationProviderUser.cs | 149 +- .../20210716142145_UserForcePasswordReset.cs | 33 +- ...2418_AddMaxAutoscaleSeatsToOrganization.cs | 65 +- ...4835_SplitManageCollectionsPermissions2.cs | 25 +- ..._SetMaxAutoscaleSeatsToCurrentSeatCount.cs | 25 +- .../Migrations/20211108041911_KeyConnector.cs | 33 +- .../20211108225243_OrganizationSponsorship.cs | 145 +- .../20211115145402_KeyConnectorFlag.cs | 33 +- .../Migrations/20220121092546_RemoveU2F.cs | 83 +- .../20220301215315_FailedLoginCaptcha.cs | 49 +- .../Migrations/20220322191314_SelfHostF4E.cs | 257 +- .../20220411191518_SponsorshipBulkActions.cs | 135 +- ...0220420170738_AddInstallationIdToEvents.cs | 113 +- ...0220524171600_DeviceUnknownVerification.cs | 33 +- .../20220608191914_DeactivatedUserStatus.cs | 43 +- .../Migrations/20220707163017_UseScimFlag.cs | 33 +- util/PostgresMigrations/Factories.cs | 49 +- .../Migrations/20210708191531_Init.cs | 1937 +++---- ..._RemoveProviderOrganizationProviderUser.cs | 123 +- .../20210716141748_UserForcePasswordReset.cs | 33 +- ...1829_AddMaxAutoscaleSeatsToOrganization.cs | 63 +- ...5128_SplitManageCollectionsPermissions2.cs | 25 +- ..._SetMaxAutoscaleSeatsToCurrentSeatCount.cs | 25 +- .../Migrations/20211108041547_KeyConnector.cs | 33 +- .../20211108225011_OrganizationSponsorship.cs | 139 +- .../20211115142623_KeyConnectorFlag.cs | 33 +- .../Migrations/20220121092321_RemoveU2F.cs | 73 +- .../20220301211818_FailedLoginCaptcha.cs | 49 +- .../Migrations/20220322183505_SelfHostF4E.cs | 245 +- .../20220411190525_SponsorshipBulkActions.cs | 119 +- ...0220420171153_AddInstallationIdToEvents.cs | 103 +- ...0220524170740_DeviceUnknownVerification.cs | 33 +- .../Migrations/20220707162231_UseScimFlag.cs | 33 +- util/Server/Program.cs | 63 +- util/Server/Startup.cs | 151 +- util/Setup/AppIdBuilder.cs | 49 +- util/Setup/CertBuilder.cs | 189 +- util/Setup/Configuration.cs | 221 +- util/Setup/Context.cs | 279 +- util/Setup/DockerComposeBuilder.cs | 123 +- util/Setup/EnvironmentFileBuilder.cs | 387 +- util/Setup/Helpers.cs | 365 +- util/Setup/NginxConfigBuilder.cs | 229 +- util/Setup/Program.cs | 565 +-- util/Setup/YamlComments.cs | 165 +- 1208 files changed, 74317 insertions(+), 73126 deletions(-) diff --git a/.editorconfig b/.editorconfig index 828d0a7b39..9c2beed95a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -114,9 +114,6 @@ csharp_new_line_before_finally = true csharp_new_line_before_members_in_object_initializers = true csharp_new_line_before_members_in_anonymous_types = true -# Namespace settigns -csharp_style_namespace_declarations = file_scoped:warning - # All files [*] guidelines = 120 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 307748073b..af12ea908b 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,5 +1,2 @@ # Apply .NET format https://github.com/bitwarden/server/pull/1764 23b0a1f9df25058ab29785ecad9a233113c10889 - -# Turn on file scoped namespaces https://github.com/bitwarden/server/pull/2225 -34fb4cca2aa78deb84d4cbc359992a7c6bba7ea5 diff --git a/README.md b/README.md index 8469bdc315..70c0df3603 100644 --- a/README.md +++ b/README.md @@ -84,15 +84,3 @@ We recently migrated to using dotnet-format as code formatter. All previous bran 5. Commit 6. Run `git merge -Xours 23b0a1f9df25058ab29785ecad9a233113c10889` 7. Push - -### File Scoped Namespaces -We have switched to using file scoped namespace. All previous branches will need to update to avoid large merge conflicts using the following steps: - -1. Check out your local Branch -1. Run `git merge 7c4521e0b428d523f2153cda3fb51d51bca9f194` -2. Resolve any merge conflicts, commit. -3. Run `dotnet format` -4. Commit -5. Run `git merge -Xours 34fb4cca2aa78deb84d4cbc359992a7c6bba7ea5` -6. Resolve merge conflicts -7. Push diff --git a/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs index e50388f401..22c180524c 100644 --- a/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs @@ -13,496 +13,497 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.DataProtection; -namespace Bit.Commercial.Core.Services; - -public class ProviderService : IProviderService +namespace Bit.Commercial.Core.Services { - public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; - - private readonly IDataProtector _dataProtector; - private readonly IMailService _mailService; - private readonly IEventService _eventService; - private readonly GlobalSettings _globalSettings; - private readonly IProviderRepository _providerRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; - - public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, - IUserService userService, IOrganizationService organizationService, IMailService mailService, - IDataProtectionProvider dataProtectionProvider, IEventService eventService, - IOrganizationRepository organizationRepository, GlobalSettings globalSettings, - ICurrentContext currentContext) + public class ProviderService : IProviderService { - _providerRepository = providerRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _userService = userService; - _organizationService = organizationService; - _mailService = mailService; - _eventService = eventService; - _globalSettings = globalSettings; - _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - _currentContext = currentContext; - } + public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; - public async Task CreateAsync(string ownerEmail) - { - var owner = await _userRepository.GetByEmailAsync(ownerEmail); - if (owner == null) + private readonly IDataProtector _dataProtector; + private readonly IMailService _mailService; + private readonly IEventService _eventService; + private readonly GlobalSettings _globalSettings; + private readonly IProviderRepository _providerRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; + + public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, + IUserService userService, IOrganizationService organizationService, IMailService mailService, + IDataProtectionProvider dataProtectionProvider, IEventService eventService, + IOrganizationRepository organizationRepository, GlobalSettings globalSettings, + ICurrentContext currentContext) { - throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user."); + _providerRepository = providerRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _userService = userService; + _organizationService = organizationService; + _mailService = mailService; + _eventService = eventService; + _globalSettings = globalSettings; + _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + _currentContext = currentContext; } - var provider = new Provider + public async Task CreateAsync(string ownerEmail) { - Status = ProviderStatusType.Pending, - Enabled = true, - UseEvents = true, - }; - await _providerRepository.CreateAsync(provider); - - var providerUser = new ProviderUser - { - ProviderId = provider.Id, - UserId = owner.Id, - Type = ProviderUserType.ProviderAdmin, - Status = ProviderUserStatusType.Confirmed, - }; - await _providerUserRepository.CreateAsync(providerUser); - await SendProviderSetupInviteEmailAsync(provider, owner.Email); - } - - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) - { - var owner = await _userService.GetUserByIdAsync(ownerUserId); - if (owner == null) - { - throw new BadRequestException("Invalid owner."); - } - - if (provider.Status != ProviderStatusType.Pending) - { - throw new BadRequestException("Provider is already setup."); - } - - if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id, - _globalSettings.OrganizationInviteExpirationHours)) - { - throw new BadRequestException("Invalid token."); - } - - var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId); - if (!(providerUser is { Type: ProviderUserType.ProviderAdmin })) - { - throw new BadRequestException("Invalid owner."); - } - - provider.Status = ProviderStatusType.Created; - await _providerRepository.UpsertAsync(provider); - - providerUser.Key = key; - await _providerUserRepository.ReplaceAsync(providerUser); - - return provider; - } - - public async Task UpdateAsync(Provider provider, bool updateBilling = false) - { - if (provider.Id == default) - { - throw new ArgumentException("Cannot create provider this way."); - } - - await _providerRepository.ReplaceAsync(provider); - } - - public async Task> InviteUserAsync(ProviderUserInvite invite) - { - if (!_currentContext.ProviderManageUsers(invite.ProviderId)) - { - throw new InvalidOperationException("Invalid permissions."); - } - - var emails = invite?.UserIdentifiers; - var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId); - - var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); - if (provider == null || emails == null || !emails.Any()) - { - throw new NotFoundException(); - } - - var providerUsers = new List(); - foreach (var email in emails) - { - // Make sure user is not already invited - var existingProviderUserCount = - await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false); - if (existingProviderUserCount > 0) + var owner = await _userRepository.GetByEmailAsync(ownerEmail); + if (owner == null) { - continue; + throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user."); } + var provider = new Provider + { + Status = ProviderStatusType.Pending, + Enabled = true, + UseEvents = true, + }; + await _providerRepository.CreateAsync(provider); + var providerUser = new ProviderUser { - ProviderId = invite.ProviderId, - UserId = null, - Email = email.ToLowerInvariant(), - Key = null, - Type = invite.Type, - Status = ProviderUserStatusType.Invited, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, + ProviderId = provider.Id, + UserId = owner.Id, + Type = ProviderUserType.ProviderAdmin, + Status = ProviderUserStatusType.Confirmed, + }; + await _providerUserRepository.CreateAsync(providerUser); + await SendProviderSetupInviteEmailAsync(provider, owner.Email); + } + + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) + { + var owner = await _userService.GetUserByIdAsync(ownerUserId); + if (owner == null) + { + throw new BadRequestException("Invalid owner."); + } + + if (provider.Status != ProviderStatusType.Pending) + { + throw new BadRequestException("Provider is already setup."); + } + + if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id, + _globalSettings.OrganizationInviteExpirationHours)) + { + throw new BadRequestException("Invalid token."); + } + + var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId); + if (!(providerUser is { Type: ProviderUserType.ProviderAdmin })) + { + throw new BadRequestException("Invalid owner."); + } + + provider.Status = ProviderStatusType.Created; + await _providerRepository.UpsertAsync(provider); + + providerUser.Key = key; + await _providerUserRepository.ReplaceAsync(providerUser); + + return provider; + } + + public async Task UpdateAsync(Provider provider, bool updateBilling = false) + { + if (provider.Id == default) + { + throw new ArgumentException("Cannot create provider this way."); + } + + await _providerRepository.ReplaceAsync(provider); + } + + public async Task> InviteUserAsync(ProviderUserInvite invite) + { + if (!_currentContext.ProviderManageUsers(invite.ProviderId)) + { + throw new InvalidOperationException("Invalid permissions."); + } + + var emails = invite?.UserIdentifiers; + var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId); + + var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); + if (provider == null || emails == null || !emails.Any()) + { + throw new NotFoundException(); + } + + var providerUsers = new List(); + foreach (var email in emails) + { + // Make sure user is not already invited + var existingProviderUserCount = + await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false); + if (existingProviderUserCount > 0) + { + continue; + } + + var providerUser = new ProviderUser + { + ProviderId = invite.ProviderId, + UserId = null, + Email = email.ToLowerInvariant(), + Key = null, + Type = invite.Type, + Status = ProviderUserStatusType.Invited, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + await _providerUserRepository.CreateAsync(providerUser); + + await SendInviteAsync(providerUser, provider); + providerUsers.Add(providerUser); + } + + await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?))); + + return providerUsers; + } + + public async Task>> ResendInvitesAsync(ProviderUserInvite invite) + { + if (!_currentContext.ProviderManageUsers(invite.ProviderId)) + { + throw new BadRequestException("Invalid permissions."); + } + + var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers); + var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); + + var result = new List>(); + foreach (var providerUser in providerUsers) + { + if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId) + { + result.Add(Tuple.Create(providerUser, "User invalid.")); + continue; + } + + await SendInviteAsync(providerUser, provider); + result.Add(Tuple.Create(providerUser, "")); + } + + return result; + } + + public async Task AcceptUserAsync(Guid providerUserId, User user, string token) + { + var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId); + if (providerUser == null) + { + throw new BadRequestException("User invalid."); + } + + if (providerUser.Status != ProviderUserStatusType.Invited) + { + throw new BadRequestException("Already accepted."); + } + + if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id, + _globalSettings.OrganizationInviteExpirationHours)) + { + throw new BadRequestException("Invalid token."); + } + + if (string.IsNullOrWhiteSpace(providerUser.Email) || + !providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + throw new BadRequestException("User email does not match invite."); + } + + providerUser.Status = ProviderUserStatusType.Accepted; + providerUser.UserId = user.Id; + providerUser.Email = null; + + await _providerUserRepository.ReplaceAsync(providerUser); + + return providerUser; + } + + public async Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, + Guid confirmingUserId) + { + var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys); + var validProviderUsers = providerUsers + .Where(u => u.UserId != null) + .ToList(); + + if (!validProviderUsers.Any()) + { + return new List>(); + } + + var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList(); + + var provider = await _providerRepository.GetByIdAsync(providerId); + var users = await _userRepository.GetManyAsync(validOrganizationUserIds); + + var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u); + + var result = new List>(); + var events = new List<(ProviderUser, EventType, DateTime?)>(); + + foreach (var user in users) + { + if (!keyedFilteredUsers.ContainsKey(user.Id)) + { + continue; + } + var providerUser = keyedFilteredUsers[user.Id]; + try + { + if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId) + { + throw new BadRequestException("Invalid user."); + } + + providerUser.Status = ProviderUserStatusType.Confirmed; + providerUser.Key = keys[providerUser.Id]; + providerUser.Email = null; + + await _providerUserRepository.ReplaceAsync(providerUser); + events.Add((providerUser, EventType.ProviderUser_Confirmed, null)); + await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email); + result.Add(Tuple.Create(providerUser, "")); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(providerUser, e.Message)); + } + } + + await _eventService.LogProviderUsersEventAsync(events); + + return result; + } + + public async Task SaveUserAsync(ProviderUser user, Guid savingUserId) + { + if (user.Id.Equals(default)) + { + throw new BadRequestException("Invite the user first."); + } + + if (user.Type != ProviderUserType.ProviderAdmin && + !await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id })) + { + throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); + } + + await _providerUserRepository.ReplaceAsync(user); + await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated); + } + + public async Task>> DeleteUsersAsync(Guid providerId, + IEnumerable providerUserIds, Guid deletingUserId) + { + var provider = await _providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + throw new NotFoundException(); + } + + var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds); + var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue) + .Select(pu => pu.UserId.Value)); + var keyedUsers = users.ToDictionary(u => u.Id); + + if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds)) + { + throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); + } + + var result = new List>(); + var deletedUserIds = new List(); + var events = new List<(ProviderUser, EventType, DateTime?)>(); + + foreach (var providerUser in providerUsers) + { + try + { + if (providerUser.ProviderId != providerId) + { + throw new BadRequestException("Invalid user."); + } + if (providerUser.UserId == deletingUserId) + { + throw new BadRequestException("You cannot remove yourself."); + } + + events.Add((providerUser, EventType.ProviderUser_Removed, null)); + + var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault()); + var email = user == null ? providerUser.Email : user.Email; + if (!string.IsNullOrWhiteSpace(email)) + { + await _mailService.SendProviderUserRemoved(provider.Name, email); + } + + result.Add(Tuple.Create(providerUser, "")); + deletedUserIds.Add(providerUser.Id); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(providerUser, e.Message)); + } + + await _providerUserRepository.DeleteManyAsync(deletedUserIds); + } + + await _eventService.LogProviderUsersEventAsync(events); + + return result; + } + + public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) + { + var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId); + if (po != null) + { + throw new BadRequestException("Organization already belongs to a provider."); + } + + var organization = await _organizationRepository.GetByIdAsync(organizationId); + ThrowOnInvalidPlanType(organization.PlanType); + + var providerOrganization = new ProviderOrganization + { + ProviderId = providerId, + OrganizationId = organizationId, + Key = key, }; - await _providerUserRepository.CreateAsync(providerUser); - - await SendInviteAsync(providerUser, provider); - providerUsers.Add(providerUser); + await _providerOrganizationRepository.CreateAsync(providerOrganization); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added); } - await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?))); - - return providerUsers; - } - - public async Task>> ResendInvitesAsync(ProviderUserInvite invite) - { - if (!_currentContext.ProviderManageUsers(invite.ProviderId)) + public async Task CreateOrganizationAsync(Guid providerId, + OrganizationSignup organizationSignup, string clientOwnerEmail, User user) { - throw new BadRequestException("Invalid permissions."); - } + ThrowOnInvalidPlanType(organizationSignup.Plan); - var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers); - var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); + var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true); - var result = new List>(); - foreach (var providerUser in providerUsers) - { - if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId) + var providerOrganization = new ProviderOrganization { - result.Add(Tuple.Create(providerUser, "User invalid.")); - continue; - } + ProviderId = providerId, + OrganizationId = organization.Id, + Key = organizationSignup.OwnerKey, + }; - await SendInviteAsync(providerUser, provider); - result.Add(Tuple.Create(providerUser, "")); - } + await _providerOrganizationRepository.CreateAsync(providerOrganization); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created); - return result; - } - - public async Task AcceptUserAsync(Guid providerUserId, User user, string token) - { - var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId); - if (providerUser == null) - { - throw new BadRequestException("User invalid."); - } - - if (providerUser.Status != ProviderUserStatusType.Invited) - { - throw new BadRequestException("Already accepted."); - } - - if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id, - _globalSettings.OrganizationInviteExpirationHours)) - { - throw new BadRequestException("Invalid token."); - } - - if (string.IsNullOrWhiteSpace(providerUser.Email) || - !providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("User email does not match invite."); - } - - providerUser.Status = ProviderUserStatusType.Accepted; - providerUser.UserId = user.Id; - providerUser.Email = null; - - await _providerUserRepository.ReplaceAsync(providerUser); - - return providerUser; - } - - public async Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, - Guid confirmingUserId) - { - var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys); - var validProviderUsers = providerUsers - .Where(u => u.UserId != null) - .ToList(); - - if (!validProviderUsers.Any()) - { - return new List>(); - } - - var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList(); - - var provider = await _providerRepository.GetByIdAsync(providerId); - var users = await _userRepository.GetManyAsync(validOrganizationUserIds); - - var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u); - - var result = new List>(); - var events = new List<(ProviderUser, EventType, DateTime?)>(); - - foreach (var user in users) - { - if (!keyedFilteredUsers.ContainsKey(user.Id)) - { - continue; - } - var providerUser = keyedFilteredUsers[user.Id]; - try - { - if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId) + await _organizationService.InviteUsersAsync(organization.Id, user.Id, + new (OrganizationUserInvite, string)[] { - throw new BadRequestException("Invalid user."); - } + ( + new OrganizationUserInvite + { + Emails = new[] { clientOwnerEmail }, + AccessAll = true, + Type = OrganizationUserType.Owner, + Permissions = null, + Collections = Array.Empty(), + }, + null + ) + }); - providerUser.Status = ProviderUserStatusType.Confirmed; - providerUser.Key = keys[providerUser.Id]; - providerUser.Email = null; + return providerOrganization; + } - await _providerUserRepository.ReplaceAsync(providerUser); - events.Add((providerUser, EventType.ProviderUser_Confirmed, null)); - await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email); - result.Add(Tuple.Create(providerUser, "")); - } - catch (BadRequestException e) + public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) + { + var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId); + if (providerOrganization == null || providerOrganization.ProviderId != providerId) { - result.Add(Tuple.Create(providerUser, e.Message)); + throw new BadRequestException("Invalid organization."); + } + + if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false)) + { + throw new BadRequestException("Organization needs to have at least one confirmed owner."); + } + + await _providerOrganizationRepository.DeleteAsync(providerOrganization); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + } + + public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId) + { + var provider = await _providerRepository.GetByIdAsync(providerId); + var owner = await _userRepository.GetByIdAsync(ownerId); + if (owner == null) + { + throw new BadRequestException("Invalid owner."); + } + await SendProviderSetupInviteEmailAsync(provider, owner.Email); + } + + private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail) + { + var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail); + } + + public async Task LogProviderAccessToOrganizationAsync(Guid organizationId) + { + if (organizationId == default) + { + return; + } + + var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId); + var organization = await _organizationRepository.GetByIdAsync(organizationId); + if (providerOrganization != null) + { + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed); + } + if (organization != null) + { + await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed); } } - await _eventService.LogProviderUsersEventAsync(events); - - return result; - } - - public async Task SaveUserAsync(ProviderUser user, Guid savingUserId) - { - if (user.Id.Equals(default)) + private async Task SendInviteAsync(ProviderUser providerUser, Provider provider) { - throw new BadRequestException("Invite the user first."); + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var token = _dataProtector.Protect( + $"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}"); + await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email); } - if (user.Type != ProviderUserType.ProviderAdmin && - !await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id })) + private async Task HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable providerUserIds) { - throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); + var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId, + ProviderUserType.ProviderAdmin); + var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed); + var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); + return confirmedOwnersIds.Except(providerUserIds).Any(); } - await _providerUserRepository.ReplaceAsync(user); - await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated); - } - - public async Task>> DeleteUsersAsync(Guid providerId, - IEnumerable providerUserIds, Guid deletingUserId) - { - var provider = await _providerRepository.GetByIdAsync(providerId); - - if (provider == null) + private void ThrowOnInvalidPlanType(PlanType requestedType) { - throw new NotFoundException(); - } - - var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds); - var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue) - .Select(pu => pu.UserId.Value)); - var keyedUsers = users.ToDictionary(u => u.Id); - - if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds)) - { - throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); - } - - var result = new List>(); - var deletedUserIds = new List(); - var events = new List<(ProviderUser, EventType, DateTime?)>(); - - foreach (var providerUser in providerUsers) - { - try + if (ProviderDisllowedOrganizationTypes.Contains(requestedType)) { - if (providerUser.ProviderId != providerId) - { - throw new BadRequestException("Invalid user."); - } - if (providerUser.UserId == deletingUserId) - { - throw new BadRequestException("You cannot remove yourself."); - } - - events.Add((providerUser, EventType.ProviderUser_Removed, null)); - - var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault()); - var email = user == null ? providerUser.Email : user.Email; - if (!string.IsNullOrWhiteSpace(email)) - { - await _mailService.SendProviderUserRemoved(provider.Name, email); - } - - result.Add(Tuple.Create(providerUser, "")); - deletedUserIds.Add(providerUser.Id); + throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); } - catch (BadRequestException e) - { - result.Add(Tuple.Create(providerUser, e.Message)); - } - - await _providerUserRepository.DeleteManyAsync(deletedUserIds); - } - - await _eventService.LogProviderUsersEventAsync(events); - - return result; - } - - public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) - { - var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId); - if (po != null) - { - throw new BadRequestException("Organization already belongs to a provider."); - } - - var organization = await _organizationRepository.GetByIdAsync(organizationId); - ThrowOnInvalidPlanType(organization.PlanType); - - var providerOrganization = new ProviderOrganization - { - ProviderId = providerId, - OrganizationId = organizationId, - Key = key, - }; - - await _providerOrganizationRepository.CreateAsync(providerOrganization); - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added); - } - - public async Task CreateOrganizationAsync(Guid providerId, - OrganizationSignup organizationSignup, string clientOwnerEmail, User user) - { - ThrowOnInvalidPlanType(organizationSignup.Plan); - - var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true); - - var providerOrganization = new ProviderOrganization - { - ProviderId = providerId, - OrganizationId = organization.Id, - Key = organizationSignup.OwnerKey, - }; - - await _providerOrganizationRepository.CreateAsync(providerOrganization); - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created); - - await _organizationService.InviteUsersAsync(organization.Id, user.Id, - new (OrganizationUserInvite, string)[] - { - ( - new OrganizationUserInvite - { - Emails = new[] { clientOwnerEmail }, - AccessAll = true, - Type = OrganizationUserType.Owner, - Permissions = null, - Collections = Array.Empty(), - }, - null - ) - }); - - return providerOrganization; - } - - public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) - { - var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId); - if (providerOrganization == null || providerOrganization.ProviderId != providerId) - { - throw new BadRequestException("Invalid organization."); - } - - if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false)) - { - throw new BadRequestException("Organization needs to have at least one confirmed owner."); - } - - await _providerOrganizationRepository.DeleteAsync(providerOrganization); - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); - } - - public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId) - { - var provider = await _providerRepository.GetByIdAsync(providerId); - var owner = await _userRepository.GetByIdAsync(ownerId); - if (owner == null) - { - throw new BadRequestException("Invalid owner."); - } - await SendProviderSetupInviteEmailAsync(provider, owner.Email); - } - - private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail) - { - var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail); - } - - public async Task LogProviderAccessToOrganizationAsync(Guid organizationId) - { - if (organizationId == default) - { - return; - } - - var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId); - var organization = await _organizationRepository.GetByIdAsync(organizationId); - if (providerOrganization != null) - { - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed); - } - if (organization != null) - { - await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed); - } - } - - private async Task SendInviteAsync(ProviderUser providerUser, Provider provider) - { - var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); - var token = _dataProtector.Protect( - $"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}"); - await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email); - } - - private async Task HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable providerUserIds) - { - var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId, - ProviderUserType.ProviderAdmin); - var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed); - var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); - return confirmedOwnersIds.Except(providerUserIds).Any(); - } - - private void ThrowOnInvalidPlanType(PlanType requestedType) - { - if (ProviderDisllowedOrganizationTypes.Contains(requestedType)) - { - throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); } } } diff --git a/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs b/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs index 5bb1a5bde0..4074fe5f74 100644 --- a/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs @@ -2,12 +2,13 @@ using Bit.Core.Services; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Commercial.Core.Utilities; - -public static class ServiceCollectionExtensions +namespace Bit.Commercial.Core.Utilities { - public static void AddCommCoreServices(this IServiceCollection services) + public static class ServiceCollectionExtensions { - services.AddScoped(); + public static void AddCommCoreServices(this IServiceCollection services) + { + services.AddScoped(); + } } } diff --git a/bitwarden_license/src/Scim/Context/IScimContext.cs b/bitwarden_license/src/Scim/Context/IScimContext.cs index 1e7010bd26..90e5aca3ab 100644 --- a/bitwarden_license/src/Scim/Context/IScimContext.cs +++ b/bitwarden_license/src/Scim/Context/IScimContext.cs @@ -4,17 +4,18 @@ using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Scim.Context; - -public interface IScimContext +namespace Bit.Scim.Context { - ScimProviderType RequestScimProvider { get; set; } - ScimConfig ScimConfiguration { get; set; } - Guid? OrganizationId { get; set; } - Organization Organization { get; set; } - Task BuildAsync( - HttpContext httpContext, - GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, - IOrganizationConnectionRepository organizationConnectionRepository); + public interface IScimContext + { + ScimProviderType RequestScimProvider { get; set; } + ScimConfig ScimConfiguration { get; set; } + Guid? OrganizationId { get; set; } + Organization Organization { get; set; } + Task BuildAsync( + HttpContext httpContext, + GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IOrganizationConnectionRepository organizationConnectionRepository); + } } diff --git a/bitwarden_license/src/Scim/Context/ScimContext.cs b/bitwarden_license/src/Scim/Context/ScimContext.cs index ae8d30807d..0e489d33d0 100644 --- a/bitwarden_license/src/Scim/Context/ScimContext.cs +++ b/bitwarden_license/src/Scim/Context/ScimContext.cs @@ -4,60 +4,61 @@ using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Scim.Context; - -public class ScimContext : IScimContext +namespace Bit.Scim.Context { - private bool _builtHttpContext; - - public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default; - public ScimConfig ScimConfiguration { get; set; } - public Guid? OrganizationId { get; set; } - public Organization Organization { get; set; } - - public async virtual Task BuildAsync( - HttpContext httpContext, - GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, - IOrganizationConnectionRepository organizationConnectionRepository) + public class ScimContext : IScimContext { - if (_builtHttpContext) - { - return; - } + private bool _builtHttpContext; - _builtHttpContext = true; + public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default; + public ScimConfig ScimConfiguration { get; set; } + public Guid? OrganizationId { get; set; } + public Organization Organization { get; set; } - string orgIdString = null; - if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject)) + public async virtual Task BuildAsync( + HttpContext httpContext, + GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IOrganizationConnectionRepository organizationConnectionRepository) { - orgIdString = orgIdObject?.ToString(); - } - - if (Guid.TryParse(orgIdString, out var orgId)) - { - OrganizationId = orgId; - Organization = await organizationRepository.GetByIdAsync(orgId); - if (Organization != null) + if (_builtHttpContext) { - var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id, - OrganizationConnectionType.Scim); - ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig(); + return; } - } - if (RequestScimProvider == ScimProviderType.Default && - httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent)) - { - if (userAgent.ToString().StartsWith("Okta")) + _builtHttpContext = true; + + string orgIdString = null; + if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject)) { - RequestScimProvider = ScimProviderType.Okta; + orgIdString = orgIdObject?.ToString(); + } + + if (Guid.TryParse(orgIdString, out var orgId)) + { + OrganizationId = orgId; + Organization = await organizationRepository.GetByIdAsync(orgId); + if (Organization != null) + { + var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id, + OrganizationConnectionType.Scim); + ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig(); + } + } + + if (RequestScimProvider == ScimProviderType.Default && + httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent)) + { + if (userAgent.ToString().StartsWith("Okta")) + { + RequestScimProvider = ScimProviderType.Okta; + } + } + if (RequestScimProvider == ScimProviderType.Default && + httpContext.Request.Headers.ContainsKey("Adscimversion")) + { + RequestScimProvider = ScimProviderType.AzureAd; } - } - if (RequestScimProvider == ScimProviderType.Default && - httpContext.Request.Headers.ContainsKey("Adscimversion")) - { - RequestScimProvider = ScimProviderType.AzureAd; } } } diff --git a/bitwarden_license/src/Scim/Controllers/InfoController.cs b/bitwarden_license/src/Scim/Controllers/InfoController.cs index aa08ce9bf7..67967ed374 100644 --- a/bitwarden_license/src/Scim/Controllers/InfoController.cs +++ b/bitwarden_license/src/Scim/Controllers/InfoController.cs @@ -2,21 +2,22 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Scim.Controllers; - -[AllowAnonymous] -public class InfoController : Controller +namespace Bit.Scim.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + [AllowAnonymous] + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs index 6fe47db87f..ff55c411d5 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs @@ -8,320 +8,321 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Scim.Controllers.v2; - -[Authorize("Scim")] -[Route("v2/{organizationId}/groups")] -public class GroupsController : Controller +namespace Bit.Scim.Controllers.v2 { - private readonly ScimSettings _scimSettings; - private readonly IGroupRepository _groupRepository; - private readonly IGroupService _groupService; - private readonly IScimContext _scimContext; - private readonly ILogger _logger; - - public GroupsController( - IGroupRepository groupRepository, - IGroupService groupService, - IOptions scimSettings, - IScimContext scimContext, - ILogger logger) + [Authorize("Scim")] + [Route("v2/{organizationId}/groups")] + public class GroupsController : Controller { - _scimSettings = scimSettings?.Value; - _groupRepository = groupRepository; - _groupService = groupService; - _scimContext = scimContext; - _logger = logger; - } + private readonly ScimSettings _scimSettings; + private readonly IGroupRepository _groupRepository; + private readonly IGroupService _groupService; + private readonly IScimContext _scimContext; + private readonly ILogger _logger; - [HttpGet("{id}")] - public async Task Get(Guid organizationId, Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) + public GroupsController( + IGroupRepository groupRepository, + IGroupService groupService, + IOptions scimSettings, + IScimContext scimContext, + ILogger logger) { - return new NotFoundObjectResult(new ScimErrorResponseModel + _scimSettings = scimSettings?.Value; + _groupRepository = groupRepository; + _groupService = groupService; + _scimContext = scimContext; + _logger = logger; + } + + [HttpGet("{id}")] + public async Task Get(Guid organizationId, Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) { - Status = 404, - Detail = "Group not found." - }); - } - return new ObjectResult(new ScimGroupResponseModel(group)); - } - - [HttpGet("")] - public async Task Get( - Guid organizationId, - [FromQuery] string filter, - [FromQuery] int? count, - [FromQuery] int? startIndex) - { - string nameFilter = null; - string externalIdFilter = null; - if (!string.IsNullOrWhiteSpace(filter)) - { - if (filter.StartsWith("displayName eq ")) - { - nameFilter = filter.Substring(15).Trim('"'); - } - else if (filter.StartsWith("externalId eq ")) - { - externalIdFilter = filter.Substring(14).Trim('"'); - } - } - - var groupList = new List(); - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); - var totalResults = 0; - if (!string.IsNullOrWhiteSpace(nameFilter)) - { - var group = groups.FirstOrDefault(g => g.Name == nameFilter); - if (group != null) - { - groupList.Add(new ScimGroupResponseModel(group)); - } - totalResults = groupList.Count; - } - else if (!string.IsNullOrWhiteSpace(externalIdFilter)) - { - var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); - if (group != null) - { - groupList.Add(new ScimGroupResponseModel(group)); - } - totalResults = groupList.Count; - } - else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) - { - groupList = groups.OrderBy(g => g.Name) - .Skip(startIndex.Value - 1) - .Take(count.Value) - .Select(g => new ScimGroupResponseModel(g)) - .ToList(); - totalResults = groups.Count; - } - - var result = new ScimListResponseModel - { - Resources = groupList, - ItemsPerPage = count.GetValueOrDefault(groupList.Count), - TotalResults = totalResults, - StartIndex = startIndex.GetValueOrDefault(1), - }; - return new ObjectResult(result); - } - - [HttpPost("")] - public async Task Post(Guid organizationId, [FromBody] ScimGroupRequestModel model) - { - if (string.IsNullOrWhiteSpace(model.DisplayName)) - { - return new BadRequestResult(); - } - - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); - if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId)) - { - return new ConflictResult(); - } - - var group = model.ToGroup(organizationId); - await _groupService.SaveAsync(group, null); - await UpdateGroupMembersAsync(group, model, true); - var response = new ScimGroupResponseModel(group); - return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response); - } - - [HttpPut("{id}")] - public async Task Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "Group not found." - }); - } - - group.Name = model.DisplayName; - await _groupService.SaveAsync(group); - await UpdateGroupMembersAsync(group, model, false); - return new ObjectResult(new ScimGroupResponseModel(group)); - } - - [HttpPatch("{id}")] - public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "Group not found." - }); - } - - var operationHandled = false; - foreach (var operation in model.Operations) - { - // Replace operations - if (operation.Op?.ToLowerInvariant() == "replace") - { - // Replace a list of members - if (operation.Path?.ToLowerInvariant() == "members") + return new NotFoundObjectResult(new ScimErrorResponseModel { - var ids = GetOperationValueIds(operation.Value); - await _groupRepository.UpdateUsersAsync(group.Id, ids); - operationHandled = true; + Status = 404, + Detail = "Group not found." + }); + } + return new ObjectResult(new ScimGroupResponseModel(group)); + } + + [HttpGet("")] + public async Task Get( + Guid organizationId, + [FromQuery] string filter, + [FromQuery] int? count, + [FromQuery] int? startIndex) + { + string nameFilter = null; + string externalIdFilter = null; + if (!string.IsNullOrWhiteSpace(filter)) + { + if (filter.StartsWith("displayName eq ")) + { + nameFilter = filter.Substring(15).Trim('"'); } - // Replace group name from path - else if (operation.Path?.ToLowerInvariant() == "displayname") + else if (filter.StartsWith("externalId eq ")) { - group.Name = operation.Value.GetString(); - await _groupService.SaveAsync(group); - operationHandled = true; - } - // Replace group name from value object - else if (string.IsNullOrWhiteSpace(operation.Path) && - operation.Value.TryGetProperty("displayName", out var displayNameProperty)) - { - group.Name = displayNameProperty.GetString(); - await _groupService.SaveAsync(group); - operationHandled = true; + externalIdFilter = filter.Substring(14).Trim('"'); } } - // Add a single member - else if (operation.Op?.ToLowerInvariant() == "add" && - !string.IsNullOrWhiteSpace(operation.Path) && - operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) + + var groupList = new List(); + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); + var totalResults = 0; + if (!string.IsNullOrWhiteSpace(nameFilter)) { - var addId = GetOperationPathId(operation.Path); - if (addId.HasValue) + var group = groups.FirstOrDefault(g => g.Name == nameFilter); + if (group != null) + { + groupList.Add(new ScimGroupResponseModel(group)); + } + totalResults = groupList.Count; + } + else if (!string.IsNullOrWhiteSpace(externalIdFilter)) + { + var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); + if (group != null) + { + groupList.Add(new ScimGroupResponseModel(group)); + } + totalResults = groupList.Count; + } + else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) + { + groupList = groups.OrderBy(g => g.Name) + .Skip(startIndex.Value - 1) + .Take(count.Value) + .Select(g => new ScimGroupResponseModel(g)) + .ToList(); + totalResults = groups.Count; + } + + var result = new ScimListResponseModel + { + Resources = groupList, + ItemsPerPage = count.GetValueOrDefault(groupList.Count), + TotalResults = totalResults, + StartIndex = startIndex.GetValueOrDefault(1), + }; + return new ObjectResult(result); + } + + [HttpPost("")] + public async Task Post(Guid organizationId, [FromBody] ScimGroupRequestModel model) + { + if (string.IsNullOrWhiteSpace(model.DisplayName)) + { + return new BadRequestResult(); + } + + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); + if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId)) + { + return new ConflictResult(); + } + + var group = model.ToGroup(organizationId); + await _groupService.SaveAsync(group, null); + await UpdateGroupMembersAsync(group, model, true); + var response = new ScimGroupResponseModel(group); + return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response); + } + + [HttpPut("{id}")] + public async Task Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "Group not found." + }); + } + + group.Name = model.DisplayName; + await _groupService.SaveAsync(group); + await UpdateGroupMembersAsync(group, model, false); + return new ObjectResult(new ScimGroupResponseModel(group)); + } + + [HttpPatch("{id}")] + public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "Group not found." + }); + } + + var operationHandled = false; + foreach (var operation in model.Operations) + { + // Replace operations + if (operation.Op?.ToLowerInvariant() == "replace") + { + // Replace a list of members + if (operation.Path?.ToLowerInvariant() == "members") + { + var ids = GetOperationValueIds(operation.Value); + await _groupRepository.UpdateUsersAsync(group.Id, ids); + operationHandled = true; + } + // Replace group name from path + else if (operation.Path?.ToLowerInvariant() == "displayname") + { + group.Name = operation.Value.GetString(); + await _groupService.SaveAsync(group); + operationHandled = true; + } + // Replace group name from value object + else if (string.IsNullOrWhiteSpace(operation.Path) && + operation.Value.TryGetProperty("displayName", out var displayNameProperty)) + { + group.Name = displayNameProperty.GetString(); + await _groupService.SaveAsync(group); + operationHandled = true; + } + } + // Add a single member + else if (operation.Op?.ToLowerInvariant() == "add" && + !string.IsNullOrWhiteSpace(operation.Path) && + operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) + { + var addId = GetOperationPathId(operation.Path); + if (addId.HasValue) + { + var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); + orgUserIds.Add(addId.Value); + await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); + operationHandled = true; + } + } + // Add a list of members + else if (operation.Op?.ToLowerInvariant() == "add" && + operation.Path?.ToLowerInvariant() == "members") { var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); - orgUserIds.Add(addId.Value); + foreach (var v in GetOperationValueIds(operation.Value)) + { + orgUserIds.Add(v); + } + await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); + operationHandled = true; + } + // Remove a single member + else if (operation.Op?.ToLowerInvariant() == "remove" && + !string.IsNullOrWhiteSpace(operation.Path) && + operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) + { + var removeId = GetOperationPathId(operation.Path); + if (removeId.HasValue) + { + await _groupService.DeleteUserAsync(group, removeId.Value); + operationHandled = true; + } + } + // Remove a list of members + else if (operation.Op?.ToLowerInvariant() == "remove" && + operation.Path?.ToLowerInvariant() == "members") + { + var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); + foreach (var v in GetOperationValueIds(operation.Value)) + { + orgUserIds.Remove(v); + } await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); operationHandled = true; } } - // Add a list of members - else if (operation.Op?.ToLowerInvariant() == "add" && - operation.Path?.ToLowerInvariant() == "members") + + if (!operationHandled) { - var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); - foreach (var v in GetOperationValueIds(operation.Value)) - { - orgUserIds.Add(v); - } - await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); - operationHandled = true; + _logger.LogWarning("Group patch operation not handled: {0} : ", + string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); } - // Remove a single member - else if (operation.Op?.ToLowerInvariant() == "remove" && - !string.IsNullOrWhiteSpace(operation.Path) && - operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) + + return new NoContentResult(); + } + + [HttpDelete("{id}")] + public async Task Delete(Guid organizationId, Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) { - var removeId = GetOperationPathId(operation.Path); - if (removeId.HasValue) + return new NotFoundObjectResult(new ScimErrorResponseModel { - await _groupService.DeleteUserAsync(group, removeId.Value); - operationHandled = true; + Status = 404, + Detail = "Group not found." + }); + } + await _groupService.DeleteAsync(group); + return new NoContentResult(); + } + + private List GetOperationValueIds(JsonElement objArray) + { + var ids = new List(); + foreach (var obj in objArray.EnumerateArray()) + { + if (obj.TryGetProperty("value", out var valueProperty)) + { + if (valueProperty.TryGetGuid(out var guid)) + { + ids.Add(guid); + } } } - // Remove a list of members - else if (operation.Op?.ToLowerInvariant() == "remove" && - operation.Path?.ToLowerInvariant() == "members") + return ids; + } + + private Guid? GetOperationPathId(string path) + { + // Parse Guid from string like: members[value eq "{GUID}"}] + if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id)) { - var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); - foreach (var v in GetOperationValueIds(operation.Value)) - { - orgUserIds.Remove(v); - } - await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); - operationHandled = true; + return id; } + return null; } - if (!operationHandled) + private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty) { - _logger.LogWarning("Group patch operation not handled: {0} : ", - string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); - } - - return new NoContentResult(); - } - - [HttpDelete("{id}")] - public async Task Delete(Guid organizationId, Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel + if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta) { - Status = 404, - Detail = "Group not found." - }); - } - await _groupService.DeleteAsync(group); - return new NoContentResult(); - } + return; + } - private List GetOperationValueIds(JsonElement objArray) - { - var ids = new List(); - foreach (var obj in objArray.EnumerateArray()) - { - if (obj.TryGetProperty("value", out var valueProperty)) + if (model.Members == null) { - if (valueProperty.TryGetGuid(out var guid)) + return; + } + + var memberIds = new List(); + foreach (var id in model.Members.Select(i => i.Value)) + { + if (Guid.TryParse(id, out var guidId)) { - ids.Add(guid); + memberIds.Add(guidId); } } - } - return ids; - } - private Guid? GetOperationPathId(string path) - { - // Parse Guid from string like: members[value eq "{GUID}"}] - if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id)) - { - return id; - } - return null; - } - - private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty) - { - if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta) - { - return; - } - - if (model.Members == null) - { - return; - } - - var memberIds = new List(); - foreach (var id in model.Members.Select(i => i.Value)) - { - if (Guid.TryParse(id, out var guidId)) + if (!memberIds.Any() && skipIfEmpty) { - memberIds.Add(guidId); + return; } - } - if (!memberIds.Any() && skipIfEmpty) - { - return; + await _groupRepository.UpdateUsersAsync(group.Id, memberIds); } - - await _groupRepository.UpdateUsersAsync(group.Id, memberIds); } } diff --git a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs index 7291be7f67..ff650c64ee 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs @@ -9,286 +9,287 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Scim.Controllers.v2; - -[Authorize("Scim")] -[Route("v2/{organizationId}/users")] -public class UsersController : Controller +namespace Bit.Scim.Controllers.v2 { - private readonly IUserService _userService; - private readonly IUserRepository _userRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; - private readonly IScimContext _scimContext; - private readonly ScimSettings _scimSettings; - private readonly ILogger _logger; - - public UsersController( - IUserService userService, - IUserRepository userRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, - IScimContext scimContext, - IOptions scimSettings, - ILogger logger) + [Authorize("Scim")] + [Route("v2/{organizationId}/users")] + public class UsersController : Controller { - _userService = userService; - _userRepository = userRepository; - _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; - _scimContext = scimContext; - _scimSettings = scimSettings?.Value; - _logger = logger; - } + private readonly IUserService _userService; + private readonly IUserRepository _userRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationService _organizationService; + private readonly IScimContext _scimContext; + private readonly ScimSettings _scimSettings; + private readonly ILogger _logger; - [HttpGet("{id}")] - public async Task Get(Guid organizationId, Guid id) - { - var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) + public UsersController( + IUserService userService, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationService organizationService, + IScimContext scimContext, + IOptions scimSettings, + ILogger logger) { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "User not found." - }); + _userService = userService; + _userRepository = userRepository; + _organizationUserRepository = organizationUserRepository; + _organizationService = organizationService; + _scimContext = scimContext; + _scimSettings = scimSettings?.Value; + _logger = logger; } - return new ObjectResult(new ScimUserResponseModel(orgUser)); - } - [HttpGet("")] - public async Task Get( - Guid organizationId, - [FromQuery] string filter, - [FromQuery] int? count, - [FromQuery] int? startIndex) - { - string emailFilter = null; - string usernameFilter = null; - string externalIdFilter = null; - if (!string.IsNullOrWhiteSpace(filter)) + [HttpGet("{id}")] + public async Task Get(Guid organizationId, Guid id) { - if (filter.StartsWith("userName eq ")) + var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) { - usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant(); - if (usernameFilter.Contains("@")) + return new NotFoundObjectResult(new ScimErrorResponseModel { - emailFilter = usernameFilter; - } + Status = 404, + Detail = "User not found." + }); } - else if (filter.StartsWith("externalId eq ")) + return new ObjectResult(new ScimUserResponseModel(orgUser)); + } + + [HttpGet("")] + public async Task Get( + Guid organizationId, + [FromQuery] string filter, + [FromQuery] int? count, + [FromQuery] int? startIndex) + { + string emailFilter = null; + string usernameFilter = null; + string externalIdFilter = null; + if (!string.IsNullOrWhiteSpace(filter)) { - externalIdFilter = filter.Substring(14).Trim('"'); - } - } - - var userList = new List { }; - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - var totalResults = 0; - if (!string.IsNullOrWhiteSpace(emailFilter)) - { - var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter); - if (orgUser != null) - { - userList.Add(new ScimUserResponseModel(orgUser)); - } - totalResults = userList.Count; - } - else if (!string.IsNullOrWhiteSpace(externalIdFilter)) - { - var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); - if (orgUser != null) - { - userList.Add(new ScimUserResponseModel(orgUser)); - } - totalResults = userList.Count; - } - else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) - { - userList = orgUsers.OrderBy(ou => ou.Email) - .Skip(startIndex.Value - 1) - .Take(count.Value) - .Select(ou => new ScimUserResponseModel(ou)) - .ToList(); - totalResults = orgUsers.Count; - } - - var result = new ScimListResponseModel - { - Resources = userList, - ItemsPerPage = count.GetValueOrDefault(userList.Count), - TotalResults = totalResults, - StartIndex = startIndex.GetValueOrDefault(1), - }; - return new ObjectResult(result); - } - - [HttpPost("")] - public async Task Post(Guid organizationId, [FromBody] ScimUserRequestModel model) - { - var email = model.PrimaryEmail?.ToLowerInvariant(); - if (string.IsNullOrWhiteSpace(email)) - { - switch (_scimContext.RequestScimProvider) - { - case ScimProviderType.AzureAd: - email = model.UserName?.ToLowerInvariant(); - break; - default: - email = model.WorkEmail?.ToLowerInvariant(); - if (string.IsNullOrWhiteSpace(email)) - { - email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant(); - } - break; - } - } - - if (string.IsNullOrWhiteSpace(email) || !model.Active) - { - return new BadRequestResult(); - } - - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email); - if (orgUserByEmail != null) - { - return new ConflictResult(); - } - - string externalId = null; - if (!string.IsNullOrWhiteSpace(model.ExternalId)) - { - externalId = model.ExternalId; - } - else if (!string.IsNullOrWhiteSpace(model.UserName)) - { - externalId = model.UserName; - } - else - { - externalId = CoreHelpers.RandomString(15); - } - - var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId); - if (orgUserByExternalId != null) - { - return new ConflictResult(); - } - - var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email, - OrganizationUserType.User, false, externalId, new List()); - var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id); - var response = new ScimUserResponseModel(orgUser); - return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response); - } - - [HttpPut("{id}")] - public async Task Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "User not found." - }); - } - - if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked) - { - await _organizationService.RestoreUserAsync(orgUser, null, _userService); - } - else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked) - { - await _organizationService.RevokeUserAsync(orgUser, null); - } - - // Have to get full details object for response model - var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - return new ObjectResult(new ScimUserResponseModel(orgUserDetails)); - } - - [HttpPatch("{id}")] - public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "User not found." - }); - } - - var operationHandled = false; - foreach (var operation in model.Operations) - { - // Replace operations - if (operation.Op?.ToLowerInvariant() == "replace") - { - // Active from path - if (operation.Path?.ToLowerInvariant() == "active") + if (filter.StartsWith("userName eq ")) { - var active = operation.Value.ToString()?.ToLowerInvariant(); - var handled = await HandleActiveOperationAsync(orgUser, active == "true"); - if (!operationHandled) + usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant(); + if (usernameFilter.Contains("@")) { - operationHandled = handled; + emailFilter = usernameFilter; } } - // Active from value object - else if (string.IsNullOrWhiteSpace(operation.Path) && - operation.Value.TryGetProperty("active", out var activeProperty)) + else if (filter.StartsWith("externalId eq ")) { - var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean()); - if (!operationHandled) + externalIdFilter = filter.Substring(14).Trim('"'); + } + } + + var userList = new List { }; + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var totalResults = 0; + if (!string.IsNullOrWhiteSpace(emailFilter)) + { + var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter); + if (orgUser != null) + { + userList.Add(new ScimUserResponseModel(orgUser)); + } + totalResults = userList.Count; + } + else if (!string.IsNullOrWhiteSpace(externalIdFilter)) + { + var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); + if (orgUser != null) + { + userList.Add(new ScimUserResponseModel(orgUser)); + } + totalResults = userList.Count; + } + else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) + { + userList = orgUsers.OrderBy(ou => ou.Email) + .Skip(startIndex.Value - 1) + .Take(count.Value) + .Select(ou => new ScimUserResponseModel(ou)) + .ToList(); + totalResults = orgUsers.Count; + } + + var result = new ScimListResponseModel + { + Resources = userList, + ItemsPerPage = count.GetValueOrDefault(userList.Count), + TotalResults = totalResults, + StartIndex = startIndex.GetValueOrDefault(1), + }; + return new ObjectResult(result); + } + + [HttpPost("")] + public async Task Post(Guid organizationId, [FromBody] ScimUserRequestModel model) + { + var email = model.PrimaryEmail?.ToLowerInvariant(); + if (string.IsNullOrWhiteSpace(email)) + { + switch (_scimContext.RequestScimProvider) + { + case ScimProviderType.AzureAd: + email = model.UserName?.ToLowerInvariant(); + break; + default: + email = model.WorkEmail?.ToLowerInvariant(); + if (string.IsNullOrWhiteSpace(email)) + { + email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant(); + } + break; + } + } + + if (string.IsNullOrWhiteSpace(email) || !model.Active) + { + return new BadRequestResult(); + } + + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email); + if (orgUserByEmail != null) + { + return new ConflictResult(); + } + + string externalId = null; + if (!string.IsNullOrWhiteSpace(model.ExternalId)) + { + externalId = model.ExternalId; + } + else if (!string.IsNullOrWhiteSpace(model.UserName)) + { + externalId = model.UserName; + } + else + { + externalId = CoreHelpers.RandomString(15); + } + + var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId); + if (orgUserByExternalId != null) + { + return new ConflictResult(); + } + + var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email, + OrganizationUserType.User, false, externalId, new List()); + var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id); + var response = new ScimUserResponseModel(orgUser); + return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response); + } + + [HttpPut("{id}")] + public async Task Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "User not found." + }); + } + + if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked) + { + await _organizationService.RestoreUserAsync(orgUser, null, _userService); + } + else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked) + { + await _organizationService.RevokeUserAsync(orgUser, null); + } + + // Have to get full details object for response model + var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); + return new ObjectResult(new ScimUserResponseModel(orgUserDetails)); + } + + [HttpPatch("{id}")] + public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "User not found." + }); + } + + var operationHandled = false; + foreach (var operation in model.Operations) + { + // Replace operations + if (operation.Op?.ToLowerInvariant() == "replace") + { + // Active from path + if (operation.Path?.ToLowerInvariant() == "active") { - operationHandled = handled; + var active = operation.Value.ToString()?.ToLowerInvariant(); + var handled = await HandleActiveOperationAsync(orgUser, active == "true"); + if (!operationHandled) + { + operationHandled = handled; + } + } + // Active from value object + else if (string.IsNullOrWhiteSpace(operation.Path) && + operation.Value.TryGetProperty("active", out var activeProperty)) + { + var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean()); + if (!operationHandled) + { + operationHandled = handled; + } } } } - } - if (!operationHandled) - { - _logger.LogWarning("User patch operation not handled: {operation} : ", - string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); - } - - return new NoContentResult(); - } - - [HttpDelete("{id}")] - public async Task Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel + if (!operationHandled) { - Status = 404, - Detail = "User not found." - }); - } - await _organizationService.DeleteUserAsync(organizationId, id, null); - return new NoContentResult(); - } + _logger.LogWarning("User patch operation not handled: {operation} : ", + string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); + } - private async Task HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active) - { - if (active && orgUser.Status == OrganizationUserStatusType.Revoked) - { - await _organizationService.RestoreUserAsync(orgUser, null, _userService); - return true; + return new NoContentResult(); } - else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked) + + [HttpDelete("{id}")] + public async Task Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) { - await _organizationService.RevokeUserAsync(orgUser, null); - return true; + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "User not found." + }); + } + await _organizationService.DeleteUserAsync(organizationId, id, null); + return new NoContentResult(); + } + + private async Task HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active) + { + if (active && orgUser.Status == OrganizationUserStatusType.Revoked) + { + await _organizationService.RestoreUserAsync(orgUser, null, _userService); + return true; + } + else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked) + { + await _organizationService.RevokeUserAsync(orgUser, null); + return true; + } + return false; } - return false; } } diff --git a/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs b/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs index 150885fb50..06d57bfad0 100644 --- a/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs +++ b/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs @@ -1,17 +1,18 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models; - -public abstract class BaseScimGroupModel : BaseScimModel +namespace Bit.Scim.Models { - public BaseScimGroupModel(bool initSchema = false) + public abstract class BaseScimGroupModel : BaseScimModel { - if (initSchema) + public BaseScimGroupModel(bool initSchema = false) { - Schemas = new List { ScimConstants.Scim2SchemaGroup }; + if (initSchema) + { + Schemas = new List { ScimConstants.Scim2SchemaGroup }; + } } - } - public string DisplayName { get; set; } - public string ExternalId { get; set; } + public string DisplayName { get; set; } + public string ExternalId { get; set; } + } } diff --git a/bitwarden_license/src/Scim/Models/BaseScimModel.cs b/bitwarden_license/src/Scim/Models/BaseScimModel.cs index 8f3adfbe4a..a2a0717866 100644 --- a/bitwarden_license/src/Scim/Models/BaseScimModel.cs +++ b/bitwarden_license/src/Scim/Models/BaseScimModel.cs @@ -1,14 +1,15 @@ -namespace Bit.Scim.Models; - -public abstract class BaseScimModel +namespace Bit.Scim.Models { - public BaseScimModel() - { } - - public BaseScimModel(string schema) + public abstract class BaseScimModel { - Schemas = new List { schema }; - } + public BaseScimModel() + { } - public List Schemas { get; set; } + public BaseScimModel(string schema) + { + Schemas = new List { schema }; + } + + public List Schemas { get; set; } + } } diff --git a/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs b/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs index d3c69d574d..0af9e652b8 100644 --- a/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs +++ b/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs @@ -1,55 +1,56 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models; - -public abstract class BaseScimUserModel : BaseScimModel +namespace Bit.Scim.Models { - public BaseScimUserModel(bool initSchema = false) + public abstract class BaseScimUserModel : BaseScimModel { - if (initSchema) + public BaseScimUserModel(bool initSchema = false) { - Schemas = new List { ScimConstants.Scim2SchemaUser }; - } - } - - public string UserName { get; set; } - public NameModel Name { get; set; } - public List Emails { get; set; } - public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value; - public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value; - public string DisplayName { get; set; } - public bool Active { get; set; } - public List Groups { get; set; } - public string ExternalId { get; set; } - - public class NameModel - { - public NameModel() { } - - public NameModel(string name) - { - Formatted = name; + if (initSchema) + { + Schemas = new List { ScimConstants.Scim2SchemaUser }; + } } - public string Formatted { get; set; } - public string GivenName { get; set; } - public string MiddleName { get; set; } - public string FamilyName { get; set; } - } + public string UserName { get; set; } + public NameModel Name { get; set; } + public List Emails { get; set; } + public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value; + public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value; + public string DisplayName { get; set; } + public bool Active { get; set; } + public List Groups { get; set; } + public string ExternalId { get; set; } - public class EmailModel - { - public EmailModel() { } - - public EmailModel(string email) + public class NameModel { - Primary = true; - Value = email; - Type = "work"; + public NameModel() { } + + public NameModel(string name) + { + Formatted = name; + } + + public string Formatted { get; set; } + public string GivenName { get; set; } + public string MiddleName { get; set; } + public string FamilyName { get; set; } } - public bool Primary { get; set; } - public string Value { get; set; } - public string Type { get; set; } + public class EmailModel + { + public EmailModel() { } + + public EmailModel(string email) + { + Primary = true; + Value = email; + Type = "work"; + } + + public bool Primary { get; set; } + public string Value { get; set; } + public string Type { get; set; } + } } } diff --git a/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs index d1dce35ef0..6055001f56 100644 --- a/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs @@ -1,13 +1,14 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models; - -public class ScimErrorResponseModel : BaseScimModel +namespace Bit.Scim.Models { - public ScimErrorResponseModel() - : base(ScimConstants.Scim2SchemaError) - { } + public class ScimErrorResponseModel : BaseScimModel + { + public ScimErrorResponseModel() + : base(ScimConstants.Scim2SchemaError) + { } - public string Detail { get; set; } - public int Status { get; set; } + public string Detail { get; set; } + public int Status { get; set; } + } } diff --git a/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs b/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs index ac99eca2e9..6de96655b0 100644 --- a/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs @@ -1,30 +1,31 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Scim.Models; - -public class ScimGroupRequestModel : BaseScimGroupModel +namespace Bit.Scim.Models { - public ScimGroupRequestModel() - : base(false) - { } - - public Group ToGroup(Guid organizationId) + public class ScimGroupRequestModel : BaseScimGroupModel { - var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId; - return new Group + public ScimGroupRequestModel() + : base(false) + { } + + public Group ToGroup(Guid organizationId) { - Name = DisplayName, - ExternalId = externalId, - OrganizationId = organizationId - }; - } + var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId; + return new Group + { + Name = DisplayName, + ExternalId = externalId, + OrganizationId = organizationId + }; + } - public List Members { get; set; } + public List Members { get; set; } - public class GroupMembersModel - { - public string Value { get; set; } - public string Display { get; set; } + public class GroupMembersModel + { + public string Value { get; set; } + public string Display { get; set; } + } } } diff --git a/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs index d5bd64a32c..df5d9b22a2 100644 --- a/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs @@ -1,25 +1,26 @@ using Bit.Core.Entities; -namespace Bit.Scim.Models; - -public class ScimGroupResponseModel : BaseScimGroupModel +namespace Bit.Scim.Models { - public ScimGroupResponseModel() - : base(true) + public class ScimGroupResponseModel : BaseScimGroupModel { - Meta = new ScimMetaModel("Group"); - } + public ScimGroupResponseModel() + : base(true) + { + Meta = new ScimMetaModel("Group"); + } - public ScimGroupResponseModel(Group group) - : this() - { - Id = group.Id.ToString(); - DisplayName = group.Name; - ExternalId = group.ExternalId; - Meta.Created = group.CreationDate; - Meta.LastModified = group.RevisionDate; - } + public ScimGroupResponseModel(Group group) + : this() + { + Id = group.Id.ToString(); + DisplayName = group.Name; + ExternalId = group.ExternalId; + Meta.Created = group.CreationDate; + Meta.LastModified = group.RevisionDate; + } - public string Id { get; set; } - public ScimMetaModel Meta { get; private set; } + public string Id { get; set; } + public ScimMetaModel Meta { get; private set; } + } } diff --git a/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs index 77ab52356c..e7b9521680 100644 --- a/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs @@ -1,15 +1,16 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models; - -public class ScimListResponseModel : BaseScimModel +namespace Bit.Scim.Models { - public ScimListResponseModel() - : base(ScimConstants.Scim2SchemaListResponse) - { } + public class ScimListResponseModel : BaseScimModel + { + public ScimListResponseModel() + : base(ScimConstants.Scim2SchemaListResponse) + { } - public int TotalResults { get; set; } - public int StartIndex { get; set; } - public int ItemsPerPage { get; set; } - public List Resources { get; set; } + public int TotalResults { get; set; } + public int StartIndex { get; set; } + public int ItemsPerPage { get; set; } + public List Resources { get; set; } + } } diff --git a/bitwarden_license/src/Scim/Models/ScimMetaModel.cs b/bitwarden_license/src/Scim/Models/ScimMetaModel.cs index 862c054b79..f3d95f5f38 100644 --- a/bitwarden_license/src/Scim/Models/ScimMetaModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimMetaModel.cs @@ -1,13 +1,14 @@ -namespace Bit.Scim.Models; - -public class ScimMetaModel +namespace Bit.Scim.Models { - public ScimMetaModel(string resourceType) + public class ScimMetaModel { - ResourceType = resourceType; - } + public ScimMetaModel(string resourceType) + { + ResourceType = resourceType; + } - public string ResourceType { get; set; } - public DateTime? Created { get; set; } - public DateTime? LastModified { get; set; } + public string ResourceType { get; set; } + public DateTime? Created { get; set; } + public DateTime? LastModified { get; set; } + } } diff --git a/bitwarden_license/src/Scim/Models/ScimPatchModel.cs b/bitwarden_license/src/Scim/Models/ScimPatchModel.cs index 6707ced85f..d421267656 100644 --- a/bitwarden_license/src/Scim/Models/ScimPatchModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimPatchModel.cs @@ -1,18 +1,19 @@ using System.Text.Json; -namespace Bit.Scim.Models; - -public class ScimPatchModel : BaseScimModel +namespace Bit.Scim.Models { - public ScimPatchModel() - : base() { } - - public List Operations { get; set; } - - public class OperationModel + public class ScimPatchModel : BaseScimModel { - public string Op { get; set; } - public string Path { get; set; } - public JsonElement Value { get; set; } + public ScimPatchModel() + : base() { } + + public List Operations { get; set; } + + public class OperationModel + { + public string Op { get; set; } + public string Path { get; set; } + public JsonElement Value { get; set; } + } } } diff --git a/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs b/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs index a489e03adf..17f5e85931 100644 --- a/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Scim.Models; - -public class ScimUserRequestModel : BaseScimUserModel +namespace Bit.Scim.Models { - public ScimUserRequestModel() - : base(false) - { } + public class ScimUserRequestModel : BaseScimUserModel + { + public ScimUserRequestModel() + : base(false) + { } + } } diff --git a/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs index 95d5184daf..6f96506616 100644 --- a/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs @@ -1,28 +1,29 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Scim.Models; - -public class ScimUserResponseModel : BaseScimUserModel +namespace Bit.Scim.Models { - public ScimUserResponseModel() - : base(true) + public class ScimUserResponseModel : BaseScimUserModel { - Meta = new ScimMetaModel("User"); - Groups = new List(); - } + public ScimUserResponseModel() + : base(true) + { + Meta = new ScimMetaModel("User"); + Groups = new List(); + } - public ScimUserResponseModel(OrganizationUserUserDetails orgUser) - : this() - { - Id = orgUser.Id.ToString(); - ExternalId = orgUser.ExternalId; - UserName = orgUser.Email; - DisplayName = orgUser.Name; - Emails = new List { new EmailModel(orgUser.Email) }; - Name = new NameModel(orgUser.Name); - Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked; - } + public ScimUserResponseModel(OrganizationUserUserDetails orgUser) + : this() + { + Id = orgUser.Id.ToString(); + ExternalId = orgUser.ExternalId; + UserName = orgUser.Email; + DisplayName = orgUser.Name; + Emails = new List { new EmailModel(orgUser.Email) }; + Name = new NameModel(orgUser.Name); + Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked; + } - public string Id { get; set; } - public ScimMetaModel Meta { get; private set; } + public string Id { get; set; } + public ScimMetaModel Meta { get; private set; } + } } diff --git a/bitwarden_license/src/Scim/Program.cs b/bitwarden_license/src/Scim/Program.cs index 48d5711e15..f8d6cb15b0 100644 --- a/bitwarden_license/src/Scim/Program.cs +++ b/bitwarden_license/src/Scim/Program.cs @@ -1,33 +1,34 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Scim; - -public class Program +namespace Bit.Scim { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return false; - } + var context = e.Properties["SourceContext"].ToString(); - return e.Level >= LogEventLevel.Warning; - })); - }) - .Build() - .Run(); + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } + + return e.Level >= LogEventLevel.Warning; + })); + }) + .Build() + .Run(); + } } } diff --git a/bitwarden_license/src/Scim/ScimSettings.cs b/bitwarden_license/src/Scim/ScimSettings.cs index ef4ebfb501..5c25dbf378 100644 --- a/bitwarden_license/src/Scim/ScimSettings.cs +++ b/bitwarden_license/src/Scim/ScimSettings.cs @@ -1,5 +1,6 @@ -namespace Bit.Scim; - -public class ScimSettings +namespace Bit.Scim { + public class ScimSettings + { + } } diff --git a/bitwarden_license/src/Scim/Startup.cs b/bitwarden_license/src/Scim/Startup.cs index 65e9220a77..daa5752e98 100644 --- a/bitwarden_license/src/Scim/Startup.cs +++ b/bitwarden_license/src/Scim/Startup.cs @@ -9,107 +9,108 @@ using IdentityModel; using Microsoft.Extensions.DependencyInjection.Extensions; using Stripe; -namespace Bit.Scim; - -public class Startup +namespace Bit.Scim { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - services.Configure(Configuration.GetSection("ScimSettings")); - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - services.AddScoped(); - - // Authentication - services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme) - .AddScheme( - ApiKeyAuthenticationOptions.DefaultScheme, null); - - services.AddAuthorization(config => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - config.AddPolicy("Scim", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.scim"); - }); - }); - - // Identity - services.AddCustomIdentityServices(globalSettings); - - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - - services.TryAddSingleton(); - - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - services.Configure(options => options.LowercaseUrls = true); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } - // Add routing - app.UseRouting(); + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); - // Add Scim context - app.UseMiddleware(); + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + services.Configure(Configuration.GetSection("ScimSettings")); - // Add authentication and authorization to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); - // Add current context - app.UseMiddleware(); + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - // Add MVC to the request pipeline. - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + services.AddScoped(); + + // Authentication + services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme) + .AddScheme( + ApiKeyAuthenticationOptions.DefaultScheme, null); + + services.AddAuthorization(config => + { + config.AddPolicy("Scim", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.scim"); + }); + }); + + // Identity + services.AddCustomIdentityServices(globalSettings); + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + + services.TryAddSingleton(); + + // Mvc + services.AddMvc(config => + { + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + services.Configure(options => options.LowercaseUrls = true); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + // Add routing + app.UseRouting(); + + // Add Scim context + app.UseMiddleware(); + + // Add authentication and authorization to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add current context + app.UseMiddleware(); + + // Add MVC to the request pipeline. + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + } } } diff --git a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs index 4e7e7ceb7a..c1b08b1b9e 100644 --- a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs +++ b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs @@ -8,82 +8,83 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.Options; -namespace Bit.Scim.Utilities; - -public class ApiKeyAuthenticationHandler : AuthenticationHandler +namespace Bit.Scim.Utilities { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly IScimContext _scimContext; - - public ApiKeyAuthenticationHandler( - IOptionsMonitor options, - ILoggerFactory logger, - UrlEncoder encoder, - ISystemClock clock, - IOrganizationRepository organizationRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository, - IScimContext scimContext) : - base(options, logger, encoder, clock) + public class ApiKeyAuthenticationHandler : AuthenticationHandler { - _organizationRepository = organizationRepository; - _organizationApiKeyRepository = organizationApiKeyRepository; - _scimContext = scimContext; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + private readonly IScimContext _scimContext; - protected override async Task HandleAuthenticateAsync() - { - var endpoint = Context.GetEndpoint(); - if (endpoint?.Metadata?.GetMetadata() != null) + public ApiKeyAuthenticationHandler( + IOptionsMonitor options, + ILoggerFactory logger, + UrlEncoder encoder, + ISystemClock clock, + IOrganizationRepository organizationRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository, + IScimContext scimContext) : + base(options, logger, encoder, clock) { - return AuthenticateResult.NoResult(); + _organizationRepository = organizationRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + _scimContext = scimContext; } - if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null) + protected override async Task HandleAuthenticateAsync() { - Logger.LogWarning("No organization."); - return AuthenticateResult.Fail("Invalid parameters"); + var endpoint = Context.GetEndpoint(); + if (endpoint?.Metadata?.GetMetadata() != null) + { + return AuthenticateResult.NoResult(); + } + + if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null) + { + Logger.LogWarning("No organization."); + return AuthenticateResult.Fail("Invalid parameters"); + } + + if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1) + { + Logger.LogWarning("An API request was received without the Authorization header"); + return AuthenticateResult.Fail("Invalid parameters"); + } + var apiKey = authHeader.ToString(); + if (apiKey.StartsWith("Bearer ")) + { + apiKey = apiKey.Substring(7); + } + + if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim || + _scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled) + { + Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId); + return AuthenticateResult.Fail("Invalid parameters"); + } + + var orgApiKey = (await _organizationApiKeyRepository + .GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim)) + .FirstOrDefault(); + if (orgApiKey?.ApiKey != apiKey) + { + Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey); + return AuthenticateResult.Fail("Invalid parameters"); + } + + Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId); + + var claims = new[] + { + new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"), + new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()), + new Claim(JwtClaimTypes.Scope, "api.scim"), + }; + var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler)); + var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), + ApiKeyAuthenticationOptions.DefaultScheme); + + return AuthenticateResult.Success(ticket); } - - if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1) - { - Logger.LogWarning("An API request was received without the Authorization header"); - return AuthenticateResult.Fail("Invalid parameters"); - } - var apiKey = authHeader.ToString(); - if (apiKey.StartsWith("Bearer ")) - { - apiKey = apiKey.Substring(7); - } - - if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim || - _scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled) - { - Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId); - return AuthenticateResult.Fail("Invalid parameters"); - } - - var orgApiKey = (await _organizationApiKeyRepository - .GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim)) - .FirstOrDefault(); - if (orgApiKey?.ApiKey != apiKey) - { - Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey); - return AuthenticateResult.Fail("Invalid parameters"); - } - - Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId); - - var claims = new[] - { - new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"), - new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()), - new Claim(JwtClaimTypes.Scope, "api.scim"), - }; - var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler)); - var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), - ApiKeyAuthenticationOptions.DefaultScheme); - - return AuthenticateResult.Success(ticket); } } diff --git a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs index f0015226b2..7d2bb3e81f 100644 --- a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs +++ b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs @@ -1,8 +1,9 @@ using Microsoft.AspNetCore.Authentication; -namespace Bit.Scim.Utilities; - -public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions +namespace Bit.Scim.Utilities { - public const string DefaultScheme = "ScimApiKey"; + public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions + { + public const string DefaultScheme = "ScimApiKey"; + } } diff --git a/bitwarden_license/src/Scim/Utilities/ScimConstants.cs b/bitwarden_license/src/Scim/Utilities/ScimConstants.cs index 219be6534f..4c9d11f6cd 100644 --- a/bitwarden_license/src/Scim/Utilities/ScimConstants.cs +++ b/bitwarden_license/src/Scim/Utilities/ScimConstants.cs @@ -1,9 +1,10 @@ -namespace Bit.Scim.Utilities; - -public static class ScimConstants +namespace Bit.Scim.Utilities { - public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse"; - public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error"; - public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User"; - public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group"; + public static class ScimConstants + { + public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse"; + public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error"; + public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User"; + public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group"; + } } diff --git a/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs b/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs index 6d5f3e1bf2..9550814de3 100644 --- a/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs +++ b/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs @@ -2,21 +2,22 @@ using Bit.Core.Settings; using Bit.Scim.Context; -namespace Bit.Scim.Utilities; - -public class ScimContextMiddleware +namespace Bit.Scim.Utilities { - private readonly RequestDelegate _next; - - public ScimContextMiddleware(RequestDelegate next) + public class ScimContextMiddleware { - _next = next; - } + private readonly RequestDelegate _next; - public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository) - { - await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository); - await _next.Invoke(httpContext); + public ScimContextMiddleware(RequestDelegate next) + { + _next = next; + } + + public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository) + { + await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository); + await _next.Invoke(httpContext); + } } } diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index 4ab9d7ef03..fbbf3084dc 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -22,688 +22,689 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Controllers; - -public class AccountController : Controller +namespace Bit.Sso.Controllers { - private readonly IAuthenticationSchemeProvider _schemeProvider; - private readonly IClientStore _clientStore; - - private readonly IIdentityServerInteractionService _interaction; - private readonly ILogger _logger; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoUserRepository _ssoUserRepository; - private readonly IUserRepository _userRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IUserService _userService; - private readonly II18nService _i18nService; - private readonly UserManager _userManager; - private readonly IGlobalSettings _globalSettings; - private readonly Core.Services.IEventService _eventService; - private readonly IDataProtectorTokenFactory _dataProtector; - - public AccountController( - IAuthenticationSchemeProvider schemeProvider, - IClientStore clientStore, - IIdentityServerInteractionService interaction, - ILogger logger, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, - ISsoConfigRepository ssoConfigRepository, - ISsoUserRepository ssoUserRepository, - IUserRepository userRepository, - IPolicyRepository policyRepository, - IUserService userService, - II18nService i18nService, - UserManager userManager, - IGlobalSettings globalSettings, - Core.Services.IEventService eventService, - IDataProtectorTokenFactory dataProtector) + public class AccountController : Controller { - _schemeProvider = schemeProvider; - _clientStore = clientStore; - _interaction = interaction; - _logger = logger; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; - _userRepository = userRepository; - _ssoConfigRepository = ssoConfigRepository; - _ssoUserRepository = ssoUserRepository; - _policyRepository = policyRepository; - _userService = userService; - _i18nService = i18nService; - _userManager = userManager; - _eventService = eventService; - _globalSettings = globalSettings; - _dataProtector = dataProtector; - } + private readonly IAuthenticationSchemeProvider _schemeProvider; + private readonly IClientStore _clientStore; - [HttpGet] - public async Task PreValidate(string domainHint) - { - try + private readonly IIdentityServerInteractionService _interaction; + private readonly ILogger _logger; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationService _organizationService; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoUserRepository _ssoUserRepository; + private readonly IUserRepository _userRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IUserService _userService; + private readonly II18nService _i18nService; + private readonly UserManager _userManager; + private readonly IGlobalSettings _globalSettings; + private readonly Core.Services.IEventService _eventService; + private readonly IDataProtectorTokenFactory _dataProtector; + + public AccountController( + IAuthenticationSchemeProvider schemeProvider, + IClientStore clientStore, + IIdentityServerInteractionService interaction, + ILogger logger, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationService organizationService, + ISsoConfigRepository ssoConfigRepository, + ISsoUserRepository ssoUserRepository, + IUserRepository userRepository, + IPolicyRepository policyRepository, + IUserService userService, + II18nService i18nService, + UserManager userManager, + IGlobalSettings globalSettings, + Core.Services.IEventService eventService, + IDataProtectorTokenFactory dataProtector) { - // Validate domain_hint provided - if (string.IsNullOrWhiteSpace(domainHint)) + _schemeProvider = schemeProvider; + _clientStore = clientStore; + _interaction = interaction; + _logger = logger; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _organizationService = organizationService; + _userRepository = userRepository; + _ssoConfigRepository = ssoConfigRepository; + _ssoUserRepository = ssoUserRepository; + _policyRepository = policyRepository; + _userService = userService; + _i18nService = i18nService; + _userManager = userManager; + _eventService = eventService; + _globalSettings = globalSettings; + _dataProtector = dataProtector; + } + + [HttpGet] + public async Task PreValidate(string domainHint) + { + try { - return InvalidJson("NoOrganizationIdentifierProvidedError"); + // Validate domain_hint provided + if (string.IsNullOrWhiteSpace(domainHint)) + { + return InvalidJson("NoOrganizationIdentifierProvidedError"); + } + + // Validate organization exists from domain_hint + var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); + if (organization == null) + { + return InvalidJson("OrganizationNotFoundByIdentifierError"); + } + if (!organization.UseSso) + { + return InvalidJson("SsoNotAllowedForOrganizationError"); + } + + // Validate SsoConfig exists and is Enabled + var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); + if (ssoConfig == null) + { + return InvalidJson("SsoConfigurationNotFoundForOrganizationError"); + } + if (!ssoConfig.Enabled) + { + return InvalidJson("SsoNotEnabledForOrganizationError"); + } + + // Validate Authentication Scheme exists and is loaded (cache) + var scheme = await _schemeProvider.GetSchemeAsync(organization.Id.ToString()); + if (scheme == null || !(scheme is IDynamicAuthenticationScheme dynamicScheme)) + { + return InvalidJson("NoSchemeOrHandlerForSsoConfigurationFoundError"); + } + + // Run scheme validation + try + { + await dynamicScheme.Validate(); + } + catch (Exception ex) + { + var translatedException = _i18nService.GetLocalizedHtmlString(ex.Message); + var errorKey = "InvalidSchemeConfigurationError"; + if (!translatedException.ResourceNotFound) + { + errorKey = ex.Message; + } + return InvalidJson(errorKey, translatedException.ResourceNotFound ? ex : null); + } + + var tokenable = new SsoTokenable(organization, _globalSettings.Sso.SsoTokenLifetimeInSeconds); + var token = _dataProtector.Protect(tokenable); + + return new SsoPreValidateResponseModel(token); + } + catch (Exception ex) + { + return InvalidJson("PreValidationError", ex); + } + } + + [HttpGet] + public async Task Login(string returnUrl) + { + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + + if (!context.Parameters.AllKeys.Contains("domain_hint") || + string.IsNullOrWhiteSpace(context.Parameters["domain_hint"])) + { + throw new Exception(_i18nService.T("NoDomainHintProvided")); } - // Validate organization exists from domain_hint + var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; + + if (string.IsNullOrWhiteSpace(ssoToken)) + { + return Unauthorized("A valid SSO token is required to continue with SSO login"); + } + + var domainHint = context.Parameters["domain_hint"]; var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); + if (organization == null) { return InvalidJson("OrganizationNotFoundByIdentifierError"); } - if (!organization.UseSso) + + var tokenable = _dataProtector.Unprotect(ssoToken); + + if (!tokenable.TokenIsValid(organization)) { - return InvalidJson("SsoNotAllowedForOrganizationError"); + return Unauthorized("The SSO token associated with your request is expired. A valid SSO token is required to continue."); } - // Validate SsoConfig exists and is Enabled - var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); - if (ssoConfig == null) + return RedirectToAction(nameof(ExternalChallenge), new { - return InvalidJson("SsoConfigurationNotFoundForOrganizationError"); - } - if (!ssoConfig.Enabled) + scheme = organization.Id.ToString(), + returnUrl, + state = context.Parameters["state"], + userIdentifier = context.Parameters["session_state"], + }); + } + + [HttpGet] + public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) + { + if (string.IsNullOrEmpty(returnUrl)) { - return InvalidJson("SsoNotEnabledForOrganizationError"); + returnUrl = "~/"; } - // Validate Authentication Scheme exists and is loaded (cache) - var scheme = await _schemeProvider.GetSchemeAsync(organization.Id.ToString()); - if (scheme == null || !(scheme is IDynamicAuthenticationScheme dynamicScheme)) + if (!Url.IsLocalUrl(returnUrl) && !_interaction.IsValidReturnUrl(returnUrl)) { - return InvalidJson("NoSchemeOrHandlerForSsoConfigurationFoundError"); + throw new Exception(_i18nService.T("InvalidReturnUrl")); } - // Run scheme validation - try + var props = new AuthenticationProperties { - await dynamicScheme.Validate(); - } - catch (Exception ex) - { - var translatedException = _i18nService.GetLocalizedHtmlString(ex.Message); - var errorKey = "InvalidSchemeConfigurationError"; - if (!translatedException.ResourceNotFound) + RedirectUri = Url.Action(nameof(ExternalCallback)), + Items = { - errorKey = ex.Message; + // scheme will get serialized into `State` and returned back + { "scheme", scheme }, + { "return_url", returnUrl }, + { "state", state }, + { "user_identifier", userIdentifier }, } - return InvalidJson(errorKey, translatedException.ResourceNotFound ? ex : null); - } - - var tokenable = new SsoTokenable(organization, _globalSettings.Sso.SsoTokenLifetimeInSeconds); - var token = _dataProtector.Protect(tokenable); - - return new SsoPreValidateResponseModel(token); - } - catch (Exception ex) - { - return InvalidJson("PreValidationError", ex); - } - } - - [HttpGet] - public async Task Login(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - - if (!context.Parameters.AllKeys.Contains("domain_hint") || - string.IsNullOrWhiteSpace(context.Parameters["domain_hint"])) - { - throw new Exception(_i18nService.T("NoDomainHintProvided")); - } - - var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; - - if (string.IsNullOrWhiteSpace(ssoToken)) - { - return Unauthorized("A valid SSO token is required to continue with SSO login"); - } - - var domainHint = context.Parameters["domain_hint"]; - var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); - - if (organization == null) - { - return InvalidJson("OrganizationNotFoundByIdentifierError"); - } - - var tokenable = _dataProtector.Unprotect(ssoToken); - - if (!tokenable.TokenIsValid(organization)) - { - return Unauthorized("The SSO token associated with your request is expired. A valid SSO token is required to continue."); - } - - return RedirectToAction(nameof(ExternalChallenge), new - { - scheme = organization.Id.ToString(), - returnUrl, - state = context.Parameters["state"], - userIdentifier = context.Parameters["session_state"], - }); - } - - [HttpGet] - public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) - { - if (string.IsNullOrEmpty(returnUrl)) - { - returnUrl = "~/"; - } - - if (!Url.IsLocalUrl(returnUrl) && !_interaction.IsValidReturnUrl(returnUrl)) - { - throw new Exception(_i18nService.T("InvalidReturnUrl")); - } - - var props = new AuthenticationProperties - { - RedirectUri = Url.Action(nameof(ExternalCallback)), - Items = - { - // scheme will get serialized into `State` and returned back - { "scheme", scheme }, - { "return_url", returnUrl }, - { "state", state }, - { "user_identifier", userIdentifier }, - } - }; - - return Challenge(props, scheme); - } - - [HttpGet] - public async Task ExternalCallback() - { - // Read external identity from the temporary cookie - var result = await HttpContext.AuthenticateAsync( - AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - if (result?.Succeeded != true) - { - throw new Exception(_i18nService.T("ExternalAuthenticationError")); - } - - // Debugging - var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); - _logger.LogDebug("External claims: {@claims}", externalClaims); - - // Lookup our user and external provider info - var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); - if (user == null) - { - // This might be where you might initiate a custom workflow for user registration - // in this sample we don't show how that would be done, as our sample implementation - // simply auto-provisions new external user - var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") ? - result.Properties.Items["user_identifier"] : null; - user = await AutoProvisionUserAsync(provider, providerUserId, claims, userIdentifier, ssoConfigData); - } - - if (user != null) - { - // This allows us to collect any additional claims or properties - // for the specific protocols used and store them in the local auth cookie. - // this is typically used to store data needed for signout from those protocols. - var additionalLocalClaims = new List(); - var localSignInProps = new AuthenticationProperties - { - IsPersistent = true, - ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) }; - ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); - // Issue authentication cookie for user - await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) - { - DisplayName = user.Email, - IdentityProvider = provider, - AdditionalClaims = additionalLocalClaims.ToArray() - }, localSignInProps); + return Challenge(props, scheme); } - // Delete temporary cookie used during external authentication - await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - - // Retrieve return URL - var returnUrl = result.Properties.Items["return_url"] ?? "~/"; - - // Check if external login is in the context of an OIDC request - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context != null) + [HttpGet] + public async Task ExternalCallback() { - if (IsNativeClient(context)) + // Read external identity from the temporary cookie + var result = await HttpContext.AuthenticateAsync( + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + if (result?.Succeeded != true) { - // The client is native, so this change in how to - // return the response is for better UX for the end user. - HttpContext.Response.StatusCode = 200; - HttpContext.Response.Headers["Location"] = string.Empty; - return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); + throw new Exception(_i18nService.T("ExternalAuthenticationError")); } - } - return Redirect(returnUrl); - } + // Debugging + var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); + _logger.LogDebug("External claims: {@claims}", externalClaims); - [HttpGet] - public async Task Logout(string logoutId) - { - // Build a model so the logged out page knows what to display - var (updatedLogoutId, redirectUri, externalAuthenticationScheme) = await GetLoggedOutDataAsync(logoutId); - - if (User?.Identity.IsAuthenticated == true) - { - // Delete local authentication cookie - await HttpContext.SignOutAsync(); - } - - // HACK: Temporary workaroud for the time being that doesn't try to sign out of OneLogin schemes, - // which doesnt support SLO - if (externalAuthenticationScheme != null && !externalAuthenticationScheme.Contains("onelogin")) - { - // Build a return URL so the upstream provider will redirect back - // to us after the user has logged out. this allows us to then - // complete our single sign-out processing. - var url = Url.Action("Logout", new { logoutId = updatedLogoutId }); - - // This triggers a redirect to the external provider for sign-out - return SignOut(new AuthenticationProperties { RedirectUri = url }, externalAuthenticationScheme); - } - if (redirectUri != null) - { - return View("Redirect", new RedirectViewModel { RedirectUrl = redirectUri }); - } - else - { - return Redirect("~/"); - } - } - - private async Task<(User user, string provider, string providerUserId, IEnumerable claims, SsoConfigurationData config)> - FindUserFromExternalProviderAsync(AuthenticateResult result) - { - var provider = result.Properties.Items["scheme"]; - var orgId = new Guid(provider); - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId); - if (ssoConfig == null || !ssoConfig.Enabled) - { - throw new Exception(_i18nService.T("OrganizationOrSsoConfigNotFound")); - } - - var ssoConfigData = ssoConfig.GetData(); - var externalUser = result.Principal; - - // Validate acr claim against expectation before going further - if (!string.IsNullOrWhiteSpace(ssoConfigData.ExpectedReturnAcrValue)) - { - var acrClaim = externalUser.FindFirst(JwtClaimTypes.AuthenticationContextClassReference); - if (acrClaim?.Value != ssoConfigData.ExpectedReturnAcrValue) + // Lookup our user and external provider info + var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); + if (user == null) { - throw new Exception(_i18nService.T("AcrMissingOrInvalid")); + // This might be where you might initiate a custom workflow for user registration + // in this sample we don't show how that would be done, as our sample implementation + // simply auto-provisions new external user + var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") ? + result.Properties.Items["user_identifier"] : null; + user = await AutoProvisionUserAsync(provider, providerUserId, claims, userIdentifier, ssoConfigData); } - } - // Ensure the NameIdentifier used is not a transient name ID, if so, we need a different attribute - // for the user identifier. - static bool nameIdIsNotTransient(Claim c) => c.Type == ClaimTypes.NameIdentifier - && (c.Properties == null - || !c.Properties.ContainsKey(SamlPropertyKeys.ClaimFormat) - || c.Properties[SamlPropertyKeys.ClaimFormat] != SamlNameIdFormats.Transient); - - // Try to determine the unique id of the external user (issued by the provider) - // the most common claim type for that are the sub claim and the NameIdentifier - // depending on the external provider, some other claim type might be used - var customUserIdClaimTypes = ssoConfigData.GetAdditionalUserIdClaimTypes(); - var userIdClaim = externalUser.FindFirst(c => customUserIdClaimTypes.Contains(c.Type)) ?? - externalUser.FindFirst(JwtClaimTypes.Subject) ?? - externalUser.FindFirst(nameIdIsNotTransient) ?? - // Some SAML providers may use the `uid` attribute for this - // where a transient NameID has been sent in the subject - externalUser.FindFirst("uid") ?? - externalUser.FindFirst("upn") ?? - externalUser.FindFirst("eppn") ?? - throw new Exception(_i18nService.T("UnknownUserId")); - - // Remove the user id claim so we don't include it as an extra claim if/when we provision the user - var claims = externalUser.Claims.ToList(); - claims.Remove(userIdClaim); - - // find external user - var providerUserId = userIdClaim.Value; - - var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); - - return (user, provider, providerUserId, claims, ssoConfigData); - } - - private async Task AutoProvisionUserAsync(string provider, string providerUserId, - IEnumerable claims, string userIdentifier, SsoConfigurationData config) - { - var name = GetName(claims, config.GetAdditionalNameClaimTypes()); - var email = GetEmailAddress(claims, config.GetAdditionalEmailClaimTypes()); - if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) - { - email = providerUserId; - } - - if (!Guid.TryParse(provider, out var orgId)) - { - // TODO: support non-org (server-wide) SSO in the future? - throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); - } - - User existingUser = null; - if (string.IsNullOrWhiteSpace(userIdentifier)) - { - if (string.IsNullOrWhiteSpace(email)) + if (user != null) { - throw new Exception(_i18nService.T("CannotFindEmailClaim")); - } - existingUser = await _userRepository.GetByEmailAsync(email); - } - else - { - var split = userIdentifier.Split(","); - if (split.Length < 2) - { - throw new Exception(_i18nService.T("InvalidUserIdentifier")); - } - var userId = split[0]; - var token = split[1]; - - var tokenOptions = new TokenOptions(); - - var claimedUser = await _userService.GetUserByIdAsync(userId); - if (claimedUser != null) - { - var tokenIsValid = await _userManager.VerifyUserTokenAsync( - claimedUser, tokenOptions.PasswordResetTokenProvider, TokenPurposes.LinkSso, token); - if (tokenIsValid) + // This allows us to collect any additional claims or properties + // for the specific protocols used and store them in the local auth cookie. + // this is typically used to store data needed for signout from those protocols. + var additionalLocalClaims = new List(); + var localSignInProps = new AuthenticationProperties { - existingUser = claimedUser; - } - else + IsPersistent = true, + ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) + }; + ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); + + // Issue authentication cookie for user + await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) { - throw new Exception(_i18nService.T("UserIdAndTokenMismatch")); + DisplayName = user.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + } + + // Delete temporary cookie used during external authentication + await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + + // Retrieve return URL + var returnUrl = result.Properties.Items["return_url"] ?? "~/"; + + // Check if external login is in the context of an OIDC request + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + if (context != null) + { + if (IsNativeClient(context)) + { + // The client is native, so this change in how to + // return the response is for better UX for the end user. + HttpContext.Response.StatusCode = 200; + HttpContext.Response.Headers["Location"] = string.Empty; + return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); } } + + return Redirect(returnUrl); } - OrganizationUser orgUser = null; - var organization = await _organizationRepository.GetByIdAsync(orgId); - if (organization == null) + [HttpGet] + public async Task Logout(string logoutId) { - throw new Exception(_i18nService.T("CouldNotFindOrganization", orgId)); - } + // Build a model so the logged out page knows what to display + var (updatedLogoutId, redirectUri, externalAuthenticationScheme) = await GetLoggedOutDataAsync(logoutId); - // Try to find OrgUser via existing User Id (accepted/confirmed user) - if (existingUser != null) - { - var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(existingUser.Id); - orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgId); - } - - // If no Org User found by Existing User Id - search all organization users via email - orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(orgId, email); - - // All Existing User flows handled below - if (existingUser != null) - { - if (existingUser.UsesKeyConnector && - (orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited)) + if (User?.Identity.IsAuthenticated == true) { - throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector")); + // Delete local authentication cookie + await HttpContext.SignOutAsync(); } + // HACK: Temporary workaroud for the time being that doesn't try to sign out of OneLogin schemes, + // which doesnt support SLO + if (externalAuthenticationScheme != null && !externalAuthenticationScheme.Contains("onelogin")) + { + // Build a return URL so the upstream provider will redirect back + // to us after the user has logged out. this allows us to then + // complete our single sign-out processing. + var url = Url.Action("Logout", new { logoutId = updatedLogoutId }); + + // This triggers a redirect to the external provider for sign-out + return SignOut(new AuthenticationProperties { RedirectUri = url }, externalAuthenticationScheme); + } + if (redirectUri != null) + { + return View("Redirect", new RedirectViewModel { RedirectUrl = redirectUri }); + } + else + { + return Redirect("~/"); + } + } + + private async Task<(User user, string provider, string providerUserId, IEnumerable claims, SsoConfigurationData config)> + FindUserFromExternalProviderAsync(AuthenticateResult result) + { + var provider = result.Properties.Items["scheme"]; + var orgId = new Guid(provider); + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId); + if (ssoConfig == null || !ssoConfig.Enabled) + { + throw new Exception(_i18nService.T("OrganizationOrSsoConfigNotFound")); + } + + var ssoConfigData = ssoConfig.GetData(); + var externalUser = result.Principal; + + // Validate acr claim against expectation before going further + if (!string.IsNullOrWhiteSpace(ssoConfigData.ExpectedReturnAcrValue)) + { + var acrClaim = externalUser.FindFirst(JwtClaimTypes.AuthenticationContextClassReference); + if (acrClaim?.Value != ssoConfigData.ExpectedReturnAcrValue) + { + throw new Exception(_i18nService.T("AcrMissingOrInvalid")); + } + } + + // Ensure the NameIdentifier used is not a transient name ID, if so, we need a different attribute + // for the user identifier. + static bool nameIdIsNotTransient(Claim c) => c.Type == ClaimTypes.NameIdentifier + && (c.Properties == null + || !c.Properties.ContainsKey(SamlPropertyKeys.ClaimFormat) + || c.Properties[SamlPropertyKeys.ClaimFormat] != SamlNameIdFormats.Transient); + + // Try to determine the unique id of the external user (issued by the provider) + // the most common claim type for that are the sub claim and the NameIdentifier + // depending on the external provider, some other claim type might be used + var customUserIdClaimTypes = ssoConfigData.GetAdditionalUserIdClaimTypes(); + var userIdClaim = externalUser.FindFirst(c => customUserIdClaimTypes.Contains(c.Type)) ?? + externalUser.FindFirst(JwtClaimTypes.Subject) ?? + externalUser.FindFirst(nameIdIsNotTransient) ?? + // Some SAML providers may use the `uid` attribute for this + // where a transient NameID has been sent in the subject + externalUser.FindFirst("uid") ?? + externalUser.FindFirst("upn") ?? + externalUser.FindFirst("eppn") ?? + throw new Exception(_i18nService.T("UnknownUserId")); + + // Remove the user id claim so we don't include it as an extra claim if/when we provision the user + var claims = externalUser.Claims.ToList(); + claims.Remove(userIdClaim); + + // find external user + var providerUserId = userIdClaim.Value; + + var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); + + return (user, provider, providerUserId, claims, ssoConfigData); + } + + private async Task AutoProvisionUserAsync(string provider, string providerUserId, + IEnumerable claims, string userIdentifier, SsoConfigurationData config) + { + var name = GetName(claims, config.GetAdditionalNameClaimTypes()); + var email = GetEmailAddress(claims, config.GetAdditionalEmailClaimTypes()); + if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) + { + email = providerUserId; + } + + if (!Guid.TryParse(provider, out var orgId)) + { + // TODO: support non-org (server-wide) SSO in the future? + throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); + } + + User existingUser = null; + if (string.IsNullOrWhiteSpace(userIdentifier)) + { + if (string.IsNullOrWhiteSpace(email)) + { + throw new Exception(_i18nService.T("CannotFindEmailClaim")); + } + existingUser = await _userRepository.GetByEmailAsync(email); + } + else + { + var split = userIdentifier.Split(","); + if (split.Length < 2) + { + throw new Exception(_i18nService.T("InvalidUserIdentifier")); + } + var userId = split[0]; + var token = split[1]; + + var tokenOptions = new TokenOptions(); + + var claimedUser = await _userService.GetUserByIdAsync(userId); + if (claimedUser != null) + { + var tokenIsValid = await _userManager.VerifyUserTokenAsync( + claimedUser, tokenOptions.PasswordResetTokenProvider, TokenPurposes.LinkSso, token); + if (tokenIsValid) + { + existingUser = claimedUser; + } + else + { + throw new Exception(_i18nService.T("UserIdAndTokenMismatch")); + } + } + } + + OrganizationUser orgUser = null; + var organization = await _organizationRepository.GetByIdAsync(orgId); + if (organization == null) + { + throw new Exception(_i18nService.T("CouldNotFindOrganization", orgId)); + } + + // Try to find OrgUser via existing User Id (accepted/confirmed user) + if (existingUser != null) + { + var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(existingUser.Id); + orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgId); + } + + // If no Org User found by Existing User Id - search all organization users via email + orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(orgId, email); + + // All Existing User flows handled below + if (existingUser != null) + { + if (existingUser.UsesKeyConnector && + (orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited)) + { + throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector")); + } + + if (orgUser == null) + { + // Org User is not created - no invite has been sent + throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + } + + if (orgUser.Status == OrganizationUserStatusType.Invited) + { + // Org User is invited - they must manually accept the invite via email and authenticate with MP + throw new Exception(_i18nService.T("UserAlreadyInvited", email, organization.Name)); + } + + // Accepted or Confirmed - create SSO link and return; + await CreateSsoUserRecord(providerUserId, existingUser.Id, orgId, orgUser); + return existingUser; + } + + // Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one + if (orgUser == null && organization.Seats.HasValue) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(orgId); + var initialSeatCount = organization.Seats.Value; + var availableSeats = initialSeatCount - userCount; + var prorationDate = DateTime.UtcNow; + if (availableSeats < 1) + { + try + { + if (_globalSettings.SelfHosted) + { + throw new Exception("Cannot autoscale on self-hosted instance."); + } + + await _organizationService.AutoAddSeatsAsync(organization, 1, prorationDate); + } + catch (Exception e) + { + if (organization.Seats.Value != initialSeatCount) + { + await _organizationService.AdjustSeatsAsync(orgId, initialSeatCount - organization.Seats.Value, prorationDate); + } + _logger.LogInformation(e, "SSO auto provisioning failed"); + throw new Exception(_i18nService.T("NoSeatsAvailable", organization.Name)); + } + } + } + + // Create user record - all existing user flows are handled above + var user = new User + { + Name = name, + Email = email, + ApiKey = CoreHelpers.SecureRandomString(30) + }; + await _userService.RegisterUserAsync(user); + + // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email + var twoFactorPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.TwoFactorAuthentication); + if (twoFactorPolicy != null && twoFactorPolicy.Enabled) + { + user.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); + } + + // Create Org User if null or else update existing Org User if (orgUser == null) { - // Org User is not created - no invite has been sent - throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Invited + }; + await _organizationUserRepository.CreateAsync(orgUser); } - - if (orgUser.Status == OrganizationUserStatusType.Invited) + else { - // Org User is invited - they must manually accept the invite via email and authenticate with MP - throw new Exception(_i18nService.T("UserAlreadyInvited", email, organization.Name)); + orgUser.UserId = user.Id; + await _organizationUserRepository.ReplaceAsync(orgUser); } - // Accepted or Confirmed - create SSO link and return; - await CreateSsoUserRecord(providerUserId, existingUser.Id, orgId, orgUser); - return existingUser; + // Create sso user record + await CreateSsoUserRecord(providerUserId, user.Id, orgId, orgUser); + + return user; } - // Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one - if (orgUser == null && organization.Seats.HasValue) + private IActionResult InvalidJson(string errorMessageKey, Exception ex = null) { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(orgId); - var initialSeatCount = organization.Seats.Value; - var availableSeats = initialSeatCount - userCount; - var prorationDate = DateTime.UtcNow; - if (availableSeats < 1) + Response.StatusCode = ex == null ? 400 : 500; + return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey)) { - try - { - if (_globalSettings.SelfHosted) - { - throw new Exception("Cannot autoscale on self-hosted instance."); - } - - await _organizationService.AutoAddSeatsAsync(organization, 1, prorationDate); - } - catch (Exception e) - { - if (organization.Seats.Value != initialSeatCount) - { - await _organizationService.AdjustSeatsAsync(orgId, initialSeatCount - organization.Seats.Value, prorationDate); - } - _logger.LogInformation(e, "SSO auto provisioning failed"); - throw new Exception(_i18nService.T("NoSeatsAvailable", organization.Name)); - } - } - } - - // Create user record - all existing user flows are handled above - var user = new User - { - Name = name, - Email = email, - ApiKey = CoreHelpers.SecureRandomString(30) - }; - await _userService.RegisterUserAsync(user); - - // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email - var twoFactorPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.TwoFactorAuthentication); - if (twoFactorPolicy != null && twoFactorPolicy.Enabled) - { - user.SetTwoFactorProviders(new Dictionary - { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - } + ExceptionMessage = ex?.Message, + ExceptionStackTrace = ex?.StackTrace, + InnerExceptionMessage = ex?.InnerException?.Message, }); - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); } - // Create Org User if null or else update existing Org User - if (orgUser == null) + private string GetEmailAddress(IEnumerable claims, IEnumerable additionalClaimTypes) { - orgUser = new OrganizationUser + var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@")); + + var email = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? + filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, + SamlClaimTypes.Email, "mail", "emailaddress"); + if (!string.IsNullOrWhiteSpace(email)) { + return email; + } + + var username = filteredClaims.GetFirstMatch(JwtClaimTypes.PreferredUserName, + SamlClaimTypes.UserId, "uid"); + if (!string.IsNullOrWhiteSpace(username)) + { + return username; + } + + return null; + } + + private string GetName(IEnumerable claims, IEnumerable additionalClaimTypes) + { + var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value)); + + var name = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? + filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, + SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); + if (!string.IsNullOrWhiteSpace(name)) + { + return name; + } + + var givenName = filteredClaims.GetFirstMatch(SamlClaimTypes.GivenName, "givenname", "firstname", + "fn", "fname", "nickname"); + var surname = filteredClaims.GetFirstMatch(SamlClaimTypes.Surname, "sn", "surname", "lastname"); + var nameParts = new[] { givenName, surname }.Where(p => !string.IsNullOrWhiteSpace(p)); + if (nameParts.Any()) + { + return string.Join(' ', nameParts); + } + + return null; + } + + private async Task CreateSsoUserRecord(string providerUserId, Guid userId, Guid orgId, OrganizationUser orgUser) + { + // Delete existing SsoUser (if any) - avoids error if providerId has changed and the sso link is stale + var existingSsoUser = await _ssoUserRepository.GetByUserIdOrganizationIdAsync(orgId, userId); + if (existingSsoUser != null) + { + await _ssoUserRepository.DeleteAsync(userId, orgId); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_ResetSsoLink); + } + else + { + // If no stale user, this is the user's first Sso login ever + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_FirstSsoLogin); + } + + var ssoUser = new SsoUser + { + ExternalId = providerUserId, + UserId = userId, OrganizationId = orgId, - UserId = user.Id, - Type = OrganizationUserType.User, - Status = OrganizationUserStatusType.Invited }; - await _organizationUserRepository.CreateAsync(orgUser); - } - else - { - orgUser.UserId = user.Id; - await _organizationUserRepository.ReplaceAsync(orgUser); + await _ssoUserRepository.CreateAsync(ssoUser); } - // Create sso user record - await CreateSsoUserRecord(providerUserId, user.Id, orgId, orgUser); - - return user; - } - - private IActionResult InvalidJson(string errorMessageKey, Exception ex = null) - { - Response.StatusCode = ex == null ? 400 : 500; - return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey)) + private void ProcessLoginCallback(AuthenticateResult externalResult, + List localClaims, AuthenticationProperties localSignInProps) { - ExceptionMessage = ex?.Message, - ExceptionStackTrace = ex?.StackTrace, - InnerExceptionMessage = ex?.InnerException?.Message, - }); - } - - private string GetEmailAddress(IEnumerable claims, IEnumerable additionalClaimTypes) - { - var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@")); - - var email = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, - SamlClaimTypes.Email, "mail", "emailaddress"); - if (!string.IsNullOrWhiteSpace(email)) - { - return email; - } - - var username = filteredClaims.GetFirstMatch(JwtClaimTypes.PreferredUserName, - SamlClaimTypes.UserId, "uid"); - if (!string.IsNullOrWhiteSpace(username)) - { - return username; - } - - return null; - } - - private string GetName(IEnumerable claims, IEnumerable additionalClaimTypes) - { - var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value)); - - var name = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, - SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); - if (!string.IsNullOrWhiteSpace(name)) - { - return name; - } - - var givenName = filteredClaims.GetFirstMatch(SamlClaimTypes.GivenName, "givenname", "firstname", - "fn", "fname", "nickname"); - var surname = filteredClaims.GetFirstMatch(SamlClaimTypes.Surname, "sn", "surname", "lastname"); - var nameParts = new[] { givenName, surname }.Where(p => !string.IsNullOrWhiteSpace(p)); - if (nameParts.Any()) - { - return string.Join(' ', nameParts); - } - - return null; - } - - private async Task CreateSsoUserRecord(string providerUserId, Guid userId, Guid orgId, OrganizationUser orgUser) - { - // Delete existing SsoUser (if any) - avoids error if providerId has changed and the sso link is stale - var existingSsoUser = await _ssoUserRepository.GetByUserIdOrganizationIdAsync(orgId, userId); - if (existingSsoUser != null) - { - await _ssoUserRepository.DeleteAsync(userId, orgId); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_ResetSsoLink); - } - else - { - // If no stale user, this is the user's first Sso login ever - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_FirstSsoLogin); - } - - var ssoUser = new SsoUser - { - ExternalId = providerUserId, - UserId = userId, - OrganizationId = orgId, - }; - await _ssoUserRepository.CreateAsync(ssoUser); - } - - private void ProcessLoginCallback(AuthenticateResult externalResult, - List localClaims, AuthenticationProperties localSignInProps) - { - // If the external system sent a session id claim, copy it over - // so we can use it for single sign-out - var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); - if (sid != null) - { - localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); - } - - // If the external provider issued an idToken, we'll keep it for signout - var idToken = externalResult.Properties.GetTokenValue("id_token"); - if (idToken != null) - { - localSignInProps.StoreTokens( - new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); - } - } - - private async Task GetProviderAsync(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context?.IdP != null && await _schemeProvider.GetSchemeAsync(context.IdP) != null) - { - return context.IdP; - } - var schemes = await _schemeProvider.GetAllSchemesAsync(); - var providers = schemes.Select(x => x.Name).ToList(); - return providers.FirstOrDefault(); - } - - private async Task<(string, string, string)> GetLoggedOutDataAsync(string logoutId) - { - // Get context information (client name, post logout redirect URI and iframe for federated signout) - var logout = await _interaction.GetLogoutContextAsync(logoutId); - string externalAuthenticationScheme = null; - if (User?.Identity.IsAuthenticated == true) - { - var idp = User.FindFirst(JwtClaimTypes.IdentityProvider)?.Value; - if (idp != null && idp != IdentityServerConstants.LocalIdentityProvider) + // If the external system sent a session id claim, copy it over + // so we can use it for single sign-out + var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); + if (sid != null) { - var providerSupportsSignout = await HttpContext.GetSchemeSupportsSignOutAsync(idp); - if (providerSupportsSignout) - { - if (logoutId == null) - { - // If there's no current logout context, we need to create one - // this captures necessary info from the current logged in user - // before we signout and redirect away to the external IdP for signout - logoutId = await _interaction.CreateLogoutContextAsync(); - } + localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); + } - externalAuthenticationScheme = idp; - } + // If the external provider issued an idToken, we'll keep it for signout + var idToken = externalResult.Properties.GetTokenValue("id_token"); + if (idToken != null) + { + localSignInProps.StoreTokens( + new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); } } - return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme); - } + private async Task GetProviderAsync(string returnUrl) + { + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + if (context?.IdP != null && await _schemeProvider.GetSchemeAsync(context.IdP) != null) + { + return context.IdP; + } + var schemes = await _schemeProvider.GetAllSchemesAsync(); + var providers = schemes.Select(x => x.Name).ToList(); + return providers.FirstOrDefault(); + } - public bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) - { - return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) - && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + private async Task<(string, string, string)> GetLoggedOutDataAsync(string logoutId) + { + // Get context information (client name, post logout redirect URI and iframe for federated signout) + var logout = await _interaction.GetLogoutContextAsync(logoutId); + string externalAuthenticationScheme = null; + if (User?.Identity.IsAuthenticated == true) + { + var idp = User.FindFirst(JwtClaimTypes.IdentityProvider)?.Value; + if (idp != null && idp != IdentityServerConstants.LocalIdentityProvider) + { + var providerSupportsSignout = await HttpContext.GetSchemeSupportsSignOutAsync(idp); + if (providerSupportsSignout) + { + if (logoutId == null) + { + // If there's no current logout context, we need to create one + // this captures necessary info from the current logged in user + // before we signout and redirect away to the external IdP for signout + logoutId = await _interaction.CreateLogoutContextAsync(); + } + + externalAuthenticationScheme = idp; + } + } + } + + return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme); + } + + public bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) + { + return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) + && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + } } } diff --git a/bitwarden_license/src/Sso/Controllers/HomeController.cs b/bitwarden_license/src/Sso/Controllers/HomeController.cs index ee15fefc90..5ce112fa46 100644 --- a/bitwarden_license/src/Sso/Controllers/HomeController.cs +++ b/bitwarden_license/src/Sso/Controllers/HomeController.cs @@ -5,50 +5,51 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Controllers; - -public class HomeController : Controller +namespace Bit.Sso.Controllers { - private readonly IIdentityServerInteractionService _interaction; - - public HomeController(IIdentityServerInteractionService interaction) + public class HomeController : Controller { - _interaction = interaction; - } + private readonly IIdentityServerInteractionService _interaction; - [Route("~/Error")] - [Route("~/Home/Error")] - [AllowAnonymous] - public async Task Error(string errorId) - { - var vm = new ErrorViewModel(); - - // retrieve error details from identityserver - var message = string.IsNullOrWhiteSpace(errorId) ? null : - await _interaction.GetErrorContextAsync(errorId); - if (message != null) + public HomeController(IIdentityServerInteractionService interaction) { - vm.Error = message; + _interaction = interaction; } - else + + [Route("~/Error")] + [Route("~/Home/Error")] + [AllowAnonymous] + public async Task Error(string errorId) { - vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier; - var exceptionHandlerPathFeature = HttpContext.Features.Get(); - var exception = exceptionHandlerPathFeature?.Error; - if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: ")) + var vm = new ErrorViewModel(); + + // retrieve error details from identityserver + var message = string.IsNullOrWhiteSpace(errorId) ? null : + await _interaction.GetErrorContextAsync(errorId); + if (message != null) { - // Messages coming from aspnetcore with a message - // similar to "The registered sign-in schemes are: {schemes}." - // will expose other Org IDs and sign-in schemes enabled on - // the server. These errors should be truncated to just the - // scheme impacted (always the first sentence) - var cleanupPoint = opEx.Message.IndexOf(". ") + 1; - var exMessage = opEx.Message.Substring(0, cleanupPoint); - exception = new InvalidOperationException(exMessage, opEx); + vm.Error = message; + } + else + { + vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier; + var exceptionHandlerPathFeature = HttpContext.Features.Get(); + var exception = exceptionHandlerPathFeature?.Error; + if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: ")) + { + // Messages coming from aspnetcore with a message + // similar to "The registered sign-in schemes are: {schemes}." + // will expose other Org IDs and sign-in schemes enabled on + // the server. These errors should be truncated to just the + // scheme impacted (always the first sentence) + var cleanupPoint = opEx.Message.IndexOf(". ") + 1; + var exMessage = opEx.Message.Substring(0, cleanupPoint); + exception = new InvalidOperationException(exMessage, opEx); + } + vm.Exception = exception; } - vm.Exception = exception; - } - return View("Error", vm); + return View("Error", vm); + } } } diff --git a/bitwarden_license/src/Sso/Controllers/InfoController.cs b/bitwarden_license/src/Sso/Controllers/InfoController.cs index c3641c4660..d652e8cddc 100644 --- a/bitwarden_license/src/Sso/Controllers/InfoController.cs +++ b/bitwarden_license/src/Sso/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Controllers; - -public class InfoController : Controller +namespace Bit.Sso.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/bitwarden_license/src/Sso/Controllers/MetadataController.cs b/bitwarden_license/src/Sso/Controllers/MetadataController.cs index 54f4f8cd44..dbf033e84f 100644 --- a/bitwarden_license/src/Sso/Controllers/MetadataController.cs +++ b/bitwarden_license/src/Sso/Controllers/MetadataController.cs @@ -5,65 +5,66 @@ using Microsoft.AspNetCore.Mvc; using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.WebSso; -namespace Bit.Sso.Controllers; - -public class MetadataController : Controller +namespace Bit.Sso.Controllers { - private readonly IAuthenticationSchemeProvider _schemeProvider; - - public MetadataController( - IAuthenticationSchemeProvider schemeProvider) + public class MetadataController : Controller { - _schemeProvider = schemeProvider; - } + private readonly IAuthenticationSchemeProvider _schemeProvider; - [HttpGet("saml2/{scheme}")] - public async Task ViewAsync(string scheme) - { - if (string.IsNullOrWhiteSpace(scheme)) + public MetadataController( + IAuthenticationSchemeProvider schemeProvider) { - return NotFound(); + _schemeProvider = schemeProvider; } - var authScheme = await _schemeProvider.GetSchemeAsync(scheme); - if (authScheme == null || - !(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) || - dynamicAuthScheme?.SsoType != SsoType.Saml2) + [HttpGet("saml2/{scheme}")] + public async Task ViewAsync(string scheme) { - return NotFound(); + if (string.IsNullOrWhiteSpace(scheme)) + { + return NotFound(); + } + + var authScheme = await _schemeProvider.GetSchemeAsync(scheme); + if (authScheme == null || + !(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) || + dynamicAuthScheme?.SsoType != SsoType.Saml2) + { + return NotFound(); + } + + if (!(dynamicAuthScheme.Options is Saml2Options options)) + { + return NotFound(); + } + + var uri = new Uri( + Request.Scheme + + "://" + + Request.Host + + Request.Path + + Request.QueryString); + + var pathBase = Request.PathBase.Value; + pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase; + + var requestdata = new HttpRequestData( + Request.Method, + uri, + pathBase, + null, + Request.Cookies, + (data) => data); + + var metadataResult = CommandFactory + .GetCommand(CommandFactory.MetadataCommand) + .Run(requestdata, options); + //Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml"); + return new ContentResult + { + Content = metadataResult.Content, + ContentType = "text/xml", + }; } - - if (!(dynamicAuthScheme.Options is Saml2Options options)) - { - return NotFound(); - } - - var uri = new Uri( - Request.Scheme - + "://" - + Request.Host - + Request.Path - + Request.QueryString); - - var pathBase = Request.PathBase.Value; - pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase; - - var requestdata = new HttpRequestData( - Request.Method, - uri, - pathBase, - null, - Request.Cookies, - (data) => data); - - var metadataResult = CommandFactory - .GetCommand(CommandFactory.MetadataCommand) - .Run(requestdata, options); - //Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml"); - return new ContentResult - { - Content = metadataResult.Content, - ContentType = "text/xml", - }; } } diff --git a/bitwarden_license/src/Sso/Models/ErrorViewModel.cs b/bitwarden_license/src/Sso/Models/ErrorViewModel.cs index 46ae8edd90..4c0ea8748d 100644 --- a/bitwarden_license/src/Sso/Models/ErrorViewModel.cs +++ b/bitwarden_license/src/Sso/Models/ErrorViewModel.cs @@ -1,26 +1,27 @@ using IdentityServer4.Models; -namespace Bit.Sso.Models; - -public class ErrorViewModel +namespace Bit.Sso.Models { - private string _requestId; - - public ErrorMessage Error { get; set; } - public Exception Exception { get; set; } - - public string Message => Error?.Error; - public string Description => Error?.ErrorDescription ?? Exception?.Message; - public string RedirectUri => Error?.RedirectUri; - public string RequestId + public class ErrorViewModel { - get + private string _requestId; + + public ErrorMessage Error { get; set; } + public Exception Exception { get; set; } + + public string Message => Error?.Error; + public string Description => Error?.ErrorDescription ?? Exception?.Message; + public string RedirectUri => Error?.RedirectUri; + public string RequestId { - return Error?.RequestId ?? _requestId; - } - set - { - _requestId = value; + get + { + return Error?.RequestId ?? _requestId; + } + set + { + _requestId = value; + } } } } diff --git a/bitwarden_license/src/Sso/Models/RedirectViewModel.cs b/bitwarden_license/src/Sso/Models/RedirectViewModel.cs index 9bc294d96c..54b5b4715d 100644 --- a/bitwarden_license/src/Sso/Models/RedirectViewModel.cs +++ b/bitwarden_license/src/Sso/Models/RedirectViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Sso.Models; - -public class RedirectViewModel +namespace Bit.Sso.Models { - public string RedirectUrl { get; set; } + public class RedirectViewModel + { + public string RedirectUrl { get; set; } + } } diff --git a/bitwarden_license/src/Sso/Models/SamlEnvironment.cs b/bitwarden_license/src/Sso/Models/SamlEnvironment.cs index 6de718029a..f1890840fa 100644 --- a/bitwarden_license/src/Sso/Models/SamlEnvironment.cs +++ b/bitwarden_license/src/Sso/Models/SamlEnvironment.cs @@ -1,8 +1,9 @@ using System.Security.Cryptography.X509Certificates; -namespace Bit.Sso.Models; - -public class SamlEnvironment +namespace Bit.Sso.Models { - public X509Certificate2 SpSigningCertificate { get; set; } + public class SamlEnvironment + { + public X509Certificate2 SpSigningCertificate { get; set; } + } } diff --git a/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs b/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs index f96b387752..9877e1c5ac 100644 --- a/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs +++ b/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs @@ -1,12 +1,13 @@ using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Models; - -public class SsoPreValidateResponseModel : JsonResult +namespace Bit.Sso.Models { - public SsoPreValidateResponseModel(string token) : base(new + public class SsoPreValidateResponseModel : JsonResult { - token - }) - { } + public SsoPreValidateResponseModel(string token) : base(new + { + token + }) + { } + } } diff --git a/bitwarden_license/src/Sso/Program.cs b/bitwarden_license/src/Sso/Program.cs index 672c73bfb5..910f09332e 100644 --- a/bitwarden_license/src/Sso/Program.cs +++ b/bitwarden_license/src/Sso/Program.cs @@ -2,32 +2,33 @@ using Serilog; using Serilog.Events; -namespace Bit.Sso; - -public class Program +namespace Bit.Sso { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return false; - } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); + var context = e.Properties["SourceContext"].ToString(); + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); + } } } diff --git a/bitwarden_license/src/Sso/Startup.cs b/bitwarden_license/src/Sso/Startup.cs index 99aa5961f4..6116d86c26 100644 --- a/bitwarden_license/src/Sso/Startup.cs +++ b/bitwarden_license/src/Sso/Startup.cs @@ -8,147 +8,148 @@ using IdentityServer4.Extensions; using Microsoft.IdentityModel.Logging; using Stripe; -namespace Bit.Sso; - -public class Startup +namespace Bit.Sso { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - - // Caching - services.AddMemoryCache(); - services.AddDistributedCache(globalSettings); - - // Mvc - services.AddControllersWithViews(); - - // Cookies - if (Environment.IsDevelopment()) + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - services.Configure(options => + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + + // Caching + services.AddMemoryCache(); + services.AddDistributedCache(globalSettings); + + // Mvc + services.AddControllersWithViews(); + + // Cookies + if (Environment.IsDevelopment()) { - options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - options.OnAppendCookie = ctx => + services.Configure(options => { - ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - }; - }); + options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + options.OnAppendCookie = ctx => + { + ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + }; + }); + } + + // Authentication + services.AddDistributedIdentityServices(globalSettings); + services.AddAuthentication() + .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + services.AddSsoServices(globalSettings); + + // IdentityServer + services.AddSsoIdentityServerServices(Environment, globalSettings); + + // Identity + services.AddCustomIdentityServices(globalSettings); + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + services.AddCoreLocalizationServices(); } - // Authentication - services.AddDistributedIdentityServices(globalSettings); - services.AddAuthentication() - .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - services.AddSsoServices(globalSettings); - - // IdentityServer - services.AddSsoIdentityServerServices(Environment, globalSettings); - - // Identity - services.AddCustomIdentityServices(globalSettings); - - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - services.AddCoreLocalizationServices(); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings, - ILogger logger) - { - if (env.IsDevelopment() || globalSettings.SelfHosted) + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings, + ILogger logger) { - IdentityModelEventSource.ShowPII = true; - } - - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (!env.IsDevelopment()) - { - var uri = new Uri(globalSettings.BaseServiceUri.Sso); - app.Use(async (ctx, next) => + if (env.IsDevelopment() || globalSettings.SelfHosted) { - ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); - await next(); + IdentityModelEventSource.ShowPII = true; + } + + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (!env.IsDevelopment()) + { + var uri = new Uri(globalSettings.BaseServiceUri.Sso); + app.Use(async (ctx, next) => + { + ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); + await next(); + }); + } + + if (globalSettings.SelfHosted) + { + app.UsePathBase("/sso"); + app.UseForwardedHeaders(globalSettings); + } + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + app.UseCookiePolicy(); + } + else + { + app.UseExceptionHandler("/Error"); + } + + app.UseCoreLocalization(); + + // Add static files to the request pipeline. + app.UseStaticFiles(); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add current context + app.UseMiddleware(); + + // Add IdentityServer to the request pipeline. + app.UseIdentityServer(new IdentityServerMiddlewareOptions + { + AuthenticationMiddleware = app => app.UseMiddleware() }); + + // Add Mvc stuff + app.UseAuthorization(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + + // Log startup + logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } - - if (globalSettings.SelfHosted) - { - app.UsePathBase("/sso"); - app.UseForwardedHeaders(globalSettings); - } - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - app.UseCookiePolicy(); - } - else - { - app.UseExceptionHandler("/Error"); - } - - app.UseCoreLocalization(); - - // Add static files to the request pipeline. - app.UseStaticFiles(); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add current context - app.UseMiddleware(); - - // Add IdentityServer to the request pipeline. - app.UseIdentityServer(new IdentityServerMiddlewareOptions - { - AuthenticationMiddleware = app => app.UseMiddleware() - }); - - // Add Mvc stuff - app.UseAuthorization(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - - // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } } diff --git a/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs b/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs index 735c7bc0ad..93a6fd146a 100644 --- a/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs @@ -1,45 +1,46 @@ using System.Security.Claims; using System.Text.RegularExpressions; -namespace Bit.Sso.Utilities; - -public static class ClaimsExtensions +namespace Bit.Sso.Utilities { - private static readonly Regex _normalizeTextRegEx = - new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline); - - public static string GetFirstMatch(this IEnumerable claims, params string[] possibleNames) + public static class ClaimsExtensions { - var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList(); + private static readonly Regex _normalizeTextRegEx = + new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline); - // Order of prescendence is by passed in names - foreach (var name in possibleNames.Select(Normalize)) + public static string GetFirstMatch(this IEnumerable claims, params string[] possibleNames) { - // Second by order of claims (find claim by name) - foreach (var claim in normalizedClaims) + var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList(); + + // Order of prescendence is by passed in names + foreach (var name in possibleNames.Select(Normalize)) { - if (Equals(claim.Item1, name)) + // Second by order of claims (find claim by name) + foreach (var claim in normalizedClaims) { - return claim.Value; + if (Equals(claim.Item1, name)) + { + return claim.Value; + } } } + return null; } - return null; - } - private static bool Equals(string text, string compare) - { - return text == compare || - (string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) || - string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase); - } - - private static string Normalize(string text) - { - if (string.IsNullOrWhiteSpace(text)) + private static bool Equals(string text, string compare) { - return text; + return text == compare || + (string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) || + string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase); + } + + private static string Normalize(string text) + { + if (string.IsNullOrWhiteSpace(text)) + { + return text; + } + return _normalizeTextRegEx.Replace(text, string.Empty); } - return _normalizeTextRegEx.Replace(text, string.Empty); } } diff --git a/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs b/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs index 7a7f569638..bd58fc6124 100644 --- a/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs +++ b/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs @@ -5,31 +5,32 @@ using IdentityServer4.Services; using IdentityServer4.Stores; using IdentityServer4.Validation; -namespace Bit.Sso.Utilities; - -public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator +namespace Bit.Sso.Utilities { - private readonly GlobalSettings _globalSettings; - - public DiscoveryResponseGenerator( - IdentityServerOptions options, - IResourceStore resourceStore, - IKeyMaterialService keys, - ExtensionGrantValidator extensionGrants, - ISecretsListParser secretParsers, - IResourceOwnerPasswordValidator resourceOwnerValidator, - ILogger logger, - GlobalSettings globalSettings) - : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) + public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator { - _globalSettings = globalSettings; - } + private readonly GlobalSettings _globalSettings; - public override async Task> CreateDiscoveryDocumentAsync( - string baseUrl, string issuerUri) - { - var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); - return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso, - _globalSettings.BaseServiceUri.InternalSso); + public DiscoveryResponseGenerator( + IdentityServerOptions options, + IResourceStore resourceStore, + IKeyMaterialService keys, + ExtensionGrantValidator extensionGrants, + ISecretsListParser secretParsers, + IResourceOwnerPasswordValidator resourceOwnerValidator, + ILogger logger, + GlobalSettings globalSettings) + : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) + { + _globalSettings = globalSettings; + } + + public override async Task> CreateDiscoveryDocumentAsync( + string baseUrl, string issuerUri) + { + var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); + return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso, + _globalSettings.BaseServiceUri.InternalSso); + } } } diff --git a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs index 96a316bc62..5a7ab65235 100644 --- a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs +++ b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs @@ -3,87 +3,88 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities; - -public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme +namespace Bit.Sso.Utilities { - public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, - AuthenticationSchemeOptions options) - : base(name, displayName, handlerType) + public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme { - Options = options; - } - public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, - AuthenticationSchemeOptions options, SsoType ssoType) - : this(name, displayName, handlerType, options) - { - SsoType = ssoType; - } + public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, + AuthenticationSchemeOptions options) + : base(name, displayName, handlerType) + { + Options = options; + } + public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, + AuthenticationSchemeOptions options, SsoType ssoType) + : this(name, displayName, handlerType, options) + { + SsoType = ssoType; + } - public AuthenticationSchemeOptions Options { get; set; } - public SsoType SsoType { get; set; } + public AuthenticationSchemeOptions Options { get; set; } + public SsoType SsoType { get; set; } - public async Task Validate() - { - switch (SsoType) + public async Task Validate() { - case SsoType.OpenIdConnect: - await ValidateOpenIdConnectAsync(); - break; - case SsoType.Saml2: - ValidateSaml(); - break; - default: - break; - } - } - - private void ValidateSaml() - { - if (SsoType != SsoType.Saml2) - { - return; - } - if (!(Options is Saml2Options samlOptions)) - { - throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError"); - } - samlOptions.Validate(Name); - } - - private async Task ValidateOpenIdConnectAsync() - { - if (SsoType != SsoType.OpenIdConnect) - { - return; - } - if (!(Options is OpenIdConnectOptions oidcOptions)) - { - throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError"); - } - oidcOptions.Validate(); - if (oidcOptions.Configuration == null) - { - if (oidcOptions.ConfigurationManager == null) + switch (SsoType) { - throw new Exception("PostConfigurationNotExecutedError"); + case SsoType.OpenIdConnect: + await ValidateOpenIdConnectAsync(); + break; + case SsoType.Saml2: + ValidateSaml(); + break; + default: + break; + } + } + + private void ValidateSaml() + { + if (SsoType != SsoType.Saml2) + { + return; + } + if (!(Options is Saml2Options samlOptions)) + { + throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError"); + } + samlOptions.Validate(Name); + } + + private async Task ValidateOpenIdConnectAsync() + { + if (SsoType != SsoType.OpenIdConnect) + { + return; + } + if (!(Options is OpenIdConnectOptions oidcOptions)) + { + throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError"); + } + oidcOptions.Validate(); + if (oidcOptions.Configuration == null) + { + if (oidcOptions.ConfigurationManager == null) + { + throw new Exception("PostConfigurationNotExecutedError"); + } + if (oidcOptions.Configuration == null) + { + try + { + oidcOptions.Configuration = await oidcOptions.ConfigurationManager + .GetConfigurationAsync(CancellationToken.None); + } + catch (Exception ex) + { + throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex); + } + } } if (oidcOptions.Configuration == null) { - try - { - oidcOptions.Configuration = await oidcOptions.ConfigurationManager - .GetConfigurationAsync(CancellationToken.None); - } - catch (Exception ex) - { - throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex); - } + throw new Exception("NoOpenIdConnectMetadataError"); } } - if (oidcOptions.Configuration == null) - { - throw new Exception("NoOpenIdConnectMetadataError"); - } } } diff --git a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs index b02e83deda..22f8979981 100644 --- a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs +++ b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs @@ -18,440 +18,441 @@ using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.Configuration; using Sustainsys.Saml2.Saml2P; -namespace Bit.Core.Business.Sso; - -public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider +namespace Bit.Core.Business.Sso { - private readonly IPostConfigureOptions _oidcPostConfigureOptions; - private readonly IExtendedOptionsMonitorCache _extendedOidcOptionsMonitorCache; - private readonly IPostConfigureOptions _saml2PostConfigureOptions; - private readonly IExtendedOptionsMonitorCache _extendedSaml2OptionsMonitorCache; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly SamlEnvironment _samlEnvironment; - private readonly TimeSpan _schemeCacheLifetime; - private readonly Dictionary _cachedSchemes; - private readonly Dictionary _cachedHandlerSchemes; - private readonly SemaphoreSlim _semaphore; - private readonly IHttpContextAccessor _httpContextAccessor; - - private DateTime? _lastSchemeLoad; - private IEnumerable _schemesCopy = Array.Empty(); - private IEnumerable _handlerSchemesCopy = Array.Empty(); - - public DynamicAuthenticationSchemeProvider( - IOptions options, - IPostConfigureOptions oidcPostConfigureOptions, - IOptionsMonitorCache oidcOptionsMonitorCache, - IPostConfigureOptions saml2PostConfigureOptions, - IOptionsMonitorCache saml2OptionsMonitorCache, - ISsoConfigRepository ssoConfigRepository, - ILogger logger, - GlobalSettings globalSettings, - SamlEnvironment samlEnvironment, - IHttpContextAccessor httpContextAccessor) - : base(options) + public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider { - _oidcPostConfigureOptions = oidcPostConfigureOptions; - _extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as - IExtendedOptionsMonitorCache; - if (_extendedOidcOptionsMonitorCache == null) + private readonly IPostConfigureOptions _oidcPostConfigureOptions; + private readonly IExtendedOptionsMonitorCache _extendedOidcOptionsMonitorCache; + private readonly IPostConfigureOptions _saml2PostConfigureOptions; + private readonly IExtendedOptionsMonitorCache _extendedSaml2OptionsMonitorCache; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly SamlEnvironment _samlEnvironment; + private readonly TimeSpan _schemeCacheLifetime; + private readonly Dictionary _cachedSchemes; + private readonly Dictionary _cachedHandlerSchemes; + private readonly SemaphoreSlim _semaphore; + private readonly IHttpContextAccessor _httpContextAccessor; + + private DateTime? _lastSchemeLoad; + private IEnumerable _schemesCopy = Array.Empty(); + private IEnumerable _handlerSchemesCopy = Array.Empty(); + + public DynamicAuthenticationSchemeProvider( + IOptions options, + IPostConfigureOptions oidcPostConfigureOptions, + IOptionsMonitorCache oidcOptionsMonitorCache, + IPostConfigureOptions saml2PostConfigureOptions, + IOptionsMonitorCache saml2OptionsMonitorCache, + ISsoConfigRepository ssoConfigRepository, + ILogger logger, + GlobalSettings globalSettings, + SamlEnvironment samlEnvironment, + IHttpContextAccessor httpContextAccessor) + : base(options) { - throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved."); + _oidcPostConfigureOptions = oidcPostConfigureOptions; + _extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as + IExtendedOptionsMonitorCache; + if (_extendedOidcOptionsMonitorCache == null) + { + throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved."); + } + + _saml2PostConfigureOptions = saml2PostConfigureOptions; + _extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as + IExtendedOptionsMonitorCache; + if (_extendedSaml2OptionsMonitorCache == null) + { + throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved."); + } + + _ssoConfigRepository = ssoConfigRepository; + _logger = logger; + _globalSettings = globalSettings; + _schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30); + _samlEnvironment = samlEnvironment; + _cachedSchemes = new Dictionary(); + _cachedHandlerSchemes = new Dictionary(); + _semaphore = new SemaphoreSlim(1); + _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); } - _saml2PostConfigureOptions = saml2PostConfigureOptions; - _extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as - IExtendedOptionsMonitorCache; - if (_extendedSaml2OptionsMonitorCache == null) + private bool CacheIsValid { - throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved."); + get => _lastSchemeLoad.HasValue + && _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow; } - _ssoConfigRepository = ssoConfigRepository; - _logger = logger; - _globalSettings = globalSettings; - _schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30); - _samlEnvironment = samlEnvironment; - _cachedSchemes = new Dictionary(); - _cachedHandlerSchemes = new Dictionary(); - _semaphore = new SemaphoreSlim(1); - _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); - } - - private bool CacheIsValid - { - get => _lastSchemeLoad.HasValue - && _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow; - } - - public override async Task GetSchemeAsync(string name) - { - var scheme = await base.GetSchemeAsync(name); - if (scheme != null) + public override async Task GetSchemeAsync(string name) { - return scheme; + var scheme = await base.GetSchemeAsync(name); + if (scheme != null) + { + return scheme; + } + + try + { + var dynamicScheme = await GetDynamicSchemeAsync(name); + return dynamicScheme; + } + catch (Exception ex) + { + _logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name); + } + + return null; } - try + public override async Task> GetAllSchemesAsync() { - var dynamicScheme = await GetDynamicSchemeAsync(name); - return dynamicScheme; - } - catch (Exception ex) - { - _logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name); + var existingSchemes = await base.GetAllSchemesAsync(); + var schemes = new List(); + schemes.AddRange(existingSchemes); + + await LoadAllDynamicSchemesIntoCacheAsync(); + schemes.AddRange(_schemesCopy); + + return schemes.ToArray(); } - return null; - } - - public override async Task> GetAllSchemesAsync() - { - var existingSchemes = await base.GetAllSchemesAsync(); - var schemes = new List(); - schemes.AddRange(existingSchemes); - - await LoadAllDynamicSchemesIntoCacheAsync(); - schemes.AddRange(_schemesCopy); - - return schemes.ToArray(); - } - - public override async Task> GetRequestHandlerSchemesAsync() - { - var existingSchemes = await base.GetRequestHandlerSchemesAsync(); - var schemes = new List(); - schemes.AddRange(existingSchemes); - - await LoadAllDynamicSchemesIntoCacheAsync(); - schemes.AddRange(_handlerSchemesCopy); - - return schemes.ToArray(); - } - - private async Task LoadAllDynamicSchemesIntoCacheAsync() - { - if (CacheIsValid) + public override async Task> GetRequestHandlerSchemesAsync() { - // Our cache hasn't expired or been invalidated, ignore request - return; + var existingSchemes = await base.GetRequestHandlerSchemesAsync(); + var schemes = new List(); + schemes.AddRange(existingSchemes); + + await LoadAllDynamicSchemesIntoCacheAsync(); + schemes.AddRange(_handlerSchemesCopy); + + return schemes.ToArray(); } - await _semaphore.WaitAsync(); - try + + private async Task LoadAllDynamicSchemesIntoCacheAsync() { if (CacheIsValid) { - // Just in case (double-checked locking pattern) + // Our cache hasn't expired or been invalidated, ignore request return; } - - // Save time just in case the following operation takes longer - var now = DateTime.UtcNow; - var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad); - - foreach (var config in newSchemes) + await _semaphore.WaitAsync(); + try { - DynamicAuthenticationScheme scheme; - try + if (CacheIsValid) { - scheme = GetSchemeFromSsoConfig(config); + // Just in case (double-checked locking pattern) + return; } - catch (Exception ex) + + // Save time just in case the following operation takes longer + var now = DateTime.UtcNow; + var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad); + + foreach (var config in newSchemes) { - _logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id); - continue; + DynamicAuthenticationScheme scheme; + try + { + scheme = GetSchemeFromSsoConfig(config); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id); + continue; + } + if (scheme == null) + { + continue; + } + SetSchemeInCache(scheme); } - if (scheme == null) + + if (newSchemes.Any()) { - continue; + // Maintain "safe" copy for use in enumeration routines + _schemesCopy = _cachedSchemes.Values.ToArray(); + _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); } - SetSchemeInCache(scheme); + _lastSchemeLoad = now; + } + finally + { + _semaphore.Release(); + } + } + + private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme) + { + if (!PostConfigureDynamicScheme(scheme)) + { + return null; + } + _cachedSchemes[scheme.Name] = scheme; + if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) + { + _cachedHandlerSchemes[scheme.Name] = scheme; + } + return scheme; + } + + private async Task GetDynamicSchemeAsync(string name) + { + if (_cachedSchemes.TryGetValue(name, out var cachedScheme)) + { + return cachedScheme; } - if (newSchemes.Any()) - { - // Maintain "safe" copy for use in enumeration routines - _schemesCopy = _cachedSchemes.Values.ToArray(); - _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); - } - _lastSchemeLoad = now; - } - finally - { - _semaphore.Release(); - } - } - - private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme) - { - if (!PostConfigureDynamicScheme(scheme)) - { - return null; - } - _cachedSchemes[scheme.Name] = scheme; - if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) - { - _cachedHandlerSchemes[scheme.Name] = scheme; - } - return scheme; - } - - private async Task GetDynamicSchemeAsync(string name) - { - if (_cachedSchemes.TryGetValue(name, out var cachedScheme)) - { - return cachedScheme; - } - - var scheme = await GetSchemeFromSsoConfigAsync(name); - if (scheme == null) - { - return null; - } - - await _semaphore.WaitAsync(); - try - { - scheme = SetSchemeInCache(scheme); + var scheme = await GetSchemeFromSsoConfigAsync(name); if (scheme == null) { return null; } - if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) + await _semaphore.WaitAsync(); + try { - _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); + scheme = SetSchemeInCache(scheme); + if (scheme == null) + { + return null; + } + + if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) + { + _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); + } + _schemesCopy = _cachedSchemes.Values.ToArray(); } - _schemesCopy = _cachedSchemes.Values.ToArray(); - } - finally - { - // Note: _lastSchemeLoad is not set here, this is a one-off - // and should not impact loading further cache updates - _semaphore.Release(); - } - return scheme; - } - - private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme) - { - try - { - if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions) + finally { - _oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions); - _extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions); + // Note: _lastSchemeLoad is not set here, this is a one-off + // and should not impact loading further cache updates + _semaphore.Release(); } - else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options) + return scheme; + } + + private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme) + { + try { - _saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options); - _extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options); + if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions) + { + _oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions); + _extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions); + } + else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options) + { + _saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options); + _extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options); + } + return true; } - return true; - } - catch (Exception ex) - { - _logger.LogError(ex, "Error performing post configuration for '{0}' ({1})", - scheme.Name, scheme.DisplayName); - } - return false; - } - - private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config) - { - var data = config.GetData(); - return data.ConfigType switch - { - SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data), - SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data), - _ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"), - }; - } - - private async Task GetSchemeFromSsoConfigAsync(string name) - { - if (!Guid.TryParse(name, out var organizationId)) - { - _logger.LogWarning("Could not determine organization id from name, '{0}'", name); - return null; - } - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); - if (ssoConfig == null || !ssoConfig.Enabled) - { - _logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name); - return null; - } - - return GetSchemeFromSsoConfig(ssoConfig); - } - - private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config) - { - var oidcOptions = new OpenIdConnectOptions - { - Authority = config.Authority, - ClientId = config.ClientId, - ClientSecret = config.ClientSecret, - ResponseType = "code", - ResponseMode = "form_post", - SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, - SignOutScheme = IdentityServerConstants.SignoutScheme, - SaveTokens = false, // reduce overall request size - TokenValidationParameters = new TokenValidationParameters + catch (Exception ex) { - NameClaimType = JwtClaimTypes.Name, - RoleClaimType = JwtClaimTypes.Role, - }, - CallbackPath = SsoConfigurationData.BuildCallbackPath(), - SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(), - MetadataAddress = config.MetadataAddress, - // Prevents URLs that go beyond 1024 characters which may break for some servers - AuthenticationMethod = config.RedirectBehavior, - GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint, - }; - oidcOptions.Scope - .AddIfNotExists(OpenIdConnectScopes.OpenId) - .AddIfNotExists(OpenIdConnectScopes.Email) - .AddIfNotExists(OpenIdConnectScopes.Profile); - foreach (var scope in config.GetAdditionalScopes()) - { - oidcOptions.Scope.AddIfNotExists(scope); - } - if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue)) - { - oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr); + _logger.LogError(ex, "Error performing post configuration for '{0}' ({1})", + scheme.Name, scheme.DisplayName); + } + return false; } - oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name); - - // see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values) - if (!string.IsNullOrWhiteSpace(config.AcrValues)) + private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config) { - oidcOptions.Events ??= new OpenIdConnectEvents(); - oidcOptions.Events.OnRedirectToIdentityProvider = ctx => + var data = config.GetData(); + return data.ConfigType switch { - ctx.ProtocolMessage.AcrValues = config.AcrValues; - return Task.CompletedTask; + SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data), + SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data), + _ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"), }; } - return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler), - oidcOptions, SsoType.OpenIdConnect); - } - - private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config) - { - if (_samlEnvironment == null) + private async Task GetSchemeFromSsoConfigAsync(string name) { - throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}"); + if (!Guid.TryParse(name, out var organizationId)) + { + _logger.LogWarning("Could not determine organization id from name, '{0}'", name); + return null; + } + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); + if (ssoConfig == null || !ssoConfig.Enabled) + { + _logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name); + return null; + } + + return GetSchemeFromSsoConfig(ssoConfig); } - var spEntityId = new Sustainsys.Saml2.Metadata.EntityId( - SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso)); - bool? allowCreate = null; - if (config.SpNameIdFormat != Saml2NameIdFormat.Transient) + private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config) { - allowCreate = true; - } - var spOptions = new SPOptions - { - EntityId = spEntityId, - ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name), - NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)), - WantAssertionsSigned = config.SpWantAssertionsSigned, - AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior), - ValidateCertificates = config.SpValidateCertificates, - }; - if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm)) - { - spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm; - } - if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm)) - { - spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm; - } - if (_samlEnvironment.SpSigningCertificate != null) - { - spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate); + var oidcOptions = new OpenIdConnectOptions + { + Authority = config.Authority, + ClientId = config.ClientId, + ClientSecret = config.ClientSecret, + ResponseType = "code", + ResponseMode = "form_post", + SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + SignOutScheme = IdentityServerConstants.SignoutScheme, + SaveTokens = false, // reduce overall request size + TokenValidationParameters = new TokenValidationParameters + { + NameClaimType = JwtClaimTypes.Name, + RoleClaimType = JwtClaimTypes.Role, + }, + CallbackPath = SsoConfigurationData.BuildCallbackPath(), + SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(), + MetadataAddress = config.MetadataAddress, + // Prevents URLs that go beyond 1024 characters which may break for some servers + AuthenticationMethod = config.RedirectBehavior, + GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint, + }; + oidcOptions.Scope + .AddIfNotExists(OpenIdConnectScopes.OpenId) + .AddIfNotExists(OpenIdConnectScopes.Email) + .AddIfNotExists(OpenIdConnectScopes.Profile); + foreach (var scope in config.GetAdditionalScopes()) + { + oidcOptions.Scope.AddIfNotExists(scope); + } + if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue)) + { + oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr); + } + + oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name); + + // see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values) + if (!string.IsNullOrWhiteSpace(config.AcrValues)) + { + oidcOptions.Events ??= new OpenIdConnectEvents(); + oidcOptions.Events.OnRedirectToIdentityProvider = ctx => + { + ctx.ProtocolMessage.AcrValues = config.AcrValues; + return Task.CompletedTask; + }; + } + + return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler), + oidcOptions, SsoType.OpenIdConnect); } - var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId); - var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions) + private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config) { - Binding = GetBindingType(config.IdpBindingType), - AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse, - DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests, - WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned, - }; - if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl)) - { - idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl); + if (_samlEnvironment == null) + { + throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}"); + } + + var spEntityId = new Sustainsys.Saml2.Metadata.EntityId( + SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso)); + bool? allowCreate = null; + if (config.SpNameIdFormat != Saml2NameIdFormat.Transient) + { + allowCreate = true; + } + var spOptions = new SPOptions + { + EntityId = spEntityId, + ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name), + NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)), + WantAssertionsSigned = config.SpWantAssertionsSigned, + AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior), + ValidateCertificates = config.SpValidateCertificates, + }; + if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm)) + { + spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm; + } + if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm)) + { + spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm; + } + if (_samlEnvironment.SpSigningCertificate != null) + { + spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate); + } + + var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId); + var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions) + { + Binding = GetBindingType(config.IdpBindingType), + AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse, + DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests, + WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned, + }; + if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl)) + { + idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl); + } + if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl)) + { + idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl); + } + if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm)) + { + idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm; + } + if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert)) + { + var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert); + idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert)); + } + idp.ArtifactResolutionServiceUrls.Clear(); + // This must happen last since it calls Validate() internally. + idp.LoadMetadata = false; + + var options = new Saml2Options + { + SPOptions = spOptions, + SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme, + CookieManager = new IdentityServer.DistributedCacheCookieManager(), + }; + options.IdentityProviders.Add(idp); + + return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2); } - if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl)) + + private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format) { - idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl); + return format switch + { + Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified, + Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress, + Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName, + Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName, + Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName, + Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier, + Saml2NameIdFormat.Persistent => NameIdFormat.Persistent, + Saml2NameIdFormat.Transient => NameIdFormat.Transient, + _ => NameIdFormat.NotConfigured, + }; } - if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm)) + + private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior) { - idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm; + return behavior switch + { + Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned, + Saml2SigningBehavior.Always => SigningBehavior.Always, + Saml2SigningBehavior.Never => SigningBehavior.Never, + _ => SigningBehavior.IfIdpWantAuthnRequestsSigned, + }; } - if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert)) + + private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType) { - var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert); - idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert)); + return bindingType switch + { + Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect, + Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, + _ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, + }; } - idp.ArtifactResolutionServiceUrls.Clear(); - // This must happen last since it calls Validate() internally. - idp.LoadMetadata = false; - - var options = new Saml2Options - { - SPOptions = spOptions, - SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, - SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme, - CookieManager = new IdentityServer.DistributedCacheCookieManager(), - }; - options.IdentityProviders.Add(idp); - - return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2); - } - - private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format) - { - return format switch - { - Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified, - Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress, - Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName, - Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName, - Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName, - Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier, - Saml2NameIdFormat.Persistent => NameIdFormat.Persistent, - Saml2NameIdFormat.Transient => NameIdFormat.Transient, - _ => NameIdFormat.NotConfigured, - }; - } - - private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior) - { - return behavior switch - { - Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned, - Saml2SigningBehavior.Always => SigningBehavior.Always, - Saml2SigningBehavior.Never => SigningBehavior.Never, - _ => SigningBehavior.IfIdpWantAuthnRequestsSigned, - }; - } - - private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType) - { - return bindingType switch - { - Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect, - Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, - _ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, - }; } } diff --git a/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs b/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs index 083417f25b..8e23e1f07f 100644 --- a/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs +++ b/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs @@ -1,36 +1,37 @@ using System.Collections.Concurrent; using Microsoft.Extensions.Options; -namespace Bit.Sso.Utilities; - -public class ExtendedOptionsMonitorCache : IExtendedOptionsMonitorCache where TOptions : class +namespace Bit.Sso.Utilities { - private readonly ConcurrentDictionary> _cache = - new ConcurrentDictionary>(StringComparer.Ordinal); - - public void AddOrUpdate(string name, TOptions options) + public class ExtendedOptionsMonitorCache : IExtendedOptionsMonitorCache where TOptions : class { - _cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy(() => options), - (string s, Lazy lazy) => new Lazy(() => options)); - } + private readonly ConcurrentDictionary> _cache = + new ConcurrentDictionary>(StringComparer.Ordinal); - public void Clear() - { - _cache.Clear(); - } + public void AddOrUpdate(string name, TOptions options) + { + _cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy(() => options), + (string s, Lazy lazy) => new Lazy(() => options)); + } - public TOptions GetOrAdd(string name, Func createOptions) - { - return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy(createOptions)).Value; - } + public void Clear() + { + _cache.Clear(); + } - public bool TryAdd(string name, TOptions options) - { - return _cache.TryAdd(name ?? Options.DefaultName, new Lazy(() => options)); - } + public TOptions GetOrAdd(string name, Func createOptions) + { + return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy(createOptions)).Value; + } - public bool TryRemove(string name) - { - return _cache.TryRemove(name ?? Options.DefaultName, out _); + public bool TryAdd(string name, TOptions options) + { + return _cache.TryAdd(name ?? Options.DefaultName, new Lazy(() => options)); + } + + public bool TryRemove(string name) + { + return _cache.TryRemove(name ?? Options.DefaultName, out _); + } } } diff --git a/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs b/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs index 9ebd0f9cfc..7deab54408 100644 --- a/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs +++ b/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs @@ -1,12 +1,13 @@ using Bit.Core.Enums; using Microsoft.AspNetCore.Authentication; -namespace Bit.Sso.Utilities; - -public interface IDynamicAuthenticationScheme +namespace Bit.Sso.Utilities { - AuthenticationSchemeOptions Options { get; set; } - SsoType SsoType { get; set; } + public interface IDynamicAuthenticationScheme + { + AuthenticationSchemeOptions Options { get; set; } + SsoType SsoType { get; set; } - Task Validate(); + Task Validate(); + } } diff --git a/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs b/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs index 0f62843187..73a5352a8d 100644 --- a/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs +++ b/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs @@ -1,8 +1,9 @@ using Microsoft.Extensions.Options; -namespace Bit.Sso.Utilities; - -public interface IExtendedOptionsMonitorCache : IOptionsMonitorCache where TOptions : class +namespace Bit.Sso.Utilities { - void AddOrUpdate(string name, TOptions options); + public interface IExtendedOptionsMonitorCache : IOptionsMonitorCache where TOptions : class + { + void AddOrUpdate(string name, TOptions options); + } } diff --git a/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs b/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs index 9221877a04..e01ff7111a 100644 --- a/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs @@ -1,62 +1,63 @@ using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.IdentityModel.Protocols.OpenIdConnect; -namespace Bit.Sso.Utilities; - -public static class OpenIdConnectOptionsExtensions +namespace Bit.Sso.Utilities { - public static async Task CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context) + public static class OpenIdConnectOptionsExtensions { - // Determine this is a valid request for our handler - if (options.CallbackPath != context.Request.Path && - options.RemoteSignOutPath != context.Request.Path && - options.SignedOutCallbackPath != context.Request.Path) + public static async Task CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context) { - return false; - } - - if (context.Request.Query["scheme"].FirstOrDefault() == scheme) - { - return true; - } - - try - { - // Parse out the message - OpenIdConnectMessage message = null; - if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + // Determine this is a valid request for our handler + if (options.CallbackPath != context.Request.Path && + options.RemoteSignOutPath != context.Request.Path && + options.SignedOutCallbackPath != context.Request.Path) { - message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair(pair.Key, pair.Value))); - } - else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && - !string.IsNullOrEmpty(context.Request.ContentType) && - context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) && - context.Request.Body.CanRead) - { - var form = await context.Request.ReadFormAsync(); - message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair(pair.Key, pair.Value))); - } - - var state = message?.State; - if (string.IsNullOrWhiteSpace(state)) - { - // State is required, it will fail later on for this reason. return false; } - // Handle State if we've gotten that back - var decodedState = options.StateDataFormat.Unprotect(state); - if (decodedState != null && decodedState.Items.ContainsKey("scheme")) + if (context.Request.Query["scheme"].FirstOrDefault() == scheme) { - return decodedState.Items["scheme"] == scheme; + return true; } - } - catch - { + + try + { + // Parse out the message + OpenIdConnectMessage message = null; + if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + { + message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair(pair.Key, pair.Value))); + } + else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && + !string.IsNullOrEmpty(context.Request.ContentType) && + context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) && + context.Request.Body.CanRead) + { + var form = await context.Request.ReadFormAsync(); + message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair(pair.Key, pair.Value))); + } + + var state = message?.State; + if (string.IsNullOrWhiteSpace(state)) + { + // State is required, it will fail later on for this reason. + return false; + } + + // Handle State if we've gotten that back + var decodedState = options.StateDataFormat.Unprotect(state); + if (decodedState != null && decodedState.Items.ContainsKey("scheme")) + { + return decodedState.Items["scheme"] == scheme; + } + } + catch + { + return false; + } + + // This is likely not an appropriate handler return false; } - - // This is likely not an appropriate handler - return false; } } diff --git a/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs b/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs index 3fae7ce4ec..983ce8b33f 100644 --- a/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs +++ b/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs @@ -1,63 +1,64 @@ -namespace Bit.Sso.Utilities; - -/// -/// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0 -/// [RFC6749]. These values represent the standard scope values supported -/// by OAuth 2.0 and therefore OIDC. -/// -/// -/// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes -/// -public static class OpenIdConnectScopes +namespace Bit.Sso.Utilities { /// - /// REQUIRED. Informs the Authorization Server that the Client is making - /// an OpenID Connect request. If the openid scope value is not present, - /// the behavior is entirely unspecified. - /// - public const string OpenId = "openid"; - - /// - /// OPTIONAL. This scope value requests access to the End-User's default - /// profile Claims, which are: name, family_name, given_name, - /// middle_name, nickname, preferred_username, profile, picture, - /// website, gender, birthdate, zoneinfo, locale, and updated_at. - /// - public const string Profile = "profile"; - - /// - /// OPTIONAL. This scope value requests access to the email and - /// email_verified Claims. - /// - public const string Email = "email"; - - /// - /// OPTIONAL. This scope value requests access to the address Claim. - /// - public const string Address = "address"; - - /// - /// OPTIONAL. This scope value requests access to the phone_number and - /// phone_number_verified Claims. - /// - public const string Phone = "phone"; - - /// - /// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token - /// be issued that can be used to obtain an Access Token that grants - /// access to the End-User's UserInfo Endpoint even when the End-User is - /// not present (not logged in). - /// - public const string OfflineAccess = "offline_access"; - - /// - /// OPTIONAL. Authentication Context Class Reference. String specifying - /// an Authentication Context Class Reference value that identifies the - /// Authentication Context Class that the authentication performed - /// satisfied. + /// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0 + /// [RFC6749]. These values represent the standard scope values supported + /// by OAuth 2.0 and therefore OIDC. /// /// - /// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2 + /// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes /// - public const string Acr = "acr"; + public static class OpenIdConnectScopes + { + /// + /// REQUIRED. Informs the Authorization Server that the Client is making + /// an OpenID Connect request. If the openid scope value is not present, + /// the behavior is entirely unspecified. + /// + public const string OpenId = "openid"; + + /// + /// OPTIONAL. This scope value requests access to the End-User's default + /// profile Claims, which are: name, family_name, given_name, + /// middle_name, nickname, preferred_username, profile, picture, + /// website, gender, birthdate, zoneinfo, locale, and updated_at. + /// + public const string Profile = "profile"; + + /// + /// OPTIONAL. This scope value requests access to the email and + /// email_verified Claims. + /// + public const string Email = "email"; + + /// + /// OPTIONAL. This scope value requests access to the address Claim. + /// + public const string Address = "address"; + + /// + /// OPTIONAL. This scope value requests access to the phone_number and + /// phone_number_verified Claims. + /// + public const string Phone = "phone"; + + /// + /// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token + /// be issued that can be used to obtain an Access Token that grants + /// access to the End-User's UserInfo Endpoint even when the End-User is + /// not present (not logged in). + /// + public const string OfflineAccess = "offline_access"; + + /// + /// OPTIONAL. Authentication Context Class Reference. String specifying + /// an Authentication Context Class Reference value that identifies the + /// Authentication Context Class that the authentication performed + /// satisfied. + /// + /// + /// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2 + /// + public const string Acr = "acr"; + } } diff --git a/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs b/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs index 46a75ca5c2..9d4870bd70 100644 --- a/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs @@ -4,101 +4,102 @@ using System.Xml; using Sustainsys.Saml2; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities; - -public static class Saml2OptionsExtensions +namespace Bit.Sso.Utilities { - public static async Task CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context) + public static class Saml2OptionsExtensions { - // Determine this is a valid request for our handler - if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal)) + public static async Task CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context) { - return false; - } + // Determine this is a valid request for our handler + if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal)) + { + return false; + } - var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default; - if (idp == null) - { - return false; - } + var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default; + if (idp == null) + { + return false; + } + + if (context.Request.Query["scheme"].FirstOrDefault() == scheme) + { + return true; + } + + // We need to pull out and parse the response or request SAML envelope + XmlElement envelope = null; + try + { + if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && + context.Request.HasFormContentType) + { + string encodedMessage; + if (context.Request.Form.TryGetValue("SAMLResponse", out var response)) + { + encodedMessage = response.FirstOrDefault(); + } + else + { + encodedMessage = context.Request.Form["SAMLRequest"]; + } + if (string.IsNullOrWhiteSpace(encodedMessage)) + { + return false; + } + envelope = XmlHelpers.XmlDocumentFromString( + Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement; + } + else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + { + var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ?? + context.Request.Query["SAMLResponse"].FirstOrDefault(); + try + { + var payload = Convert.FromBase64String(encodedPayload); + using var compressed = new MemoryStream(payload); + using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true); + using var deCompressed = new MemoryStream(); + await decompressedStream.CopyToAsync(deCompressed); + + envelope = XmlHelpers.XmlDocumentFromString( + Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement; + } + catch (FormatException ex) + { + throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex); + } + } + } + catch + { + return false; + } + + if (envelope == null) + { + return false; + } + + // Double check the entity Ids + var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim(); + if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase)) + { + return false; + } + + if (options.SPOptions.WantAssertionsSigned) + { + var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name]; + var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys, + options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm); + if (!isAssertionSigned) + { + throw new Exception("Cannot verify SAML assertion signature."); + } + } - if (context.Request.Query["scheme"].FirstOrDefault() == scheme) - { return true; } - - // We need to pull out and parse the response or request SAML envelope - XmlElement envelope = null; - try - { - if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && - context.Request.HasFormContentType) - { - string encodedMessage; - if (context.Request.Form.TryGetValue("SAMLResponse", out var response)) - { - encodedMessage = response.FirstOrDefault(); - } - else - { - encodedMessage = context.Request.Form["SAMLRequest"]; - } - if (string.IsNullOrWhiteSpace(encodedMessage)) - { - return false; - } - envelope = XmlHelpers.XmlDocumentFromString( - Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement; - } - else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) - { - var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ?? - context.Request.Query["SAMLResponse"].FirstOrDefault(); - try - { - var payload = Convert.FromBase64String(encodedPayload); - using var compressed = new MemoryStream(payload); - using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true); - using var deCompressed = new MemoryStream(); - await decompressedStream.CopyToAsync(deCompressed); - - envelope = XmlHelpers.XmlDocumentFromString( - Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement; - } - catch (FormatException ex) - { - throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex); - } - } - } - catch - { - return false; - } - - if (envelope == null) - { - return false; - } - - // Double check the entity Ids - var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim(); - if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase)) - { - return false; - } - - if (options.SPOptions.WantAssertionsSigned) - { - var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name]; - var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys, - options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm); - if (!isAssertionSigned) - { - throw new Exception("Cannot verify SAML assertion signature."); - } - } - - return true; } } diff --git a/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs b/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs index 2f314c7ef5..f62f5b04b7 100644 --- a/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs +++ b/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs @@ -1,11 +1,12 @@ -namespace Bit.Sso.Utilities; - -public static class SamlClaimTypes +namespace Bit.Sso.Utilities { - public const string Email = "urn:oid:0.9.2342.19200300.100.1.3"; - public const string GivenName = "urn:oid:2.5.4.42"; - public const string Surname = "urn:oid:2.5.4.4"; - public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241"; - public const string CommonName = "urn:oid:2.5.4.3"; - public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1"; + public static class SamlClaimTypes + { + public const string Email = "urn:oid:0.9.2342.19200300.100.1.3"; + public const string GivenName = "urn:oid:2.5.4.42"; + public const string Surname = "urn:oid:2.5.4.4"; + public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241"; + public const string CommonName = "urn:oid:2.5.4.3"; + public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1"; + } } diff --git a/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs b/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs index 94c03bd642..18ccc140f6 100644 --- a/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs +++ b/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs @@ -1,17 +1,18 @@ -namespace Bit.Sso.Utilities; - -public static class SamlNameIdFormats +namespace Bit.Sso.Utilities { - // Common - public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"; - public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; - public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"; - public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"; - // Not-so-common - public const string Upn = "http://schemas.xmlsoap.org/claims/UPN"; - public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName"; - public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"; - public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"; - public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"; - public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"; + public static class SamlNameIdFormats + { + // Common + public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"; + public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; + public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"; + public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"; + // Not-so-common + public const string Upn = "http://schemas.xmlsoap.org/claims/UPN"; + public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName"; + public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"; + public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"; + public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"; + public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"; + } } diff --git a/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs b/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs index 7be7fb4f66..21d599b7fe 100644 --- a/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs +++ b/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs @@ -1,6 +1,7 @@ -namespace Bit.Sso.Utilities; - -public static class SamlPropertyKeys +namespace Bit.Sso.Utilities { - public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format"; + public static class SamlPropertyKeys + { + public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format"; + } } diff --git a/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs b/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs index d7a5e3b1b4..444ed6c52a 100644 --- a/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs @@ -9,69 +9,70 @@ using IdentityServer4.ResponseHandling; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities; - -public static class ServiceCollectionExtensions +namespace Bit.Sso.Utilities { - public static IServiceCollection AddSsoServices(this IServiceCollection services, - GlobalSettings globalSettings) + public static class ServiceCollectionExtensions { - // SAML SP Configuration - var samlEnvironment = new SamlEnvironment + public static IServiceCollection AddSsoServices(this IServiceCollection services, + GlobalSettings globalSettings) { - SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings), - }; - services.AddSingleton(s => samlEnvironment); - - services.AddSingleton(); - // Oidc - services.AddSingleton, - OpenIdConnectPostConfigureOptions>(); - services.AddSingleton, - ExtendedOptionsMonitorCache>(); - // Saml2 - services.AddSingleton, - PostConfigureSaml2Options>(); - services.AddSingleton, - ExtendedOptionsMonitorCache>(); - - return services; - } - - public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services, - IWebHostEnvironment env, GlobalSettings globalSettings) - { - services.AddTransient(); - - var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso); - var identityServerBuilder = services - .AddIdentityServer(options => + // SAML SP Configuration + var samlEnvironment = new SamlEnvironment { - options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; - if (env.IsDevelopment()) + SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings), + }; + services.AddSingleton(s => samlEnvironment); + + services.AddSingleton(); + // Oidc + services.AddSingleton, + OpenIdConnectPostConfigureOptions>(); + services.AddSingleton, + ExtendedOptionsMonitorCache>(); + // Saml2 + services.AddSingleton, + PostConfigureSaml2Options>(); + services.AddSingleton, + ExtendedOptionsMonitorCache>(); + + return services; + } + + public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services, + IWebHostEnvironment env, GlobalSettings globalSettings) + { + services.AddTransient(); + + var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso); + var identityServerBuilder = services + .AddIdentityServer(options => { - options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - } - else + options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; + if (env.IsDevelopment()) + { + options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + } + else + { + options.UserInteraction.ErrorUrl = "/Error"; + options.UserInteraction.ErrorIdParameter = "errorId"; + } + options.InputLengthRestrictions.UserName = 256; + }) + .AddInMemoryCaching() + .AddInMemoryClients(new List { - options.UserInteraction.ErrorUrl = "/Error"; - options.UserInteraction.ErrorIdParameter = "errorId"; - } - options.InputLengthRestrictions.UserName = 256; - }) - .AddInMemoryCaching() - .AddInMemoryClients(new List - { - new OidcIdentityClient(globalSettings) - }) - .AddInMemoryIdentityResources(new List - { - new IdentityResources.OpenId(), - new IdentityResources.Profile() - }) - .AddIdentityServerCertificate(env, globalSettings); + new OidcIdentityClient(globalSettings) + }) + .AddInMemoryIdentityResources(new List + { + new IdentityResources.OpenId(), + new IdentityResources.Profile() + }) + .AddIdentityServerCertificate(env, globalSettings); - return identityServerBuilder; + return identityServerBuilder; + } } } diff --git a/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs b/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs index 9dca7a6909..4a39082f3d 100644 --- a/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs +++ b/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs @@ -3,82 +3,83 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities; - -public class SsoAuthenticationMiddleware +namespace Bit.Sso.Utilities { - private readonly RequestDelegate _next; - - public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes) + public class SsoAuthenticationMiddleware { - _next = next ?? throw new ArgumentNullException(nameof(next)); - Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes)); - } + private readonly RequestDelegate _next; - public IAuthenticationSchemeProvider Schemes { get; set; } - - public async Task Invoke(HttpContext context) - { - if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart")) - || (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart"))) + public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes) { - throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed."); + _next = next ?? throw new ArgumentNullException(nameof(next)); + Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes)); } - context.Features.Set(new AuthenticationFeature - { - OriginalPath = context.Request.Path, - OriginalPathBase = context.Request.PathBase - }); + public IAuthenticationSchemeProvider Schemes { get; set; } - // Give any IAuthenticationRequestHandler schemes a chance to handle the request - var handlers = context.RequestServices.GetRequiredService(); - foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync()) + public async Task Invoke(HttpContext context) { - // Determine if scheme is appropriate for the current context FIRST - if (scheme is IDynamicAuthenticationScheme dynamicScheme) + if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart")) + || (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart"))) { - switch (dynamicScheme.SsoType) + throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed."); + } + + context.Features.Set(new AuthenticationFeature + { + OriginalPath = context.Request.Path, + OriginalPathBase = context.Request.PathBase + }); + + // Give any IAuthenticationRequestHandler schemes a chance to handle the request + var handlers = context.RequestServices.GetRequiredService(); + foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync()) + { + // Determine if scheme is appropriate for the current context FIRST + if (scheme is IDynamicAuthenticationScheme dynamicScheme) { - case SsoType.OpenIdConnect: - default: - if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions && - !await oidcOptions.CouldHandleAsync(scheme.Name, context)) - { - // It's OIDC and Dynamic, but not a good fit - continue; - } - break; - case SsoType.Saml2: - if (dynamicScheme.Options is Saml2Options samlOptions && - !await samlOptions.CouldHandleAsync(scheme.Name, context)) - { - // It's SAML and Dynamic, but not a good fit - continue; - } - break; + switch (dynamicScheme.SsoType) + { + case SsoType.OpenIdConnect: + default: + if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions && + !await oidcOptions.CouldHandleAsync(scheme.Name, context)) + { + // It's OIDC and Dynamic, but not a good fit + continue; + } + break; + case SsoType.Saml2: + if (dynamicScheme.Options is Saml2Options samlOptions && + !await samlOptions.CouldHandleAsync(scheme.Name, context)) + { + // It's SAML and Dynamic, but not a good fit + continue; + } + break; + } + } + + // This far it's not dynamic OR it is but "could" be handled + if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler && + await handler.HandleRequestAsync()) + { + return; } } - // This far it's not dynamic OR it is but "could" be handled - if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler && - await handler.HandleRequestAsync()) + // Fallback to the default scheme from the provider + var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync(); + if (defaultAuthenticate != null) { - return; + var result = await context.AuthenticateAsync(defaultAuthenticate.Name); + if (result?.Principal != null) + { + context.User = result.Principal; + } } - } - // Fallback to the default scheme from the provider - var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync(); - if (defaultAuthenticate != null) - { - var result = await context.AuthenticateAsync(defaultAuthenticate.Name); - if (result?.Principal != null) - { - context.User = result.Principal; - } + await _next(context); } - - await _next(context); } } diff --git a/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs b/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs index 48f70c3350..59bc1c59fd 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs @@ -3,42 +3,43 @@ using AutoFixture; using AutoFixture.Xunit2; using Bit.Core.Enums.Provider; -namespace Bit.Commercial.Core.Test.AutoFixture; - -internal class ProviderUser : ICustomization +namespace Bit.Commercial.Core.Test.AutoFixture { - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - - public ProviderUser(ProviderUserStatusType status, ProviderUserType type) + internal class ProviderUser : ICustomization { - Status = status; - Type = type; + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + + public ProviderUser(ProviderUserStatusType status, ProviderUserType type) + { + Status = status; + Type = type; + } + + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.Type, Type) + .With(o => o.Status, Status)); + } } - public void Customize(IFixture fixture) + public class ProviderUserAttribute : CustomizeAttribute { - fixture.Customize(composer => composer - .With(o => o.Type, Type) - .With(o => o.Status, Status)); - } -} - -public class ProviderUserAttribute : CustomizeAttribute -{ - private readonly ProviderUserStatusType _status; - private readonly ProviderUserType _type; - - public ProviderUserAttribute( - ProviderUserStatusType status = ProviderUserStatusType.Confirmed, - ProviderUserType type = ProviderUserType.ProviderAdmin) - { - _status = status; - _type = type; - } - - public override ICustomization GetCustomization(ParameterInfo parameter) - { - return new ProviderUser(_status, _type); + private readonly ProviderUserStatusType _status; + private readonly ProviderUserType _type; + + public ProviderUserAttribute( + ProviderUserStatusType status = ProviderUserStatusType.Confirmed, + ProviderUserType type = ProviderUserType.ProviderAdmin) + { + _status = status; + _type = type; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new ProviderUser(_status, _type); + } } } diff --git a/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs index 53911ea06f..a8c08b6320 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs @@ -19,531 +19,532 @@ using NSubstitute.ReturnsExtensions; using Xunit; using ProviderUser = Bit.Core.Entities.Provider.ProviderUser; -namespace Bit.Commercial.Core.Test.Services; - -public class ProviderServiceTests +namespace Bit.Commercial.Core.Test.Services { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CreateAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) + public class ProviderServiceTests { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CreateAsync(default)); - Assert.Contains("Invalid owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CreateAsync_Success(User user, SutProvider sutProvider) - { - var userRepository = sutProvider.GetDependency(); - userRepository.GetByEmailAsync(user.Email).Returns(user); - - await sutProvider.Sut.CreateAsync(user.Email); - - await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().ReceivedWithAnyArgs().SendProviderSetupInviteEmailAsync(default, default, default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default)); - Assert.Contains("Invalid owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CompleteSetupAsync_TokenIsInvalid_Throws(User user, Provider provider, - SutProvider sutProvider) - { - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default)); - Assert.Contains("Invalid token.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, - [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.ProviderId = provider.Id; - providerUser.UserId = user.Id; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - sutProvider.Create(); - - var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key); - - await sutProvider.GetDependency().Received().UpsertAsync(provider); - await sutProvider.GetDependency().Received() - .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) - { - provider.Id = default; - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpdateAsync(provider)); - Assert.Contains("Cannot create provider this way.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateAsync_Success(Provider provider, SutProvider sutProvider) - { - await sutProvider.Sut.UpdateAsync(provider); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_ProviderIdIsInvalid_Throws(ProviderUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - - await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); - await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_EmailsInvalid_Throws(Provider provider, ProviderUserInvite providerUserInvite, - SutProvider sutProvider) - { - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); - sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - - providerUserInvite.UserIdentifiers = null; - - await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(providerUserInvite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_AlreadyInvited(Provider provider, ProviderUserInvite providerUserInvite, - SutProvider sutProvider) - { - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(1); - sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); - Assert.Empty(result); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_Success(Provider provider, ProviderUserInvite providerUserInvite, - SutProvider sutProvider) - { - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(0); - sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); - Assert.Equal(providerUserInvite.UserIdentifiers.Count(), result.Count); - Assert.True(result.TrueForAll(pu => pu.Status == ProviderUserStatusType.Invited), "Status must be invited"); - Assert.True(result.TrueForAll(pu => pu.ProviderId == providerUserInvite.ProviderId), "Provider Id must be correct"); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ResendInviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); - await Assert.ThrowsAsync(() => sutProvider.Sut.ResendInvitesAsync(invite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ResendInvitesAsync_Errors(Provider provider, - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, - [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu4, - SutProvider sutProvider) - { - var providerUsers = new[] { pu1, pu2, pu3, pu4 }; - pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; - - var invite = new ProviderUserInvite + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CreateAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) { - UserIdentifiers = providerUsers.Select(pu => pu.Id), - ProviderId = provider.Id - }; - - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.ResendInvitesAsync(invite); - Assert.Equal("", result[0].Item2); - Assert.Equal("User invalid.", result[1].Item2); - Assert.Equal("User invalid.", result[2].Item2); - Assert.Equal("User invalid.", result[3].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ResendInvitesAsync_Success(Provider provider, IEnumerable providerUsers, - SutProvider sutProvider) - { - foreach (var providerUser in providerUsers) - { - providerUser.ProviderId = provider.Id; - providerUser.Status = ProviderUserStatusType.Invited; + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(default)); + Assert.Contains("Invalid owner.", exception.Message); } - var invite = new ProviderUserInvite + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CreateAsync_Success(User user, SutProvider sutProvider) { - UserIdentifiers = providerUsers.Select(pu => pu.Id), - ProviderId = provider.Id - }; + var userRepository = sutProvider.GetDependency(); + userRepository.GetByEmailAsync(user.Email).Returns(user); - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); + await sutProvider.Sut.CreateAsync(user.Email); - var result = await sutProvider.Sut.ResendInvitesAsync(invite); - Assert.True(result.All(r => r.Item2 == "")); - } + await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().ReceivedWithAnyArgs().SendProviderSetupInviteEmailAsync(default, default, default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_UserIsInvalid_Throws(SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(default, default, default)); - Assert.Equal("User invalid.", exception.Message); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default)); + Assert.Contains("Invalid owner.", exception.Message); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_AlreadyAccepted_Throws( - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CompleteSetupAsync_TokenIsInvalid_Throws(User user, Provider provider, + SutProvider sutProvider) + { + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); - Assert.Equal("Already accepted.", exception.Message); - } + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default)); + Assert.Contains("Invalid token.", exception.Message); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_TokenIsInvalid_Throws( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); - Assert.Equal("Invalid token.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_WrongEmail_Throws( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - sutProvider.Create(); - - var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token)); - Assert.Equal("User email does not match invite.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_Success( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - sutProvider.Create(); - - providerUser.Email = user.Email; - var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); - Assert.Null(pu.Email); - Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); - Assert.Equal(user.Id, pu.UserId); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUsersAsync_NoValid( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, - [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, - SutProvider sutProvider) - { - pu1.ProviderId = pu3.ProviderId; - var providerUsers = new[] { pu1, pu2, pu3 }; - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - - var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); - var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, default); - - Assert.Empty(result); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUsersAsync_Success( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, User u1, - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, User u2, - [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, User u3, - Provider provider, User user, SutProvider sutProvider) - { - pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; - pu1.UserId = u1.Id; - pu2.UserId = u2.Id; - pu3.UserId = u3.Id; - var providerUsers = new[] { pu1, pu2, pu3 }; - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(provider.Id).Returns(provider); - var userRepository = sutProvider.GetDependency(); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { u1, u2, u3 }); - - var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); - var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, user.Id); - - Assert.Equal("Invalid user.", result[0].Item2); - Assert.Equal("", result[1].Item2); - Assert.Equal("Invalid user.", result[2].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUserAsync_UserIdIsInvalid_Throws(ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.Id = default; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(providerUser, default)); - Assert.Equal("Invite the user first.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUserAsync_Success( - [ProviderUser(type: ProviderUserType.ProviderAdmin)] ProviderUser providerUser, User savingUser, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - await sutProvider.Sut.SaveUserAsync(providerUser, savingUser.Id); - await providerUserRepository.Received().ReplaceAsync(providerUser); - await sutProvider.GetDependency().Received() - .LogProviderUserEventAsync(providerUser, EventType.ProviderUser_Updated, null); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsersAsync_NoRemainingOwner_Throws(Provider provider, User deletingUser, - ICollection providerUsers, SutProvider sutProvider) - { - var userIds = providerUsers.Select(pu => pu.Id); - - providerUsers.First().UserId = deletingUser.Id; - foreach (var providerUser in providerUsers) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, + [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser providerUser, + SutProvider sutProvider) { providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key); + + await sutProvider.GetDependency().Received().UpsertAsync(provider); + await sutProvider.GetDependency().Received() + .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); } - providerUsers.Last().ProviderId = default; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new ProviderUser[] { }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id)); - Assert.Equal("Provider must have at least one confirmed ProviderAdmin.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsersAsync_Success(Provider provider, User deletingUser, ICollection providerUsers, - [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser remainingOwner, - SutProvider sutProvider) - { - var userIds = providerUsers.Select(pu => pu.Id); - - providerUsers.First().UserId = deletingUser.Id; - foreach (var providerUser in providerUsers) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) { - providerUser.ProviderId = provider.Id; + provider.Id = default; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(provider)); + Assert.Contains("Cannot create provider this way.", exception.Message); } - providerUsers.Last().ProviderId = default; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new[] { remainingOwner }); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateAsync_Success(Provider provider, SutProvider sutProvider) + { + await sutProvider.Sut.UpdateAsync(provider); + } - var result = await sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_ProviderIdIsInvalid_Throws(ProviderUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - Assert.NotEmpty(result); - Assert.Equal("You cannot remove yourself.", result[0].Item2); - Assert.Equal("", result[1].Item2); - Assert.Equal("Invalid user.", result[2].Item2); - } + await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AddOrganization_OrganizationAlreadyBelongsToAProvider_Throws(Provider provider, - Organization organization, ProviderOrganization po, User user, string key, - SutProvider sutProvider) - { - po.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByOrganizationId(organization.Id) - .Returns(po); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); + await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); + } - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key)); - Assert.Equal("Organization already belongs to a provider.", exception.Message); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_EmailsInvalid_Throws(Provider provider, ProviderUserInvite providerUserInvite, + SutProvider sutProvider) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); + sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AddOrganization_Success(Provider provider, Organization organization, User user, string key, - SutProvider sutProvider) - { - organization.PlanType = PlanType.EnterpriseAnnually; + providerUserInvite.UserIdentifiers = null; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerOrganizationRepository = sutProvider.GetDependency(); - providerOrganizationRepository.GetByOrganizationId(organization.Id).ReturnsNull(); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(providerUserInvite)); + } - await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_AlreadyInvited(Provider provider, ProviderUserInvite providerUserInvite, + SutProvider sutProvider) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(1); + sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency() - .Received().LogProviderOrganizationEventAsync(Arg.Any(), - EventType.ProviderOrganization_Added); - } + var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); + Assert.Empty(result); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CreateOrganizationAsync_Success(Provider provider, OrganizationSignup organizationSignup, - Organization organization, string clientOwnerEmail, User user, SutProvider sutProvider) - { - organizationSignup.Plan = PlanType.EnterpriseAnnually; + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_Success(Provider provider, ProviderUserInvite providerUserInvite, + SutProvider sutProvider) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(0); + sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerOrganizationRepository = sutProvider.GetDependency(); - sutProvider.GetDependency().SignUpAsync(organizationSignup, true) - .Returns(Tuple.Create(organization, null as OrganizationUser)); + var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); + Assert.Equal(providerUserInvite.UserIdentifiers.Count(), result.Count); + Assert.True(result.TrueForAll(pu => pu.Status == ProviderUserStatusType.Invited), "Status must be invited"); + Assert.True(result.TrueForAll(pu => pu.ProviderId == providerUserInvite.ProviderId), "Provider Id must be correct"); + } - var providerOrganization = - await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ResendInviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); + await Assert.ThrowsAsync(() => sutProvider.Sut.ResendInvitesAsync(invite)); + } - await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency() - .Received().LogProviderOrganizationEventAsync(providerOrganization, - EventType.ProviderOrganization_Created); - await sutProvider.GetDependency() - .Received().InviteUsersAsync(organization.Id, user.Id, Arg.Is>( - t => t.Count() == 1 && - t.First().Item1.Emails.Count() == 1 && - t.First().Item1.Emails.First() == clientOwnerEmail && - t.First().Item1.Type == OrganizationUserType.Owner && - t.First().Item1.AccessAll && - t.First().Item2 == null)); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ResendInvitesAsync_Errors(Provider provider, + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, + [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu4, + SutProvider sutProvider) + { + var providerUsers = new[] { pu1, pu2, pu3, pu4 }; + pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_ProviderOrganizationIsInvalid_Throws(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) - .ReturnsNull(); + var invite = new ProviderUserInvite + { + UserIdentifiers = providerUsers.Select(pu => pu.Id), + ProviderId = provider.Id + }; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); - Assert.Equal("Invalid organization.", exception.Message); - } + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_ProviderOrganizationBelongsToWrongProvider_Throws(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) - .Returns(providerOrganization); + var result = await sutProvider.Sut.ResendInvitesAsync(invite); + Assert.Equal("", result[0].Item2); + Assert.Equal("User invalid.", result[1].Item2); + Assert.Equal("User invalid.", result[2].Item2); + Assert.Equal("User invalid.", result[3].Item2); + } - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); - Assert.Equal("Invalid organization.", exception.Message); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ResendInvitesAsync_Success(Provider provider, IEnumerable providerUsers, + SutProvider sutProvider) + { + foreach (var providerUser in providerUsers) + { + providerUser.ProviderId = provider.Id; + providerUser.Status = ProviderUserStatusType.Invited; + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_HasNoOwners_Throws(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - providerOrganization.ProviderId = provider.Id; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) - .Returns(providerOrganization); - sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) - .ReturnsForAnyArgs(false); + var invite = new ProviderUserInvite + { + UserIdentifiers = providerUsers.Select(pu => pu.Id), + ProviderId = provider.Id + }; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); - Assert.Equal("Organization needs to have at least one confirmed owner.", exception.Message); - } + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_Success(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - providerOrganization.ProviderId = provider.Id; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerOrganizationRepository = sutProvider.GetDependency(); - providerOrganizationRepository.GetByIdAsync(providerOrganization.Id).Returns(providerOrganization); - sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) - .ReturnsForAnyArgs(true); + var result = await sutProvider.Sut.ResendInvitesAsync(invite); + Assert.True(result.All(r => r.Item2 == "")); + } - await sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id); - await providerOrganizationRepository.Received().DeleteAsync(providerOrganization); - await sutProvider.GetDependency().Received() - .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_UserIsInvalid_Throws(SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(default, default, default)); + Assert.Equal("User invalid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_AlreadyAccepted_Throws( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); + Assert.Equal("Already accepted.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_TokenIsInvalid_Throws( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); + Assert.Equal("Invalid token.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_WrongEmail_Throws( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token)); + Assert.Equal("User email does not match invite.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); + Assert.Null(pu.Email); + Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); + Assert.Equal(user.Id, pu.UserId); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUsersAsync_NoValid( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, + [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, + SutProvider sutProvider) + { + pu1.ProviderId = pu3.ProviderId; + var providerUsers = new[] { pu1, pu2, pu3 }; + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, default); + + Assert.Empty(result); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUsersAsync_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, User u1, + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, User u2, + [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, User u3, + Provider provider, User user, SutProvider sutProvider) + { + pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; + pu1.UserId = u1.Id; + pu2.UserId = u2.Id; + pu3.UserId = u3.Id; + var providerUsers = new[] { pu1, pu2, pu3 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(provider.Id).Returns(provider); + var userRepository = sutProvider.GetDependency(); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { u1, u2, u3 }); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, user.Id); + + Assert.Equal("Invalid user.", result[0].Item2); + Assert.Equal("", result[1].Item2); + Assert.Equal("Invalid user.", result[2].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUserAsync_UserIdIsInvalid_Throws(ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.Id = default; + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveUserAsync(providerUser, default)); + Assert.Equal("Invite the user first.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUserAsync_Success( + [ProviderUser(type: ProviderUserType.ProviderAdmin)] ProviderUser providerUser, User savingUser, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + await sutProvider.Sut.SaveUserAsync(providerUser, savingUser.Id); + await providerUserRepository.Received().ReplaceAsync(providerUser); + await sutProvider.GetDependency().Received() + .LogProviderUserEventAsync(providerUser, EventType.ProviderUser_Updated, null); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsersAsync_NoRemainingOwner_Throws(Provider provider, User deletingUser, + ICollection providerUsers, SutProvider sutProvider) + { + var userIds = providerUsers.Select(pu => pu.Id); + + providerUsers.First().UserId = deletingUser.Id; + foreach (var providerUser in providerUsers) + { + providerUser.ProviderId = provider.Id; + } + providerUsers.Last().ProviderId = default; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new ProviderUser[] { }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id)); + Assert.Equal("Provider must have at least one confirmed ProviderAdmin.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsersAsync_Success(Provider provider, User deletingUser, ICollection providerUsers, + [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser remainingOwner, + SutProvider sutProvider) + { + var userIds = providerUsers.Select(pu => pu.Id); + + providerUsers.First().UserId = deletingUser.Id; + foreach (var providerUser in providerUsers) + { + providerUser.ProviderId = provider.Id; + } + providerUsers.Last().ProviderId = default; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new[] { remainingOwner }); + + var result = await sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id); + + Assert.NotEmpty(result); + Assert.Equal("You cannot remove yourself.", result[0].Item2); + Assert.Equal("", result[1].Item2); + Assert.Equal("Invalid user.", result[2].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AddOrganization_OrganizationAlreadyBelongsToAProvider_Throws(Provider provider, + Organization organization, ProviderOrganization po, User user, string key, + SutProvider sutProvider) + { + po.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByOrganizationId(organization.Id) + .Returns(po); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key)); + Assert.Equal("Organization already belongs to a provider.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AddOrganization_Success(Provider provider, Organization organization, User user, string key, + SutProvider sutProvider) + { + organization.PlanType = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + providerOrganizationRepository.GetByOrganizationId(organization.Id).ReturnsNull(); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key); + + await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .Received().LogProviderOrganizationEventAsync(Arg.Any(), + EventType.ProviderOrganization_Added); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CreateOrganizationAsync_Success(Provider provider, OrganizationSignup organizationSignup, + Organization organization, string clientOwnerEmail, User user, SutProvider sutProvider) + { + organizationSignup.Plan = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + sutProvider.GetDependency().SignUpAsync(organizationSignup, true) + .Returns(Tuple.Create(organization, null as OrganizationUser)); + + var providerOrganization = + await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); + + await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .Received().LogProviderOrganizationEventAsync(providerOrganization, + EventType.ProviderOrganization_Created); + await sutProvider.GetDependency() + .Received().InviteUsersAsync(organization.Id, user.Id, Arg.Is>( + t => t.Count() == 1 && + t.First().Item1.Emails.Count() == 1 && + t.First().Item1.Emails.First() == clientOwnerEmail && + t.First().Item1.Type == OrganizationUserType.Owner && + t.First().Item1.AccessAll && + t.First().Item2 == null)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_ProviderOrganizationIsInvalid_Throws(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) + .ReturnsNull(); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); + Assert.Equal("Invalid organization.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_ProviderOrganizationBelongsToWrongProvider_Throws(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) + .Returns(providerOrganization); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); + Assert.Equal("Invalid organization.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_HasNoOwners_Throws(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + providerOrganization.ProviderId = provider.Id; + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) + .Returns(providerOrganization); + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) + .ReturnsForAnyArgs(false); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); + Assert.Equal("Organization needs to have at least one confirmed owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_Success(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + providerOrganization.ProviderId = provider.Id; + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + providerOrganizationRepository.GetByIdAsync(providerOrganization.Id).Returns(providerOrganization); + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) + .ReturnsForAnyArgs(true); + + await sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id); + await providerOrganizationRepository.Received().DeleteAsync(providerOrganization); + await sutProvider.GetDependency().Received() + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + } } } diff --git a/src/Admin/AdminSettings.cs b/src/Admin/AdminSettings.cs index 6941bbc8f6..64de4f0837 100644 --- a/src/Admin/AdminSettings.cs +++ b/src/Admin/AdminSettings.cs @@ -1,15 +1,16 @@ -namespace Bit.Admin; - -public class AdminSettings +namespace Bit.Admin { - public virtual string Admins { get; set; } - public virtual CloudflareSettings Cloudflare { get; set; } - public int? DeleteTrashDaysAgo { get; set; } - - public class CloudflareSettings + public class AdminSettings { - public string ZoneId { get; set; } - public string AuthEmail { get; set; } - public string AuthKey { get; set; } + public virtual string Admins { get; set; } + public virtual CloudflareSettings Cloudflare { get; set; } + public int? DeleteTrashDaysAgo { get; set; } + + public class CloudflareSettings + { + public string ZoneId { get; set; } + public string AuthEmail { get; set; } + public string AuthKey { get; set; } + } } } diff --git a/src/Admin/Controllers/ErrorController.cs b/src/Admin/Controllers/ErrorController.cs index 9216537ff9..af60912045 100644 --- a/src/Admin/Controllers/ErrorController.cs +++ b/src/Admin/Controllers/ErrorController.cs @@ -1,23 +1,24 @@ using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -public class ErrorController : Controller +namespace Bit.Admin.Controllers { - [Route("/error")] - public IActionResult Error(int? statusCode = null) + public class ErrorController : Controller { - var exceptionHandlerPathFeature = HttpContext.Features.Get(); - TempData["Error"] = HttpContext.Features.Get()?.Error.Message; + [Route("/error")] + public IActionResult Error(int? statusCode = null) + { + var exceptionHandlerPathFeature = HttpContext.Features.Get(); + TempData["Error"] = HttpContext.Features.Get()?.Error.Message; - if (exceptionHandlerPathFeature != null) - { - return Redirect(exceptionHandlerPathFeature.Path); - } - else - { - return Redirect("/Home"); + if (exceptionHandlerPathFeature != null) + { + return Redirect(exceptionHandlerPathFeature.Path); + } + else + { + return Redirect("/Home"); + } } } } diff --git a/src/Admin/Controllers/HomeController.cs b/src/Admin/Controllers/HomeController.cs index 5e3b76ebb8..fe93eef26c 100644 --- a/src/Admin/Controllers/HomeController.cs +++ b/src/Admin/Controllers/HomeController.cs @@ -6,108 +6,109 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Newtonsoft.Json; -namespace Bit.Admin.Controllers; - -public class HomeController : Controller +namespace Bit.Admin.Controllers { - private readonly GlobalSettings _globalSettings; - private readonly HttpClient _httpClient = new HttpClient(); - private readonly ILogger _logger; - - public HomeController(GlobalSettings globalSettings, ILogger logger) + public class HomeController : Controller { - _globalSettings = globalSettings; - _logger = logger; - } + private readonly GlobalSettings _globalSettings; + private readonly HttpClient _httpClient = new HttpClient(); + private readonly ILogger _logger; - [Authorize] - public IActionResult Index() - { - return View(new HomeModel + public HomeController(GlobalSettings globalSettings, ILogger logger) { - GlobalSettings = _globalSettings, - CurrentVersion = Core.Utilities.CoreHelpers.GetVersion() - }); - } + _globalSettings = globalSettings; + _logger = logger; + } - public IActionResult Error() - { - return View(new ErrorViewModel + [Authorize] + public IActionResult Index() { - RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier - }); - } - - - public async Task GetLatestVersion(ProjectType project, CancellationToken cancellationToken) - { - var requestUri = $"https://selfhost.bitwarden.com/version.json"; - try - { - var response = await _httpClient.GetAsync(requestUri, cancellationToken); - if (response.IsSuccessStatusCode) + return View(new HomeModel { - var latestVersions = JsonConvert.DeserializeObject(await response.Content.ReadAsStringAsync()); - return project switch + GlobalSettings = _globalSettings, + CurrentVersion = Core.Utilities.CoreHelpers.GetVersion() + }); + } + + public IActionResult Error() + { + return View(new ErrorViewModel + { + RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier + }); + } + + + public async Task GetLatestVersion(ProjectType project, CancellationToken cancellationToken) + { + var requestUri = $"https://selfhost.bitwarden.com/version.json"; + try + { + var response = await _httpClient.GetAsync(requestUri, cancellationToken); + if (response.IsSuccessStatusCode) { - ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion), - ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion), - _ => throw new System.NotImplementedException(), - }; + var latestVersions = JsonConvert.DeserializeObject(await response.Content.ReadAsStringAsync()); + return project switch + { + ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion), + ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion), + _ => throw new System.NotImplementedException(), + }; + } } - } - catch (HttpRequestException e) - { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); - return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError }; - } - - return new JsonResult("-"); - } - - public async Task GetInstalledWebVersion(CancellationToken cancellationToken) - { - var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json"; - try - { - var response = await _httpClient.GetAsync(requestUri, cancellationToken); - if (response.IsSuccessStatusCode) + catch (HttpRequestException e) { - using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken); - var root = jsonDocument.RootElement; - return new JsonResult(root.GetProperty("version").GetString()); + _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError }; } + + return new JsonResult("-"); } - catch (HttpRequestException e) + + public async Task GetInstalledWebVersion(CancellationToken cancellationToken) { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); - return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; + var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json"; + try + { + var response = await _httpClient.GetAsync(requestUri, cancellationToken); + if (response.IsSuccessStatusCode) + { + using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken); + var root = jsonDocument.RootElement; + return new JsonResult(root.GetProperty("version").GetString()); + } + } + catch (HttpRequestException e) + { + _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; + } + + return new JsonResult("-"); } - return new JsonResult("-"); + private class LatestVersions + { + [JsonProperty("versions")] + public Versions Versions { get; set; } + } + + private class Versions + { + [JsonProperty("coreVersion")] + public string CoreVersion { get; set; } + + [JsonProperty("webVersion")] + public string WebVersion { get; set; } + + [JsonProperty("keyConnectorVersion")] + public string KeyConnectorVersion { get; set; } + } } - private class LatestVersions + public enum ProjectType { - [JsonProperty("versions")] - public Versions Versions { get; set; } - } - - private class Versions - { - [JsonProperty("coreVersion")] - public string CoreVersion { get; set; } - - [JsonProperty("webVersion")] - public string WebVersion { get; set; } - - [JsonProperty("keyConnectorVersion")] - public string KeyConnectorVersion { get; set; } + Core, + Web, } } - -public enum ProjectType -{ - Core, - Web, -} diff --git a/src/Admin/Controllers/InfoController.cs b/src/Admin/Controllers/InfoController.cs index 0c097fde7d..7f39da6ed8 100644 --- a/src/Admin/Controllers/InfoController.cs +++ b/src/Admin/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -public class InfoController : Controller +namespace Bit.Admin.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/src/Admin/Controllers/LoginController.cs b/src/Admin/Controllers/LoginController.cs index 47f9d5b34a..a8e3e9dd0d 100644 --- a/src/Admin/Controllers/LoginController.cs +++ b/src/Admin/Controllers/LoginController.cs @@ -3,90 +3,91 @@ using Bit.Core.Identity; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -public class LoginController : Controller +namespace Bit.Admin.Controllers { - private readonly PasswordlessSignInManager _signInManager; - - public LoginController( - PasswordlessSignInManager signInManager) + public class LoginController : Controller { - _signInManager = signInManager; - } + private readonly PasswordlessSignInManager _signInManager; - public IActionResult Index(string returnUrl = null, int? error = null, int? success = null, - bool accessDenied = false) - { - if (!error.HasValue && accessDenied) + public LoginController( + PasswordlessSignInManager signInManager) { - error = 4; + _signInManager = signInManager; } - return View(new LoginModel + public IActionResult Index(string returnUrl = null, int? error = null, int? success = null, + bool accessDenied = false) { - ReturnUrl = returnUrl, - Error = GetMessage(error), - Success = GetMessage(success) - }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Index(LoginModel model) - { - if (ModelState.IsValid) - { - await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl); - return RedirectToAction("Index", new + if (!error.HasValue && accessDenied) { - success = 3 + error = 4; + } + + return View(new LoginModel + { + ReturnUrl = returnUrl, + Error = GetMessage(error), + Success = GetMessage(success) }); } - return View(model); - } - - public async Task Confirm(string email, string token, string returnUrl) - { - var result = await _signInManager.PasswordlessSignInAsync(email, token, true); - if (!result.Succeeded) + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Index(LoginModel model) { + if (ModelState.IsValid) + { + await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl); + return RedirectToAction("Index", new + { + success = 3 + }); + } + + return View(model); + } + + public async Task Confirm(string email, string token, string returnUrl) + { + var result = await _signInManager.PasswordlessSignInAsync(email, token, true); + if (!result.Succeeded) + { + return RedirectToAction("Index", new + { + error = 2 + }); + } + + if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl)) + { + return Redirect(returnUrl); + } + + return RedirectToAction("Index", "Home"); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Logout() + { + await _signInManager.SignOutAsync(); return RedirectToAction("Index", new { - error = 2 + success = 1 }); } - if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl)) + private string GetMessage(int? messageCode) { - return Redirect(returnUrl); + return messageCode switch + { + 1 => "You have been logged out.", + 2 => "This login confirmation link is invalid. Try logging in again.", + 3 => "If a valid admin user with this email address exists, " + + "we've sent you an email with a secure link to log in.", + 4 => "Access denied. Please log in.", + _ => null, + }; } - - return RedirectToAction("Index", "Home"); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Logout() - { - await _signInManager.SignOutAsync(); - return RedirectToAction("Index", new - { - success = 1 - }); - } - - private string GetMessage(int? messageCode) - { - return messageCode switch - { - 1 => "You have been logged out.", - 2 => "This login confirmation link is invalid. Try logging in again.", - 3 => "If a valid admin user with this email address exists, " + - "we've sent you an email with a secure link to log in.", - 4 => "Access denied. Please log in.", - _ => null, - }; } } diff --git a/src/Admin/Controllers/LogsController.cs b/src/Admin/Controllers/LogsController.cs index 449c8cc860..feb1a91b24 100644 --- a/src/Admin/Controllers/LogsController.cs +++ b/src/Admin/Controllers/LogsController.cs @@ -7,86 +7,87 @@ using Microsoft.Azure.Cosmos; using Microsoft.Azure.Cosmos.Linq; using Serilog.Events; -namespace Bit.Admin.Controllers; - -[Authorize] -[SelfHosted(NotSelfHostedOnly = true)] -public class LogsController : Controller +namespace Bit.Admin.Controllers { - private const string Database = "Diagnostics"; - private const string Container = "Logs"; - - private readonly GlobalSettings _globalSettings; - - public LogsController(GlobalSettings globalSettings) + [Authorize] + [SelfHosted(NotSelfHostedOnly = true)] + public class LogsController : Controller { - _globalSettings = globalSettings; - } + private const string Database = "Diagnostics"; + private const string Container = "Logs"; - public async Task Index(string cursor = null, int count = 50, - LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null) - { - using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, - _globalSettings.DocumentDb.Key)) + private readonly GlobalSettings _globalSettings; + + public LogsController(GlobalSettings globalSettings) { - var cosmosContainer = client.GetContainer(Database, Container); - var query = cosmosContainer.GetItemLinqQueryable( - requestOptions: new QueryRequestOptions() - { - MaxItemCount = count - }, - continuationToken: cursor - ).AsQueryable(); - - if (level.HasValue) - { - query = query.Where(l => l.Level == level.Value.ToString()); - } - if (!string.IsNullOrWhiteSpace(project)) - { - query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project); - } - if (start.HasValue) - { - query = query.Where(l => l.Timestamp >= start.Value); - } - if (end.HasValue) - { - query = query.Where(l => l.Timestamp <= end.Value); - } - var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator(); - var response = await feedIterator.ReadNextAsync(); - - return View(new LogsModel - { - Level = level, - Project = project, - Start = start, - End = end, - Items = response.ToList(), - Count = count, - Cursor = cursor, - NextCursor = response.ContinuationToken - }); + _globalSettings = globalSettings; } - } - public async Task View(Guid id) - { - using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, - _globalSettings.DocumentDb.Key)) + public async Task Index(string cursor = null, int count = 50, + LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null) { - var cosmosContainer = client.GetContainer(Database, Container); - var query = cosmosContainer.GetItemLinqQueryable() - .AsQueryable() - .Where(l => l.Id == id.ToString()); - - var response = await query.ToFeedIterator().ReadNextAsync(); - if (response == null || response.Count == 0) + using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, + _globalSettings.DocumentDb.Key)) { - return RedirectToAction("Index"); + var cosmosContainer = client.GetContainer(Database, Container); + var query = cosmosContainer.GetItemLinqQueryable( + requestOptions: new QueryRequestOptions() + { + MaxItemCount = count + }, + continuationToken: cursor + ).AsQueryable(); + + if (level.HasValue) + { + query = query.Where(l => l.Level == level.Value.ToString()); + } + if (!string.IsNullOrWhiteSpace(project)) + { + query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project); + } + if (start.HasValue) + { + query = query.Where(l => l.Timestamp >= start.Value); + } + if (end.HasValue) + { + query = query.Where(l => l.Timestamp <= end.Value); + } + var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator(); + var response = await feedIterator.ReadNextAsync(); + + return View(new LogsModel + { + Level = level, + Project = project, + Start = start, + End = end, + Items = response.ToList(), + Count = count, + Cursor = cursor, + NextCursor = response.ContinuationToken + }); + } + } + + public async Task View(Guid id) + { + using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, + _globalSettings.DocumentDb.Key)) + { + var cosmosContainer = client.GetContainer(Database, Container); + var query = cosmosContainer.GetItemLinqQueryable() + .AsQueryable() + .Where(l => l.Id == id.ToString()); + + var response = await query.ToFeedIterator().ReadNextAsync(); + if (response == null || response.Count == 0) + { + return RedirectToAction("Index"); + } + return View(response.First()); } - return View(response.First()); } } } diff --git a/src/Admin/Controllers/OrganizationsController.cs b/src/Admin/Controllers/OrganizationsController.cs index 76c00d025b..eccc7ced65 100644 --- a/src/Admin/Controllers/OrganizationsController.cs +++ b/src/Admin/Controllers/OrganizationsController.cs @@ -11,206 +11,207 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -[Authorize] -public class OrganizationsController : Controller +namespace Bit.Admin.Controllers { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand; - private readonly ICipherRepository _cipherRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IGroupRepository _groupRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IPaymentService _paymentService; - private readonly ILicensingService _licensingService; - private readonly IApplicationCacheService _applicationCacheService; - private readonly GlobalSettings _globalSettings; - private readonly IReferenceEventService _referenceEventService; - private readonly IUserService _userService; - private readonly ILogger _logger; - - public OrganizationsController( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand, - ICipherRepository cipherRepository, - ICollectionRepository collectionRepository, - IGroupRepository groupRepository, - IPolicyRepository policyRepository, - IPaymentService paymentService, - ILicensingService licensingService, - IApplicationCacheService applicationCacheService, - GlobalSettings globalSettings, - IReferenceEventService referenceEventService, - IUserService userService, - ILogger logger) + [Authorize] + public class OrganizationsController : Controller { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _organizationConnectionRepository = organizationConnectionRepository; - _syncSponsorshipsCommand = syncSponsorshipsCommand; - _cipherRepository = cipherRepository; - _collectionRepository = collectionRepository; - _groupRepository = groupRepository; - _policyRepository = policyRepository; - _paymentService = paymentService; - _licensingService = licensingService; - _applicationCacheService = applicationCacheService; - _globalSettings = globalSettings; - _referenceEventService = referenceEventService; - _userService = userService; - _logger = logger; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand; + private readonly ICipherRepository _cipherRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IGroupRepository _groupRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IPaymentService _paymentService; + private readonly ILicensingService _licensingService; + private readonly IApplicationCacheService _applicationCacheService; + private readonly GlobalSettings _globalSettings; + private readonly IReferenceEventService _referenceEventService; + private readonly IUserService _userService; + private readonly ILogger _logger; - public async Task Index(string name = null, string userEmail = null, bool? paid = null, - int page = 1, int count = 25) - { - if (page < 1) + public OrganizationsController( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand, + ICipherRepository cipherRepository, + ICollectionRepository collectionRepository, + IGroupRepository groupRepository, + IPolicyRepository policyRepository, + IPaymentService paymentService, + ILicensingService licensingService, + IApplicationCacheService applicationCacheService, + GlobalSettings globalSettings, + IReferenceEventService referenceEventService, + IUserService userService, + ILogger logger) { - page = 1; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _organizationConnectionRepository = organizationConnectionRepository; + _syncSponsorshipsCommand = syncSponsorshipsCommand; + _cipherRepository = cipherRepository; + _collectionRepository = collectionRepository; + _groupRepository = groupRepository; + _policyRepository = policyRepository; + _paymentService = paymentService; + _licensingService = licensingService; + _applicationCacheService = applicationCacheService; + _globalSettings = globalSettings; + _referenceEventService = referenceEventService; + _userService = userService; + _logger = logger; } - if (count < 1) + public async Task Index(string name = null, string userEmail = null, bool? paid = null, + int page = 1, int count = 25) { - count = 1; + if (page < 1) + { + page = 1; + } + + if (count < 1) + { + count = 1; + } + + var skip = (page - 1) * count; + var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count); + return View(new OrganizationsModel + { + Items = organizations as List, + Name = string.IsNullOrWhiteSpace(name) ? null : name, + UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, + Paid = paid, + Page = page, + Count = count, + Action = _globalSettings.SelfHosted ? "View" : "Edit", + SelfHosted = _globalSettings.SelfHosted + }); } - var skip = (page - 1) * count; - var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count); - return View(new OrganizationsModel + public async Task View(Guid id) { - Items = organizations as List, - Name = string.IsNullOrWhiteSpace(name) ? null : name, - UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, - Paid = paid, - Page = page, - Count = count, - Action = _globalSettings.SelfHosted ? "View" : "Edit", - SelfHosted = _globalSettings.SelfHosted - }); - } + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + return RedirectToAction("Index"); + } - public async Task View(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); + var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); + IEnumerable groups = null; + if (organization.UseGroups) + { + groups = await _groupRepository.GetManyByOrganizationIdAsync(id); + } + IEnumerable policies = null; + if (organization.UsePolicies) + { + policies = await _policyRepository.GetManyByOrganizationIdAsync(id); + } + var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); + var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; + return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies)); + } + + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id) { + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + return RedirectToAction("Index"); + } + + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); + var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); + IEnumerable groups = null; + if (organization.UseGroups) + { + groups = await _groupRepository.GetManyByOrganizationIdAsync(id); + } + IEnumerable policies = null; + if (organization.UsePolicies) + { + policies = await _policyRepository.GetManyByOrganizationIdAsync(id); + } + var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); + var billingInfo = await _paymentService.GetBillingAsync(organization); + var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; + return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies, + billingInfo, billingSyncConnection, _globalSettings)); + } + + [HttpPost] + [ValidateAntiForgeryToken] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id, OrganizationEditModel model) + { + var organization = await _organizationRepository.GetByIdAsync(id); + model.ToOrganization(organization); + await _organizationRepository.ReplaceAsync(organization); + await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization) + { + EventRaisedByUser = _userService.GetUserName(User), + SalesAssistedTrialStarted = model.SalesAssistedTrialStarted, + }); + return RedirectToAction("Edit", new { id }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Delete(Guid id) + { + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization != null) + { + await _organizationRepository.DeleteAsync(organization); + await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); + } + return RedirectToAction("Index"); } - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); - var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); - IEnumerable groups = null; - if (organization.UseGroups) + public async Task TriggerBillingSync(Guid id) { - groups = await _groupRepository.GetManyByOrganizationIdAsync(id); - } - IEnumerable policies = null; - if (organization.UsePolicies) - { - policies = await _policyRepository.GetManyByOrganizationIdAsync(id); - } - var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); - var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; - return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies)); - } + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + return RedirectToAction("Index"); + } + var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); + if (connection != null) + { + try + { + var config = connection.GetConfig(); + await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection); + TempData["ConnectionActivated"] = id; + TempData["ConnectionError"] = null; + } + catch (Exception ex) + { + TempData["ConnectionError"] = ex.Message; + _logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id); + } - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { + if (_globalSettings.SelfHosted) + { + return RedirectToAction("View", new { id }); + } + else + { + return RedirectToAction("Edit", new { id }); + } + } return RedirectToAction("Index"); } - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); - var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); - IEnumerable groups = null; - if (organization.UseGroups) - { - groups = await _groupRepository.GetManyByOrganizationIdAsync(id); - } - IEnumerable policies = null; - if (organization.UsePolicies) - { - policies = await _policyRepository.GetManyByOrganizationIdAsync(id); - } - var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); - var billingInfo = await _paymentService.GetBillingAsync(organization); - var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; - return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies, - billingInfo, billingSyncConnection, _globalSettings)); } - - [HttpPost] - [ValidateAntiForgeryToken] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id, OrganizationEditModel model) - { - var organization = await _organizationRepository.GetByIdAsync(id); - model.ToOrganization(organization); - await _organizationRepository.ReplaceAsync(organization); - await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization) - { - EventRaisedByUser = _userService.GetUserName(User), - SalesAssistedTrialStarted = model.SalesAssistedTrialStarted, - }); - return RedirectToAction("Edit", new { id }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Delete(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization != null) - { - await _organizationRepository.DeleteAsync(organization); - await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); - } - - return RedirectToAction("Index"); - } - - public async Task TriggerBillingSync(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - return RedirectToAction("Index"); - } - var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); - if (connection != null) - { - try - { - var config = connection.GetConfig(); - await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection); - TempData["ConnectionActivated"] = id; - TempData["ConnectionError"] = null; - } - catch (Exception ex) - { - TempData["ConnectionError"] = ex.Message; - _logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id); - } - - if (_globalSettings.SelfHosted) - { - return RedirectToAction("View", new { id }); - } - else - { - return RedirectToAction("Edit", new { id }); - } - } - return RedirectToAction("Index"); - } - } diff --git a/src/Admin/Controllers/ProvidersController.cs b/src/Admin/Controllers/ProvidersController.cs index a141b9fd02..e0e4484998 100644 --- a/src/Admin/Controllers/ProvidersController.cs +++ b/src/Admin/Controllers/ProvidersController.cs @@ -7,127 +7,128 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -[Authorize] -[SelfHosted(NotSelfHostedOnly = true)] -public class ProvidersController : Controller +namespace Bit.Admin.Controllers { - private readonly IProviderRepository _providerRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly GlobalSettings _globalSettings; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IProviderService _providerService; - - public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService, - GlobalSettings globalSettings, IApplicationCacheService applicationCacheService) - { - _providerRepository = providerRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _providerService = providerService; - _globalSettings = globalSettings; - _applicationCacheService = applicationCacheService; - } - - public async Task Index(string name = null, string userEmail = null, int page = 1, int count = 25) - { - if (page < 1) - { - page = 1; - } - - if (count < 1) - { - count = 1; - } - - var skip = (page - 1) * count; - var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count); - return View(new ProvidersModel - { - Items = providers as List, - Name = string.IsNullOrWhiteSpace(name) ? null : name, - UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, - Page = page, - Count = count, - Action = _globalSettings.SelfHosted ? "View" : "Edit", - SelfHosted = _globalSettings.SelfHosted - }); - } - - public IActionResult Create(string ownerEmail = null) - { - return View(new CreateProviderModel - { - OwnerEmail = ownerEmail - }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Create(CreateProviderModel model) - { - if (!ModelState.IsValid) - { - return View(model); - } - - await _providerService.CreateAsync(model.OwnerEmail); - - return RedirectToAction("Index"); - } - - public async Task View(Guid id) - { - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - return RedirectToAction("Index"); - } - - var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); - var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); - return View(new ProviderViewModel(provider, users, providerOrganizations)); - } - + [Authorize] [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id) + public class ProvidersController : Controller { - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) + private readonly IProviderRepository _providerRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly GlobalSettings _globalSettings; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IProviderService _providerService; + + public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService, + GlobalSettings globalSettings, IApplicationCacheService applicationCacheService) { + _providerRepository = providerRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _providerService = providerService; + _globalSettings = globalSettings; + _applicationCacheService = applicationCacheService; + } + + public async Task Index(string name = null, string userEmail = null, int page = 1, int count = 25) + { + if (page < 1) + { + page = 1; + } + + if (count < 1) + { + count = 1; + } + + var skip = (page - 1) * count; + var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count); + return View(new ProvidersModel + { + Items = providers as List, + Name = string.IsNullOrWhiteSpace(name) ? null : name, + UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, + Page = page, + Count = count, + Action = _globalSettings.SelfHosted ? "View" : "Edit", + SelfHosted = _globalSettings.SelfHosted + }); + } + + public IActionResult Create(string ownerEmail = null) + { + return View(new CreateProviderModel + { + OwnerEmail = ownerEmail + }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Create(CreateProviderModel model) + { + if (!ModelState.IsValid) + { + return View(model); + } + + await _providerService.CreateAsync(model.OwnerEmail); + return RedirectToAction("Index"); } - var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); - var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); - return View(new ProviderEditModel(provider, users, providerOrganizations)); - } - - [HttpPost] - [ValidateAntiForgeryToken] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id, ProviderEditModel model) - { - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) + public async Task View(Guid id) { - return RedirectToAction("Index"); + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + return RedirectToAction("Index"); + } + + var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); + var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); + return View(new ProviderViewModel(provider, users, providerOrganizations)); } - model.ToProvider(provider); - await _providerRepository.ReplaceAsync(provider); - await _applicationCacheService.UpsertProviderAbilityAsync(provider); - return RedirectToAction("Edit", new { id }); - } + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id) + { + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + return RedirectToAction("Index"); + } - public async Task ResendInvite(Guid ownerId, Guid providerId) - { - await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId); - TempData["InviteResentTo"] = ownerId; - return RedirectToAction("Edit", new { id = providerId }); + var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); + var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); + return View(new ProviderEditModel(provider, users, providerOrganizations)); + } + + [HttpPost] + [ValidateAntiForgeryToken] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id, ProviderEditModel model) + { + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + return RedirectToAction("Index"); + } + + model.ToProvider(provider); + await _providerRepository.ReplaceAsync(provider); + await _applicationCacheService.UpsertProviderAbilityAsync(provider); + return RedirectToAction("Edit", new { id }); + } + + public async Task ResendInvite(Guid ownerId, Guid providerId) + { + await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId); + TempData["InviteResentTo"] = ownerId; + return RedirectToAction("Edit", new { id = providerId }); + } } } diff --git a/src/Admin/Controllers/ToolsController.cs b/src/Admin/Controllers/ToolsController.cs index 9bd6189b3e..9a483c137b 100644 --- a/src/Admin/Controllers/ToolsController.cs +++ b/src/Admin/Controllers/ToolsController.cs @@ -10,373 +10,410 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -[Authorize] -[SelfHosted(NotSelfHostedOnly = true)] -public class ToolsController : Controller +namespace Bit.Admin.Controllers { - private readonly GlobalSettings _globalSettings; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; - private readonly ITransactionRepository _transactionRepository; - private readonly IInstallationRepository _installationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IStripeAdapter _stripeAdapter; - - public ToolsController( - GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, - IOrganizationService organizationService, - IUserService userService, - ITransactionRepository transactionRepository, - IInstallationRepository installationRepository, - IOrganizationUserRepository organizationUserRepository, - ITaxRateRepository taxRateRepository, - IPaymentService paymentService, - IStripeAdapter stripeAdapter) + [Authorize] + [SelfHosted(NotSelfHostedOnly = true)] + public class ToolsController : Controller { - _globalSettings = globalSettings; - _organizationRepository = organizationRepository; - _organizationService = organizationService; - _userService = userService; - _transactionRepository = transactionRepository; - _installationRepository = installationRepository; - _organizationUserRepository = organizationUserRepository; - _taxRateRepository = taxRateRepository; - _paymentService = paymentService; - _stripeAdapter = stripeAdapter; - } + private readonly GlobalSettings _globalSettings; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly IUserService _userService; + private readonly ITransactionRepository _transactionRepository; + private readonly IInstallationRepository _installationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPaymentService _paymentService; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IStripeAdapter _stripeAdapter; - public IActionResult ChargeBraintree() - { - return View(new ChargeBraintreeModel()); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task ChargeBraintree(ChargeBraintreeModel model) - { - if (!ModelState.IsValid) + public ToolsController( + GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IOrganizationService organizationService, + IUserService userService, + ITransactionRepository transactionRepository, + IInstallationRepository installationRepository, + IOrganizationUserRepository organizationUserRepository, + ITaxRateRepository taxRateRepository, + IPaymentService paymentService, + IStripeAdapter stripeAdapter) { - return View(model); + _globalSettings = globalSettings; + _organizationRepository = organizationRepository; + _organizationService = organizationService; + _userService = userService; + _transactionRepository = transactionRepository; + _installationRepository = installationRepository; + _organizationUserRepository = organizationUserRepository; + _taxRateRepository = taxRateRepository; + _paymentService = paymentService; + _stripeAdapter = stripeAdapter; } - var btGateway = new Braintree.BraintreeGateway + public IActionResult ChargeBraintree() { - Environment = _globalSettings.Braintree.Production ? - Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, - MerchantId = _globalSettings.Braintree.MerchantId, - PublicKey = _globalSettings.Braintree.PublicKey, - PrivateKey = _globalSettings.Braintree.PrivateKey - }; + return View(new ChargeBraintreeModel()); + } - var btObjIdField = model.Id[0] == 'o' ? "organization_id" : "user_id"; - var btObjId = new Guid(model.Id.Substring(1, 32)); - - var transactionResult = await btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest + [HttpPost] + [ValidateAntiForgeryToken] + public async Task ChargeBraintree(ChargeBraintreeModel model) + { + if (!ModelState.IsValid) { - Amount = model.Amount.Value, - CustomerId = model.Id, - Options = new Braintree.TransactionOptionsRequest + return View(model); + } + + var btGateway = new Braintree.BraintreeGateway + { + Environment = _globalSettings.Braintree.Production ? + Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, + MerchantId = _globalSettings.Braintree.MerchantId, + PublicKey = _globalSettings.Braintree.PublicKey, + PrivateKey = _globalSettings.Braintree.PrivateKey + }; + + var btObjIdField = model.Id[0] == 'o' ? "organization_id" : "user_id"; + var btObjId = new Guid(model.Id.Substring(1, 32)); + + var transactionResult = await btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest + Amount = model.Amount.Value, + CustomerId = model.Id, + Options = new Braintree.TransactionOptionsRequest { - CustomField = $"{btObjIdField}:{btObjId}" + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{btObjIdField}:{btObjId}" + } + }, + CustomFields = new Dictionary + { + [btObjIdField] = btObjId.ToString() } - }, - CustomFields = new Dictionary - { - [btObjIdField] = btObjId.ToString() - } + }); + + if (!transactionResult.IsSuccess()) + { + ModelState.AddModelError(string.Empty, "Charge failed. " + + "Refer to Braintree admin portal for more information."); + } + else + { + model.TransactionId = transactionResult.Target.Id; + model.PayPalTransactionId = transactionResult.Target?.PayPalDetails?.CaptureId; + } + return View(model); + } + + public IActionResult CreateTransaction(Guid? organizationId = null, Guid? userId = null) + { + return View("CreateUpdateTransaction", new CreateUpdateTransactionModel + { + OrganizationId = organizationId, + UserId = userId }); - - if (!transactionResult.IsSuccess()) - { - ModelState.AddModelError(string.Empty, "Charge failed. " + - "Refer to Braintree admin portal for more information."); - } - else - { - model.TransactionId = transactionResult.Target.Id; - model.PayPalTransactionId = transactionResult.Target?.PayPalDetails?.CaptureId; - } - return View(model); - } - - public IActionResult CreateTransaction(Guid? organizationId = null, Guid? userId = null) - { - return View("CreateUpdateTransaction", new CreateUpdateTransactionModel - { - OrganizationId = organizationId, - UserId = userId - }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task CreateTransaction(CreateUpdateTransactionModel model) - { - if (!ModelState.IsValid) - { - return View("CreateUpdateTransaction", model); } - await _transactionRepository.CreateAsync(model.ToTransaction()); - if (model.UserId.HasValue) + [HttpPost] + [ValidateAntiForgeryToken] + public async Task CreateTransaction(CreateUpdateTransactionModel model) { - return RedirectToAction("Edit", "Users", new { id = model.UserId }); - } - else - { - return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); - } - } + if (!ModelState.IsValid) + { + return View("CreateUpdateTransaction", model); + } - public async Task EditTransaction(Guid id) - { - var transaction = await _transactionRepository.GetByIdAsync(id); - if (transaction == null) - { - return RedirectToAction("Index", "Home"); - } - return View("CreateUpdateTransaction", new CreateUpdateTransactionModel(transaction)); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task EditTransaction(Guid id, CreateUpdateTransactionModel model) - { - if (!ModelState.IsValid) - { - return View("CreateUpdateTransaction", model); - } - await _transactionRepository.ReplaceAsync(model.ToTransaction(id)); - if (model.UserId.HasValue) - { - return RedirectToAction("Edit", "Users", new { id = model.UserId }); - } - else - { - return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); - } - } - - public IActionResult PromoteAdmin() - { - return View(); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task PromoteAdmin(PromoteAdminModel model) - { - if (!ModelState.IsValid) - { - return View(model); + await _transactionRepository.CreateAsync(model.ToTransaction()); + if (model.UserId.HasValue) + { + return RedirectToAction("Edit", "Users", new { id = model.UserId }); + } + else + { + return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); + } } - var orgUsers = await _organizationUserRepository.GetManyByOrganizationAsync( - model.OrganizationId.Value, null); - var user = orgUsers.FirstOrDefault(u => u.UserId == model.UserId.Value); - if (user == null) + public async Task EditTransaction(Guid id) { - ModelState.AddModelError(nameof(model.UserId), "User Id not found in this organization."); - } - else if (user.Type != Core.Enums.OrganizationUserType.Admin) - { - ModelState.AddModelError(nameof(model.UserId), "User is not an admin of this organization."); + var transaction = await _transactionRepository.GetByIdAsync(id); + if (transaction == null) + { + return RedirectToAction("Index", "Home"); + } + return View("CreateUpdateTransaction", new CreateUpdateTransactionModel(transaction)); } - if (!ModelState.IsValid) + [HttpPost] + [ValidateAntiForgeryToken] + public async Task EditTransaction(Guid id, CreateUpdateTransactionModel model) { - return View(model); + if (!ModelState.IsValid) + { + return View("CreateUpdateTransaction", model); + } + await _transactionRepository.ReplaceAsync(model.ToTransaction(id)); + if (model.UserId.HasValue) + { + return RedirectToAction("Edit", "Users", new { id = model.UserId }); + } + else + { + return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); + } } - user.Type = Core.Enums.OrganizationUserType.Owner; - await _organizationUserRepository.ReplaceAsync(user); - return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId.Value }); - } - - public IActionResult GenerateLicense() - { - return View(); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task GenerateLicense(LicenseModel model) - { - if (!ModelState.IsValid) + public IActionResult PromoteAdmin() { - return View(model); + return View(); } - User user = null; - Organization organization = null; - if (model.UserId.HasValue) + [HttpPost] + [ValidateAntiForgeryToken] + public async Task PromoteAdmin(PromoteAdminModel model) { - user = await _userService.GetUserByIdAsync(model.UserId.Value); + if (!ModelState.IsValid) + { + return View(model); + } + + var orgUsers = await _organizationUserRepository.GetManyByOrganizationAsync( + model.OrganizationId.Value, null); + var user = orgUsers.FirstOrDefault(u => u.UserId == model.UserId.Value); if (user == null) { - ModelState.AddModelError(nameof(model.UserId), "User Id not found."); + ModelState.AddModelError(nameof(model.UserId), "User Id not found in this organization."); } + else if (user.Type != Core.Enums.OrganizationUserType.Admin) + { + ModelState.AddModelError(nameof(model.UserId), "User is not an admin of this organization."); + } + + if (!ModelState.IsValid) + { + return View(model); + } + + user.Type = Core.Enums.OrganizationUserType.Owner; + await _organizationUserRepository.ReplaceAsync(user); + return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId.Value }); } - else if (model.OrganizationId.HasValue) + + public IActionResult GenerateLicense() { - organization = await _organizationRepository.GetByIdAsync(model.OrganizationId.Value); - if (organization == null) - { - ModelState.AddModelError(nameof(model.OrganizationId), "Organization not found."); - } - else if (!organization.Enabled) - { - ModelState.AddModelError(nameof(model.OrganizationId), "Organization is disabled."); - } + return View(); } - if (model.InstallationId.HasValue) + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task GenerateLicense(LicenseModel model) { - var installation = await _installationRepository.GetByIdAsync(model.InstallationId.Value); - if (installation == null) + if (!ModelState.IsValid) { - ModelState.AddModelError(nameof(model.InstallationId), "Installation not found."); + return View(model); } - else if (!installation.Enabled) + + User user = null; + Organization organization = null; + if (model.UserId.HasValue) { - ModelState.AddModelError(nameof(model.OrganizationId), "Installation is disabled."); + user = await _userService.GetUserByIdAsync(model.UserId.Value); + if (user == null) + { + ModelState.AddModelError(nameof(model.UserId), "User Id not found."); + } + } + else if (model.OrganizationId.HasValue) + { + organization = await _organizationRepository.GetByIdAsync(model.OrganizationId.Value); + if (organization == null) + { + ModelState.AddModelError(nameof(model.OrganizationId), "Organization not found."); + } + else if (!organization.Enabled) + { + ModelState.AddModelError(nameof(model.OrganizationId), "Organization is disabled."); + } + } + if (model.InstallationId.HasValue) + { + var installation = await _installationRepository.GetByIdAsync(model.InstallationId.Value); + if (installation == null) + { + ModelState.AddModelError(nameof(model.InstallationId), "Installation not found."); + } + else if (!installation.Enabled) + { + ModelState.AddModelError(nameof(model.OrganizationId), "Installation is disabled."); + } + } + + if (!ModelState.IsValid) + { + return View(model); + } + + if (organization != null) + { + var license = await _organizationService.GenerateLicenseAsync(organization, + model.InstallationId.Value, model.Version); + var ms = new MemoryStream(); + await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); + ms.Seek(0, SeekOrigin.Begin); + return File(ms, "text/plain", "bitwarden_organization_license.json"); + } + else if (user != null) + { + var license = await _userService.GenerateLicenseAsync(user, null, model.Version); + var ms = new MemoryStream(); + ms.Seek(0, SeekOrigin.Begin); + await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); + ms.Seek(0, SeekOrigin.Begin); + return File(ms, "text/plain", "bitwarden_premium_license.json"); + } + else + { + throw new Exception("No license to generate."); } } - if (!ModelState.IsValid) + public async Task TaxRate(int page = 1, int count = 25) { + if (page < 1) + { + page = 1; + } + + if (count < 1) + { + count = 1; + } + + var skip = (page - 1) * count; + var rates = await _taxRateRepository.SearchAsync(skip, count); + return View(new TaxRatesModel + { + Items = rates.ToList(), + Page = page, + Count = count + }); + } + + public async Task TaxRateAddEdit(string stripeTaxRateId = null) + { + if (string.IsNullOrWhiteSpace(stripeTaxRateId)) + { + return View(new TaxRateAddEditModel()); + } + + var rate = await _taxRateRepository.GetByIdAsync(stripeTaxRateId); + var model = new TaxRateAddEditModel() + { + StripeTaxRateId = stripeTaxRateId, + Country = rate.Country, + State = rate.State, + PostalCode = rate.PostalCode, + Rate = rate.Rate + }; + return View(model); } - if (organization != null) + [ValidateAntiForgeryToken] + public async Task TaxRateUpload(IFormFile file) { - var license = await _organizationService.GenerateLicenseAsync(organization, - model.InstallationId.Value, model.Version); - var ms = new MemoryStream(); - await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); - ms.Seek(0, SeekOrigin.Begin); - return File(ms, "text/plain", "bitwarden_organization_license.json"); - } - else if (user != null) - { - var license = await _userService.GenerateLicenseAsync(user, null, model.Version); - var ms = new MemoryStream(); - ms.Seek(0, SeekOrigin.Begin); - await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); - ms.Seek(0, SeekOrigin.Begin); - return File(ms, "text/plain", "bitwarden_premium_license.json"); - } - else - { - throw new Exception("No license to generate."); - } - } - - public async Task TaxRate(int page = 1, int count = 25) - { - if (page < 1) - { - page = 1; - } - - if (count < 1) - { - count = 1; - } - - var skip = (page - 1) * count; - var rates = await _taxRateRepository.SearchAsync(skip, count); - return View(new TaxRatesModel - { - Items = rates.ToList(), - Page = page, - Count = count - }); - } - - public async Task TaxRateAddEdit(string stripeTaxRateId = null) - { - if (string.IsNullOrWhiteSpace(stripeTaxRateId)) - { - return View(new TaxRateAddEditModel()); - } - - var rate = await _taxRateRepository.GetByIdAsync(stripeTaxRateId); - var model = new TaxRateAddEditModel() - { - StripeTaxRateId = stripeTaxRateId, - Country = rate.Country, - State = rate.State, - PostalCode = rate.PostalCode, - Rate = rate.Rate - }; - - return View(model); - } - - [ValidateAntiForgeryToken] - public async Task TaxRateUpload(IFormFile file) - { - if (file == null || file.Length == 0) - { - throw new ArgumentNullException(nameof(file)); - } - - // Build rates and validate them first before updating DB & Stripe - var taxRateUpdates = new List(); - var currentTaxRates = await _taxRateRepository.GetAllActiveAsync(); - using var reader = new StreamReader(file.OpenReadStream()); - while (!reader.EndOfStream) - { - var line = await reader.ReadLineAsync(); - if (string.IsNullOrWhiteSpace(line)) + if (file == null || file.Length == 0) { - continue; + throw new ArgumentNullException(nameof(file)); } - var taxParts = line.Split(','); - if (taxParts.Length < 2) + + // Build rates and validate them first before updating DB & Stripe + var taxRateUpdates = new List(); + var currentTaxRates = await _taxRateRepository.GetAllActiveAsync(); + using var reader = new StreamReader(file.OpenReadStream()); + while (!reader.EndOfStream) { - throw new Exception($"This line is not in the format of ,,,: {line}"); - } - var postalCode = taxParts[0].Trim(); - if (string.IsNullOrWhiteSpace(postalCode)) - { - throw new Exception($"'{line}' is not valid, the first element must contain a postal code."); - } - if (!decimal.TryParse(taxParts[1], out var rate) || rate <= 0M || rate > 100) - { - throw new Exception($"{taxParts[1]} is not a valid rate/decimal for {postalCode}"); - } - var state = taxParts.Length > 2 ? taxParts[2] : null; - var country = (taxParts.Length > 3 ? taxParts[3] : null); - if (string.IsNullOrWhiteSpace(country)) - { - country = "US"; - } - var taxRate = currentTaxRates.FirstOrDefault(r => r.Country == country && r.PostalCode == postalCode) ?? - new TaxRate + var line = await reader.ReadLineAsync(); + if (string.IsNullOrWhiteSpace(line)) { - Country = country, - PostalCode = postalCode, - Active = true, - }; - taxRate.Rate = rate; - taxRate.State = state ?? taxRate.State; - taxRateUpdates.Add(taxRate); + continue; + } + var taxParts = line.Split(','); + if (taxParts.Length < 2) + { + throw new Exception($"This line is not in the format of ,,,: {line}"); + } + var postalCode = taxParts[0].Trim(); + if (string.IsNullOrWhiteSpace(postalCode)) + { + throw new Exception($"'{line}' is not valid, the first element must contain a postal code."); + } + if (!decimal.TryParse(taxParts[1], out var rate) || rate <= 0M || rate > 100) + { + throw new Exception($"{taxParts[1]} is not a valid rate/decimal for {postalCode}"); + } + var state = taxParts.Length > 2 ? taxParts[2] : null; + var country = (taxParts.Length > 3 ? taxParts[3] : null); + if (string.IsNullOrWhiteSpace(country)) + { + country = "US"; + } + var taxRate = currentTaxRates.FirstOrDefault(r => r.Country == country && r.PostalCode == postalCode) ?? + new TaxRate + { + Country = country, + PostalCode = postalCode, + Active = true, + }; + taxRate.Rate = rate; + taxRate.State = state ?? taxRate.State; + taxRateUpdates.Add(taxRate); + } + + foreach (var taxRate in taxRateUpdates) + { + if (!string.IsNullOrWhiteSpace(taxRate.Id)) + { + await _paymentService.UpdateTaxRateAsync(taxRate); + } + else + { + await _paymentService.CreateTaxRateAsync(taxRate); + } + } + + return RedirectToAction("TaxRate"); } - foreach (var taxRate in taxRateUpdates) + [HttpPost] + [ValidateAntiForgeryToken] + public async Task TaxRateAddEdit(TaxRateAddEditModel model) { - if (!string.IsNullOrWhiteSpace(taxRate.Id)) + var existingRateCheck = await _taxRateRepository.GetByLocationAsync(new TaxRate() { Country = model.Country, PostalCode = model.PostalCode }); + if (existingRateCheck.Any()) + { + ModelState.AddModelError(nameof(model.PostalCode), "A tax rate already exists for this Country/Postal Code combination."); + } + + if (!ModelState.IsValid) + { + return View(model); + } + + var taxRate = new TaxRate() + { + Id = model.StripeTaxRateId, + Country = model.Country, + State = model.State, + PostalCode = model.PostalCode, + Rate = model.Rate + }; + + if (!string.IsNullOrWhiteSpace(model.StripeTaxRateId)) { await _paymentService.UpdateTaxRateAsync(taxRate); } @@ -384,175 +421,139 @@ public class ToolsController : Controller { await _paymentService.CreateTaxRateAsync(taxRate); } + + return RedirectToAction("TaxRate"); } - return RedirectToAction("TaxRate"); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task TaxRateAddEdit(TaxRateAddEditModel model) - { - var existingRateCheck = await _taxRateRepository.GetByLocationAsync(new TaxRate() { Country = model.Country, PostalCode = model.PostalCode }); - if (existingRateCheck.Any()) + public async Task TaxRateArchive(string stripeTaxRateId) { - ModelState.AddModelError(nameof(model.PostalCode), "A tax rate already exists for this Country/Postal Code combination."); - } - - if (!ModelState.IsValid) - { - return View(model); - } - - var taxRate = new TaxRate() - { - Id = model.StripeTaxRateId, - Country = model.Country, - State = model.State, - PostalCode = model.PostalCode, - Rate = model.Rate - }; - - if (!string.IsNullOrWhiteSpace(model.StripeTaxRateId)) - { - await _paymentService.UpdateTaxRateAsync(taxRate); - } - else - { - await _paymentService.CreateTaxRateAsync(taxRate); - } - - return RedirectToAction("TaxRate"); - } - - public async Task TaxRateArchive(string stripeTaxRateId) - { - if (!string.IsNullOrWhiteSpace(stripeTaxRateId)) - { - await _paymentService.ArchiveTaxRateAsync(new TaxRate() { Id = stripeTaxRateId }); - } - - return RedirectToAction("TaxRate"); - } - - public async Task StripeSubscriptions(StripeSubscriptionListOptions options) - { - options = options ?? new StripeSubscriptionListOptions(); - options.Limit = 10; - options.Expand = new List() { "data.customer", "data.latest_invoice" }; - options.SelectAll = false; - - var subscriptions = await _stripeAdapter.SubscriptionListAsync(options); - - options.StartingAfter = subscriptions.LastOrDefault()?.Id; - options.EndingBefore = await StripeSubscriptionsGetHasPreviousPage(subscriptions, options) ? - subscriptions.FirstOrDefault()?.Id : - null; - - var model = new StripeSubscriptionsModel() - { - Items = subscriptions.Select(s => new StripeSubscriptionRowModel(s)).ToList(), - Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data, - TestClocks = await _stripeAdapter.TestClockListAsync(), - Filter = options - }; - return View(model); - } - - [HttpPost] - public async Task StripeSubscriptions([FromForm] StripeSubscriptionsModel model) - { - if (!ModelState.IsValid) - { - model.Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data; - model.TestClocks = await _stripeAdapter.TestClockListAsync(); - return View(model); - } - - if (model.Action == StripeSubscriptionsAction.Export || model.Action == StripeSubscriptionsAction.BulkCancel) - { - var subscriptions = model.Filter.SelectAll ? - await _stripeAdapter.SubscriptionListAsync(model.Filter) : - model.Items.Where(x => x.Selected).Select(x => x.Subscription); - - if (model.Action == StripeSubscriptionsAction.Export) + if (!string.IsNullOrWhiteSpace(stripeTaxRateId)) { - return StripeSubscriptionsExport(subscriptions); + await _paymentService.ArchiveTaxRateAsync(new TaxRate() { Id = stripeTaxRateId }); } - if (model.Action == StripeSubscriptionsAction.BulkCancel) - { - await StripeSubscriptionsCancel(subscriptions); - } - } - else - { - if (model.Action == StripeSubscriptionsAction.PreviousPage || model.Action == StripeSubscriptionsAction.Search) - { - model.Filter.StartingAfter = null; - } - if (model.Action == StripeSubscriptionsAction.NextPage || model.Action == StripeSubscriptionsAction.Search) - { - model.Filter.EndingBefore = null; - } + return RedirectToAction("TaxRate"); } - - return RedirectToAction("StripeSubscriptions", model.Filter); - } - - // This requires a redundant API call to Stripe because of the way they handle pagination. - // The StartingBefore value has to be infered from the list we get, and isn't supplied by Stripe. - private async Task StripeSubscriptionsGetHasPreviousPage(List subscriptions, StripeSubscriptionListOptions options) - { - var hasPreviousPage = false; - if (subscriptions.FirstOrDefault()?.Id != null) + public async Task StripeSubscriptions(StripeSubscriptionListOptions options) { - var previousPageSearchOptions = new StripeSubscriptionListOptions() + options = options ?? new StripeSubscriptionListOptions(); + options.Limit = 10; + options.Expand = new List() { "data.customer", "data.latest_invoice" }; + options.SelectAll = false; + + var subscriptions = await _stripeAdapter.SubscriptionListAsync(options); + + options.StartingAfter = subscriptions.LastOrDefault()?.Id; + options.EndingBefore = await StripeSubscriptionsGetHasPreviousPage(subscriptions, options) ? + subscriptions.FirstOrDefault()?.Id : + null; + + var model = new StripeSubscriptionsModel() { - EndingBefore = subscriptions.FirstOrDefault().Id, - Limit = 1, - Status = options.Status, - CurrentPeriodEndDate = options.CurrentPeriodEndDate, - CurrentPeriodEndRange = options.CurrentPeriodEndRange, - Price = options.Price + Items = subscriptions.Select(s => new StripeSubscriptionRowModel(s)).ToList(), + Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data, + TestClocks = await _stripeAdapter.TestClockListAsync(), + Filter = options }; - hasPreviousPage = (await _stripeAdapter.SubscriptionListAsync(previousPageSearchOptions)).Count > 0; + return View(model); } - return hasPreviousPage; - } - private async Task StripeSubscriptionsCancel(IEnumerable subscriptions) - { - foreach (var s in subscriptions) + [HttpPost] + public async Task StripeSubscriptions([FromForm] StripeSubscriptionsModel model) { - await _stripeAdapter.SubscriptionCancelAsync(s.Id); - if (s.LatestInvoice?.Status == "open") + if (!ModelState.IsValid) { - await _stripeAdapter.InvoiceVoidInvoiceAsync(s.LatestInvoiceId); + model.Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data; + model.TestClocks = await _stripeAdapter.TestClockListAsync(); + return View(model); + } + + if (model.Action == StripeSubscriptionsAction.Export || model.Action == StripeSubscriptionsAction.BulkCancel) + { + var subscriptions = model.Filter.SelectAll ? + await _stripeAdapter.SubscriptionListAsync(model.Filter) : + model.Items.Where(x => x.Selected).Select(x => x.Subscription); + + if (model.Action == StripeSubscriptionsAction.Export) + { + return StripeSubscriptionsExport(subscriptions); + } + + if (model.Action == StripeSubscriptionsAction.BulkCancel) + { + await StripeSubscriptionsCancel(subscriptions); + } + } + else + { + if (model.Action == StripeSubscriptionsAction.PreviousPage || model.Action == StripeSubscriptionsAction.Search) + { + model.Filter.StartingAfter = null; + } + if (model.Action == StripeSubscriptionsAction.NextPage || model.Action == StripeSubscriptionsAction.Search) + { + model.Filter.EndingBefore = null; + } + } + + + return RedirectToAction("StripeSubscriptions", model.Filter); + } + + // This requires a redundant API call to Stripe because of the way they handle pagination. + // The StartingBefore value has to be infered from the list we get, and isn't supplied by Stripe. + private async Task StripeSubscriptionsGetHasPreviousPage(List subscriptions, StripeSubscriptionListOptions options) + { + var hasPreviousPage = false; + if (subscriptions.FirstOrDefault()?.Id != null) + { + var previousPageSearchOptions = new StripeSubscriptionListOptions() + { + EndingBefore = subscriptions.FirstOrDefault().Id, + Limit = 1, + Status = options.Status, + CurrentPeriodEndDate = options.CurrentPeriodEndDate, + CurrentPeriodEndRange = options.CurrentPeriodEndRange, + Price = options.Price + }; + hasPreviousPage = (await _stripeAdapter.SubscriptionListAsync(previousPageSearchOptions)).Count > 0; + } + return hasPreviousPage; + } + + private async Task StripeSubscriptionsCancel(IEnumerable subscriptions) + { + foreach (var s in subscriptions) + { + await _stripeAdapter.SubscriptionCancelAsync(s.Id); + if (s.LatestInvoice?.Status == "open") + { + await _stripeAdapter.InvoiceVoidInvoiceAsync(s.LatestInvoiceId); + } } } - } - private FileResult StripeSubscriptionsExport(IEnumerable subscriptions) - { - var fieldsToExport = subscriptions.Select(s => new + private FileResult StripeSubscriptionsExport(IEnumerable subscriptions) { - StripeId = s.Id, - CustomerEmail = s.Customer?.Email, - SubscriptionStatus = s.Status, - InvoiceDueDate = s.CurrentPeriodEnd, - SubscriptionProducts = s.Items?.Data.Select(p => p.Plan.Id) - }); + var fieldsToExport = subscriptions.Select(s => new + { + StripeId = s.Id, + CustomerEmail = s.Customer?.Email, + SubscriptionStatus = s.Status, + InvoiceDueDate = s.CurrentPeriodEnd, + SubscriptionProducts = s.Items?.Data.Select(p => p.Plan.Id) + }); - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - WriteIndented = true - }; + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + WriteIndented = true + }; - var result = System.Text.Json.JsonSerializer.Serialize(fieldsToExport, options); - var bytes = Encoding.UTF8.GetBytes(result); - return File(bytes, "application/json", "StripeSubscriptionsSearch.json"); + var result = System.Text.Json.JsonSerializer.Serialize(fieldsToExport, options); + var bytes = Encoding.UTF8.GetBytes(result); + return File(bytes, "application/json", "StripeSubscriptionsSearch.json"); + } } } diff --git a/src/Admin/Controllers/UsersController.cs b/src/Admin/Controllers/UsersController.cs index 0a4becb697..e8ea2e0cd5 100644 --- a/src/Admin/Controllers/UsersController.cs +++ b/src/Admin/Controllers/UsersController.cs @@ -7,104 +7,105 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers; - -[Authorize] -public class UsersController : Controller +namespace Bit.Admin.Controllers { - private readonly IUserRepository _userRepository; - private readonly ICipherRepository _cipherRepository; - private readonly IPaymentService _paymentService; - private readonly GlobalSettings _globalSettings; - - public UsersController( - IUserRepository userRepository, - ICipherRepository cipherRepository, - IPaymentService paymentService, - GlobalSettings globalSettings) + [Authorize] + public class UsersController : Controller { - _userRepository = userRepository; - _cipherRepository = cipherRepository; - _paymentService = paymentService; - _globalSettings = globalSettings; - } + private readonly IUserRepository _userRepository; + private readonly ICipherRepository _cipherRepository; + private readonly IPaymentService _paymentService; + private readonly GlobalSettings _globalSettings; - public async Task Index(string email, int page = 1, int count = 25) - { - if (page < 1) + public UsersController( + IUserRepository userRepository, + ICipherRepository cipherRepository, + IPaymentService paymentService, + GlobalSettings globalSettings) { - page = 1; + _userRepository = userRepository; + _cipherRepository = cipherRepository; + _paymentService = paymentService; + _globalSettings = globalSettings; } - if (count < 1) + public async Task Index(string email, int page = 1, int count = 25) { - count = 1; + if (page < 1) + { + page = 1; + } + + if (count < 1) + { + count = 1; + } + + var skip = (page - 1) * count; + var users = await _userRepository.SearchAsync(email, skip, count); + return View(new UsersModel + { + Items = users as List, + Email = string.IsNullOrWhiteSpace(email) ? null : email, + Page = page, + Count = count, + Action = _globalSettings.SelfHosted ? "View" : "Edit" + }); } - var skip = (page - 1) * count; - var users = await _userRepository.SearchAsync(email, skip, count); - return View(new UsersModel + public async Task View(Guid id) { - Items = users as List, - Email = string.IsNullOrWhiteSpace(email) ? null : email, - Page = page, - Count = count, - Action = _globalSettings.SelfHosted ? "View" : "Edit" - }); - } + var user = await _userRepository.GetByIdAsync(id); + if (user == null) + { + return RedirectToAction("Index"); + } - public async Task View(Guid id) - { - var user = await _userRepository.GetByIdAsync(id); - if (user == null) + var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); + return View(new UserViewModel(user, ciphers)); + } + + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id) { + var user = await _userRepository.GetByIdAsync(id); + if (user == null) + { + return RedirectToAction("Index"); + } + + var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); + var billingInfo = await _paymentService.GetBillingAsync(user); + return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings)); + } + + [HttpPost] + [ValidateAntiForgeryToken] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id, UserEditModel model) + { + var user = await _userRepository.GetByIdAsync(id); + if (user == null) + { + return RedirectToAction("Index"); + } + + model.ToUser(user); + await _userRepository.ReplaceAsync(user); + return RedirectToAction("Edit", new { id }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Delete(Guid id) + { + var user = await _userRepository.GetByIdAsync(id); + if (user != null) + { + await _userRepository.DeleteAsync(user); + } + return RedirectToAction("Index"); } - - var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); - return View(new UserViewModel(user, ciphers)); - } - - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id) - { - var user = await _userRepository.GetByIdAsync(id); - if (user == null) - { - return RedirectToAction("Index"); - } - - var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); - var billingInfo = await _paymentService.GetBillingAsync(user); - return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings)); - } - - [HttpPost] - [ValidateAntiForgeryToken] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id, UserEditModel model) - { - var user = await _userRepository.GetByIdAsync(id); - if (user == null) - { - return RedirectToAction("Index"); - } - - model.ToUser(user); - await _userRepository.ReplaceAsync(user); - return RedirectToAction("Edit", new { id }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Delete(Guid id) - { - var user = await _userRepository.GetByIdAsync(id); - if (user != null) - { - await _userRepository.DeleteAsync(user); - } - - return RedirectToAction("Index"); } } diff --git a/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs b/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs index 646da09c5d..b0222d06f7 100644 --- a/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs +++ b/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs @@ -4,80 +4,81 @@ using Amazon.SQS.Model; using Bit.Core.Settings; using Microsoft.Extensions.Options; -namespace Bit.Admin.HostedServices; - -public class AmazonSqsBlockIpHostedService : BlockIpHostedService +namespace Bit.Admin.HostedServices { - private AmazonSQSClient _client; - - public AmazonSqsBlockIpHostedService( - ILogger logger, - IOptions adminSettings, - GlobalSettings globalSettings) - : base(logger, adminSettings, globalSettings) - { } - - public override void Dispose() + public class AmazonSqsBlockIpHostedService : BlockIpHostedService { - _client?.Dispose(); - } + private AmazonSQSClient _client; - protected override async Task ExecuteAsync(CancellationToken cancellationToken) - { - _client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId, - _globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region)); - var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken); - var blockIpQueueUrl = blockIpQueue.QueueUrl; - var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken); - var unblockIpQueueUrl = unblockIpQueue.QueueUrl; + public AmazonSqsBlockIpHostedService( + ILogger logger, + IOptions adminSettings, + GlobalSettings globalSettings) + : base(logger, adminSettings, globalSettings) + { } - while (!cancellationToken.IsCancellationRequested) + public override void Dispose() { - var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest - { - QueueUrl = blockIpQueueUrl, - MaxNumberOfMessages = 10, - WaitTimeSeconds = 15 - }, cancellationToken); - if (blockMessageResponse.Messages.Any()) - { - foreach (var message in blockMessageResponse.Messages) - { - try - { - await BlockIpAsync(message.Body, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to block IP."); - } - await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken); - } - } + _client?.Dispose(); + } - var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest - { - QueueUrl = unblockIpQueueUrl, - MaxNumberOfMessages = 10, - WaitTimeSeconds = 15 - }, cancellationToken); - if (unblockMessageResponse.Messages.Any()) - { - foreach (var message in unblockMessageResponse.Messages) - { - try - { - await UnblockIpAsync(message.Body, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to unblock IP."); - } - await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken); - } - } + protected override async Task ExecuteAsync(CancellationToken cancellationToken) + { + _client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId, + _globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region)); + var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken); + var blockIpQueueUrl = blockIpQueue.QueueUrl; + var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken); + var unblockIpQueueUrl = unblockIpQueue.QueueUrl; - await Task.Delay(TimeSpan.FromSeconds(15)); + while (!cancellationToken.IsCancellationRequested) + { + var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest + { + QueueUrl = blockIpQueueUrl, + MaxNumberOfMessages = 10, + WaitTimeSeconds = 15 + }, cancellationToken); + if (blockMessageResponse.Messages.Any()) + { + foreach (var message in blockMessageResponse.Messages) + { + try + { + await BlockIpAsync(message.Body, cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to block IP."); + } + await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken); + } + } + + var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest + { + QueueUrl = unblockIpQueueUrl, + MaxNumberOfMessages = 10, + WaitTimeSeconds = 15 + }, cancellationToken); + if (unblockMessageResponse.Messages.Any()) + { + foreach (var message in unblockMessageResponse.Messages) + { + try + { + await UnblockIpAsync(message.Body, cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to unblock IP."); + } + await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken); + } + } + + await Task.Delay(TimeSpan.FromSeconds(15)); + } } } } diff --git a/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs b/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs index f1590377e1..cd96f359a7 100644 --- a/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs +++ b/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs @@ -2,62 +2,63 @@ using Bit.Core.Settings; using Microsoft.Extensions.Options; -namespace Bit.Admin.HostedServices; - -public class AzureQueueBlockIpHostedService : BlockIpHostedService +namespace Bit.Admin.HostedServices { - private QueueClient _blockIpQueueClient; - private QueueClient _unblockIpQueueClient; - - public AzureQueueBlockIpHostedService( - ILogger logger, - IOptions adminSettings, - GlobalSettings globalSettings) - : base(logger, adminSettings, globalSettings) - { } - - protected override async Task ExecuteAsync(CancellationToken cancellationToken) + public class AzureQueueBlockIpHostedService : BlockIpHostedService { - _blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip"); - _unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip"); + private QueueClient _blockIpQueueClient; + private QueueClient _unblockIpQueueClient; - while (!cancellationToken.IsCancellationRequested) + public AzureQueueBlockIpHostedService( + ILogger logger, + IOptions adminSettings, + GlobalSettings globalSettings) + : base(logger, adminSettings, globalSettings) + { } + + protected override async Task ExecuteAsync(CancellationToken cancellationToken) { - var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); - if (blockMessages.Value?.Any() ?? false) - { - foreach (var message in blockMessages.Value) - { - try - { - await BlockIpAsync(message.MessageText, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to block IP."); - } - await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); - } - } + _blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip"); + _unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip"); - var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); - if (unblockMessages.Value?.Any() ?? false) + while (!cancellationToken.IsCancellationRequested) { - foreach (var message in unblockMessages.Value) + var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); + if (blockMessages.Value?.Any() ?? false) { - try + foreach (var message in blockMessages.Value) { - await UnblockIpAsync(message.MessageText, cancellationToken); + try + { + await BlockIpAsync(message.MessageText, cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to block IP."); + } + await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } - catch (Exception e) - { - _logger.LogError(e, "Failed to unblock IP."); - } - await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } - } - await Task.Delay(TimeSpan.FromSeconds(15)); + var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); + if (unblockMessages.Value?.Any() ?? false) + { + foreach (var message in unblockMessages.Value) + { + try + { + await UnblockIpAsync(message.MessageText, cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to unblock IP."); + } + await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + } + } + + await Task.Delay(TimeSpan.FromSeconds(15)); + } } } } diff --git a/src/Admin/HostedServices/AzureQueueMailHostedService.cs b/src/Admin/HostedServices/AzureQueueMailHostedService.cs index b2031a405b..6e976f0b7d 100644 --- a/src/Admin/HostedServices/AzureQueueMailHostedService.cs +++ b/src/Admin/HostedServices/AzureQueueMailHostedService.cs @@ -6,96 +6,97 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Admin.HostedServices; - -public class AzureQueueMailHostedService : IHostedService +namespace Bit.Admin.HostedServices { - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly IMailService _mailService; - private CancellationTokenSource _cts; - private Task _executingTask; - - private QueueClient _mailQueueClient; - - public AzureQueueMailHostedService( - ILogger logger, - IMailService mailService, - GlobalSettings globalSettings) + public class AzureQueueMailHostedService : IHostedService { - _logger = logger; - _mailService = mailService; - _globalSettings = globalSettings; - } + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly IMailService _mailService; + private CancellationTokenSource _cts; + private Task _executingTask; - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + private QueueClient _mailQueueClient; - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) + public AzureQueueMailHostedService( + ILogger logger, + IMailService mailService, + GlobalSettings globalSettings) { - return; + _logger = logger; + _mailService = mailService; + _globalSettings = globalSettings; } - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - _mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail"); - - QueueMessage[] mailMessages; - while (!cancellationToken.IsCancellationRequested) + public Task StartAsync(CancellationToken cancellationToken) { - if (!(mailMessages = await RetrieveMessagesAsync()).Any()) + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - await Task.Delay(TimeSpan.FromSeconds(15)); + return; } + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } - foreach (var message in mailMessages) + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + _mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail"); + + QueueMessage[] mailMessages; + while (!cancellationToken.IsCancellationRequested) { - try + if (!(mailMessages = await RetrieveMessagesAsync()).Any()) { - using var document = JsonDocument.Parse(message.DecodeMessageText()); - var root = document.RootElement; + await Task.Delay(TimeSpan.FromSeconds(15)); + } - if (root.ValueKind == JsonValueKind.Array) + foreach (var message in mailMessages) + { + try { - foreach (var mailQueueMessage in root.ToObject>()) + using var document = JsonDocument.Parse(message.DecodeMessageText()); + var root = document.RootElement; + + if (root.ValueKind == JsonValueKind.Array) { + foreach (var mailQueueMessage in root.ToObject>()) + { + await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); + } + } + else if (root.ValueKind == JsonValueKind.Object) + { + var mailQueueMessage = root.ToObject(); await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); } } - else if (root.ValueKind == JsonValueKind.Object) + catch (Exception e) { - var mailQueueMessage = root.ToObject(); - await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); + _logger.LogError(e, "Failed to send email"); + // TODO: retries? } - } - catch (Exception e) - { - _logger.LogError(e, "Failed to send email"); - // TODO: retries? - } - await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); - if (cancellationToken.IsCancellationRequested) - { - break; + if (cancellationToken.IsCancellationRequested) + { + break; + } } } } - } - private async Task RetrieveMessagesAsync() - { - return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { }; + private async Task RetrieveMessagesAsync() + { + return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { }; + } } } diff --git a/src/Admin/HostedServices/BlockIpHostedService.cs b/src/Admin/HostedServices/BlockIpHostedService.cs index 6a1f58c6b3..17f0c50ce2 100644 --- a/src/Admin/HostedServices/BlockIpHostedService.cs +++ b/src/Admin/HostedServices/BlockIpHostedService.cs @@ -1,105 +1,71 @@ using Bit.Core.Settings; using Microsoft.Extensions.Options; -namespace Bit.Admin.HostedServices; - -public abstract class BlockIpHostedService : IHostedService, IDisposable +namespace Bit.Admin.HostedServices { - protected readonly ILogger _logger; - protected readonly GlobalSettings _globalSettings; - private readonly AdminSettings _adminSettings; - - private Task _executingTask; - private CancellationTokenSource _cts; - private HttpClient _httpClient = new HttpClient(); - - public BlockIpHostedService( - ILogger logger, - IOptions adminSettings, - GlobalSettings globalSettings) + public abstract class BlockIpHostedService : IHostedService, IDisposable { - _logger = logger; - _globalSettings = globalSettings; - _adminSettings = adminSettings?.Value; - } + protected readonly ILogger _logger; + protected readonly GlobalSettings _globalSettings; + private readonly AdminSettings _adminSettings; - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + private Task _executingTask; + private CancellationTokenSource _cts; + private HttpClient _httpClient = new HttpClient(); - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) + public BlockIpHostedService( + ILogger logger, + IOptions adminSettings, + GlobalSettings globalSettings) { - return; + _logger = logger; + _globalSettings = globalSettings; + _adminSettings = adminSettings?.Value; } - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - public virtual void Dispose() - { } - - protected abstract Task ExecuteAsync(CancellationToken cancellationToken); - - protected async Task BlockIpAsync(string message, CancellationToken cancellationToken) - { - var request = new HttpRequestMessage(); - request.Headers.Accept.Clear(); - request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); - request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); - request.Method = HttpMethod.Post; - request.RequestUri = new Uri("https://api.cloudflare.com/" + - $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules"); - - request.Content = JsonContent.Create(new + public Task StartAsync(CancellationToken cancellationToken) { - mode = "block", - configuration = new + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - target = "ip", - value = message - }, - notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}." - }); - - var response = await _httpClient.SendAsync(request, cancellationToken); - if (!response.IsSuccessStatusCode) - { - return; + return; + } + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); } - var accessRuleResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); - if (!accessRuleResponse.Success) - { - return; - } + public virtual void Dispose() + { } - // TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue - } + protected abstract Task ExecuteAsync(CancellationToken cancellationToken); - protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken) - { - if (string.IsNullOrWhiteSpace(message)) + protected async Task BlockIpAsync(string message, CancellationToken cancellationToken) { - return; - } - - if (message.Contains(".") || message.Contains(":")) - { - // IP address messages var request = new HttpRequestMessage(); request.Headers.Accept.Clear(); request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); - request.Method = HttpMethod.Get; + request.Method = HttpMethod.Post; request.RequestUri = new Uri("https://api.cloudflare.com/" + - $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" + - $"configuration_target=ip&configuration_value={message}"); + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules"); + + request.Content = JsonContent.Create(new + { + mode = "block", + configuration = new + { + target = "ip", + value = message + }, + notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}." + }); var response = await _httpClient.SendAsync(request, cancellationToken); if (!response.IsSuccessStatusCode) @@ -107,58 +73,93 @@ public abstract class BlockIpHostedService : IHostedService, IDisposable return; } - var listResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); - if (!listResponse.Success) + var accessRuleResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); + if (!accessRuleResponse.Success) { return; } - foreach (var rule in listResponse.Result) + // TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue + } + + protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(message)) { - await DeleteAccessRuleAsync(rule.Id, cancellationToken); + return; + } + + if (message.Contains(".") || message.Contains(":")) + { + // IP address messages + var request = new HttpRequestMessage(); + request.Headers.Accept.Clear(); + request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); + request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); + request.Method = HttpMethod.Get; + request.RequestUri = new Uri("https://api.cloudflare.com/" + + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" + + $"configuration_target=ip&configuration_value={message}"); + + var response = await _httpClient.SendAsync(request, cancellationToken); + if (!response.IsSuccessStatusCode) + { + return; + } + + var listResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); + if (!listResponse.Success) + { + return; + } + + foreach (var rule in listResponse.Result) + { + await DeleteAccessRuleAsync(rule.Id, cancellationToken); + } + } + else + { + // Rule Id messages + await DeleteAccessRuleAsync(message, cancellationToken); } } - else + + protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken) { - // Rule Id messages - await DeleteAccessRuleAsync(message, cancellationToken); + var request = new HttpRequestMessage(); + request.Headers.Accept.Clear(); + request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); + request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); + request.Method = HttpMethod.Delete; + request.RequestUri = new Uri("https://api.cloudflare.com/" + + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}"); + await _httpClient.SendAsync(request, cancellationToken); } - } - protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken) - { - var request = new HttpRequestMessage(); - request.Headers.Accept.Clear(); - request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); - request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); - request.Method = HttpMethod.Delete; - request.RequestUri = new Uri("https://api.cloudflare.com/" + - $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}"); - await _httpClient.SendAsync(request, cancellationToken); - } - - public class ListResponse - { - public bool Success { get; set; } - public List Result { get; set; } - } - - public class AccessRuleResponse - { - public bool Success { get; set; } - public AccessRuleResultResponse Result { get; set; } - } - - public class AccessRuleResultResponse - { - public string Id { get; set; } - public string Notes { get; set; } - public ConfigurationResponse Configuration { get; set; } - - public class ConfigurationResponse + public class ListResponse { - public string Target { get; set; } - public string Value { get; set; } + public bool Success { get; set; } + public List Result { get; set; } + } + + public class AccessRuleResponse + { + public bool Success { get; set; } + public AccessRuleResultResponse Result { get; set; } + } + + public class AccessRuleResultResponse + { + public string Id { get; set; } + public string Notes { get; set; } + public ConfigurationResponse Configuration { get; set; } + + public class ConfigurationResponse + { + public string Target { get; set; } + public string Value { get; set; } + } } } } diff --git a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs index 0f660729e1..06cf014282 100644 --- a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs +++ b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs @@ -3,61 +3,62 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Bit.Migrator; -namespace Bit.Admin.HostedServices; - -public class DatabaseMigrationHostedService : IHostedService, IDisposable +namespace Bit.Admin.HostedServices { - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - private readonly DbMigrator _dbMigrator; - - public DatabaseMigrationHostedService( - GlobalSettings globalSettings, - ILogger logger, - ILogger migratorLogger, - ILogger listenerLogger) + public class DatabaseMigrationHostedService : IHostedService, IDisposable { - _globalSettings = globalSettings; - _logger = logger; - _dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger); - } + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + private readonly DbMigrator _dbMigrator; - public virtual async Task StartAsync(CancellationToken cancellationToken) - { - // Wait 20 seconds to allow database to come online - await Task.Delay(20000); - - var maxMigrationAttempts = 10; - for (var i = 1; i <= maxMigrationAttempts; i++) + public DatabaseMigrationHostedService( + GlobalSettings globalSettings, + ILogger logger, + ILogger migratorLogger, + ILogger listenerLogger) { - try + _globalSettings = globalSettings; + _logger = logger; + _dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger); + } + + public virtual async Task StartAsync(CancellationToken cancellationToken) + { + // Wait 20 seconds to allow database to come online + await Task.Delay(20000); + + var maxMigrationAttempts = 10; + for (var i = 1; i <= maxMigrationAttempts; i++) { - _dbMigrator.MigrateMsSqlDatabase(true, cancellationToken); - // TODO: Maybe flip a flag somewhere to indicate migration is complete?? - break; - } - catch (SqlException e) - { - if (i >= maxMigrationAttempts) + try { - _logger.LogError(e, "Database failed to migrate."); - throw; + _dbMigrator.MigrateMsSqlDatabase(true, cancellationToken); + // TODO: Maybe flip a flag somewhere to indicate migration is complete?? + break; } - else + catch (SqlException e) { - _logger.LogError(e, - "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); - await Task.Delay(20000); + if (i >= maxMigrationAttempts) + { + _logger.LogError(e, "Database failed to migrate."); + throw; + } + else + { + _logger.LogError(e, + "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); + await Task.Delay(20000); + } } } } - } - public virtual Task StopAsync(CancellationToken cancellationToken) - { - return Task.FromResult(0); - } + public virtual Task StopAsync(CancellationToken cancellationToken) + { + return Task.FromResult(0); + } - public virtual void Dispose() - { } + public virtual void Dispose() + { } + } } diff --git a/src/Admin/Jobs/AliveJob.cs b/src/Admin/Jobs/AliveJob.cs index b97d597e58..27d23c3421 100644 --- a/src/Admin/Jobs/AliveJob.cs +++ b/src/Admin/Jobs/AliveJob.cs @@ -3,26 +3,27 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; -namespace Bit.Admin.Jobs; - -public class AliveJob : BaseJob +namespace Bit.Admin.Jobs { - private readonly GlobalSettings _globalSettings; - private HttpClient _httpClient = new HttpClient(); - - public AliveJob( - GlobalSettings globalSettings, - ILogger logger) - : base(logger) + public class AliveJob : BaseJob { - _globalSettings = globalSettings; - } + private readonly GlobalSettings _globalSettings; + private HttpClient _httpClient = new HttpClient(); - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); - var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + - response.StatusCode); + public AliveJob( + GlobalSettings globalSettings, + ILogger logger) + : base(logger) + { + _globalSettings = globalSettings; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); + var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + + response.StatusCode); + } } } diff --git a/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs b/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs index 626eb00d52..60ac448281 100644 --- a/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs +++ b/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs @@ -3,24 +3,25 @@ using Bit.Core.Jobs; using Bit.Core.Repositories; using Quartz; -namespace Bit.Admin.Jobs; - -public class DatabaseExpiredGrantsJob : BaseJob +namespace Bit.Admin.Jobs { - private readonly IMaintenanceRepository _maintenanceRepository; - - public DatabaseExpiredGrantsJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger) - : base(logger) + public class DatabaseExpiredGrantsJob : BaseJob { - _maintenanceRepository = maintenanceRepository; - } + private readonly IMaintenanceRepository _maintenanceRepository; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync"); - await _maintenanceRepository.DeleteExpiredGrantsAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync"); + public DatabaseExpiredGrantsJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger) + : base(logger) + { + _maintenanceRepository = maintenanceRepository; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync"); + await _maintenanceRepository.DeleteExpiredGrantsAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync"); + } } } diff --git a/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs b/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs index 7a00445fd0..609351e9f2 100644 --- a/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs +++ b/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs @@ -4,35 +4,36 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Quartz; -namespace Bit.Admin.Jobs; - -public class DatabaseExpiredSponsorshipsJob : BaseJob +namespace Bit.Admin.Jobs { - private GlobalSettings _globalSettings; - private readonly IMaintenanceRepository _maintenanceRepository; - - public DatabaseExpiredSponsorshipsJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger, - GlobalSettings globalSettings) - : base(logger) + public class DatabaseExpiredSponsorshipsJob : BaseJob { - _maintenanceRepository = maintenanceRepository; - _globalSettings = globalSettings; - } + private GlobalSettings _globalSettings; + private readonly IMaintenanceRepository _maintenanceRepository; - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) + public DatabaseExpiredSponsorshipsJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger, + GlobalSettings globalSettings) + : base(logger) { - return; + _maintenanceRepository = maintenanceRepository; + _globalSettings = globalSettings; } - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync"); - // allow a 90 day grace period before deleting - var deleteDate = DateTime.UtcNow.AddDays(-90); + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) + { + return; + } + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync"); - await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync"); + // allow a 90 day grace period before deleting + var deleteDate = DateTime.UtcNow.AddDays(-90); + + await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync"); + } } } diff --git a/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs b/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs index 78e48bb6f1..24e05043a7 100644 --- a/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs +++ b/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs @@ -3,24 +3,25 @@ using Bit.Core.Jobs; using Bit.Core.Repositories; using Quartz; -namespace Bit.Admin.Jobs; - -public class DatabaseRebuildlIndexesJob : BaseJob +namespace Bit.Admin.Jobs { - private readonly IMaintenanceRepository _maintenanceRepository; - - public DatabaseRebuildlIndexesJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger) - : base(logger) + public class DatabaseRebuildlIndexesJob : BaseJob { - _maintenanceRepository = maintenanceRepository; - } + private readonly IMaintenanceRepository _maintenanceRepository; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync"); - await _maintenanceRepository.RebuildIndexesAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync"); + public DatabaseRebuildlIndexesJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger) + : base(logger) + { + _maintenanceRepository = maintenanceRepository; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync"); + await _maintenanceRepository.RebuildIndexesAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync"); + } } } diff --git a/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs b/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs index 14c13918ba..4a03d08fd1 100644 --- a/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs +++ b/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs @@ -3,27 +3,28 @@ using Bit.Core.Jobs; using Bit.Core.Repositories; using Quartz; -namespace Bit.Admin.Jobs; - -public class DatabaseUpdateStatisticsJob : BaseJob +namespace Bit.Admin.Jobs { - private readonly IMaintenanceRepository _maintenanceRepository; - - public DatabaseUpdateStatisticsJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger) - : base(logger) + public class DatabaseUpdateStatisticsJob : BaseJob { - _maintenanceRepository = maintenanceRepository; - } + private readonly IMaintenanceRepository _maintenanceRepository; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync"); - await _maintenanceRepository.UpdateStatisticsAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync"); - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync"); - await _maintenanceRepository.DisableCipherAutoStatsAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync"); + public DatabaseUpdateStatisticsJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger) + : base(logger) + { + _maintenanceRepository = maintenanceRepository; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync"); + await _maintenanceRepository.UpdateStatisticsAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync"); + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync"); + await _maintenanceRepository.DisableCipherAutoStatsAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync"); + } } } diff --git a/src/Admin/Jobs/DeleteCiphersJob.cs b/src/Admin/Jobs/DeleteCiphersJob.cs index ecf13401ee..a5a92aa286 100644 --- a/src/Admin/Jobs/DeleteCiphersJob.cs +++ b/src/Admin/Jobs/DeleteCiphersJob.cs @@ -4,33 +4,34 @@ using Bit.Core.Repositories; using Microsoft.Extensions.Options; using Quartz; -namespace Bit.Admin.Jobs; - -public class DeleteCiphersJob : BaseJob +namespace Bit.Admin.Jobs { - private readonly ICipherRepository _cipherRepository; - private readonly AdminSettings _adminSettings; - - public DeleteCiphersJob( - ICipherRepository cipherRepository, - IOptions adminSettings, - ILogger logger) - : base(logger) + public class DeleteCiphersJob : BaseJob { - _cipherRepository = cipherRepository; - _adminSettings = adminSettings?.Value; - } + private readonly ICipherRepository _cipherRepository; + private readonly AdminSettings _adminSettings; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync"); - var deleteDate = DateTime.UtcNow.AddDays(-30); - var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault(); - if (daysAgoSetting > 0) + public DeleteCiphersJob( + ICipherRepository cipherRepository, + IOptions adminSettings, + ILogger logger) + : base(logger) { - deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting); + _cipherRepository = cipherRepository; + _adminSettings = adminSettings?.Value; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync"); + var deleteDate = DateTime.UtcNow.AddDays(-30); + var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault(); + if (daysAgoSetting > 0) + { + deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting); + } + await _cipherRepository.DeleteDeletedAsync(deleteDate); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync"); } - await _cipherRepository.DeleteDeletedAsync(deleteDate); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync"); } } diff --git a/src/Admin/Jobs/DeleteSendsJob.cs b/src/Admin/Jobs/DeleteSendsJob.cs index 9f3ed96efa..814840fc47 100644 --- a/src/Admin/Jobs/DeleteSendsJob.cs +++ b/src/Admin/Jobs/DeleteSendsJob.cs @@ -4,37 +4,38 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Quartz; -namespace Bit.Admin.Jobs; - -public class DeleteSendsJob : BaseJob +namespace Bit.Admin.Jobs { - private readonly ISendRepository _sendRepository; - private readonly IServiceProvider _serviceProvider; - - public DeleteSendsJob( - ISendRepository sendRepository, - IServiceProvider serviceProvider, - ILogger logger) - : base(logger) + public class DeleteSendsJob : BaseJob { - _sendRepository = sendRepository; - _serviceProvider = serviceProvider; - } + private readonly ISendRepository _sendRepository; + private readonly IServiceProvider _serviceProvider; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow); - _logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count); - if (!sends.Any()) + public DeleteSendsJob( + ISendRepository sendRepository, + IServiceProvider serviceProvider, + ILogger logger) + : base(logger) { - return; + _sendRepository = sendRepository; + _serviceProvider = serviceProvider; } - using (var scope = _serviceProvider.CreateScope()) + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) { - var sendService = scope.ServiceProvider.GetRequiredService(); - foreach (var send in sends) + var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow); + _logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count); + if (!sends.Any()) { - await sendService.DeleteSendAsync(send); + return; + } + using (var scope = _serviceProvider.CreateScope()) + { + var sendService = scope.ServiceProvider.GetRequiredService(); + foreach (var send in sends) + { + await sendService.DeleteSendAsync(send); + } } } } diff --git a/src/Admin/Jobs/JobsHostedService.cs b/src/Admin/Jobs/JobsHostedService.cs index 53b5c05660..01ac66a84d 100644 --- a/src/Admin/Jobs/JobsHostedService.cs +++ b/src/Admin/Jobs/JobsHostedService.cs @@ -3,93 +3,94 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; -namespace Bit.Admin.Jobs; - -public class JobsHostedService : BaseJobsHostedService +namespace Bit.Admin.Jobs { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } - - public override async Task StartAsync(CancellationToken cancellationToken) + public class JobsHostedService : BaseJobsHostedService { - var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : - TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); - if (_globalSettings.SelfHosted) + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } + + public override async Task StartAsync(CancellationToken cancellationToken) { - timeZone = TimeZoneInfo.Local; + var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : + TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); + if (_globalSettings.SelfHosted) + { + timeZone = TimeZoneInfo.Local; + } + + var everyTopOfTheHourTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheHourTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var everyFiveMinutesTrigger = TriggerBuilder.Create() + .WithIdentity("EveryFiveMinutesTrigger") + .StartNow() + .WithCronSchedule("0 */5 * * * ?") + .Build(); + var everyFridayAt10pmTrigger = TriggerBuilder.Create() + .WithIdentity("EveryFridayAt10pmTrigger") + .StartNow() + .WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone)) + .Build(); + var everySaturdayAtMidnightTrigger = TriggerBuilder.Create() + .WithIdentity("EverySaturdayAtMidnightTrigger") + .StartNow() + .WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone)) + .Build(); + var everySundayAtMidnightTrigger = TriggerBuilder.Create() + .WithIdentity("EverySundayAtMidnightTrigger") + .StartNow() + .WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone)) + .Build(); + var everyMondayAtMidnightTrigger = TriggerBuilder.Create() + .WithIdentity("EveryMondayAtMidnightTrigger") + .StartNow() + .WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone)) + .Build(); + var everyDayAtMidnightUtc = TriggerBuilder.Create() + .WithIdentity("EveryDayAtMidnightUtc") + .StartNow() + .WithCronSchedule("0 0 0 * * ?") + .Build(); + + var jobs = new List> + { + new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger), + new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), + new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), + new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), + new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), + new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger) + }; + + if (!_globalSettings.SelfHosted) + { + jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger)); + } + + Jobs = jobs; + await base.StartAsync(cancellationToken); } - var everyTopOfTheHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheHourTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var everyFiveMinutesTrigger = TriggerBuilder.Create() - .WithIdentity("EveryFiveMinutesTrigger") - .StartNow() - .WithCronSchedule("0 */5 * * * ?") - .Build(); - var everyFridayAt10pmTrigger = TriggerBuilder.Create() - .WithIdentity("EveryFridayAt10pmTrigger") - .StartNow() - .WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone)) - .Build(); - var everySaturdayAtMidnightTrigger = TriggerBuilder.Create() - .WithIdentity("EverySaturdayAtMidnightTrigger") - .StartNow() - .WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone)) - .Build(); - var everySundayAtMidnightTrigger = TriggerBuilder.Create() - .WithIdentity("EverySundayAtMidnightTrigger") - .StartNow() - .WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone)) - .Build(); - var everyMondayAtMidnightTrigger = TriggerBuilder.Create() - .WithIdentity("EveryMondayAtMidnightTrigger") - .StartNow() - .WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone)) - .Build(); - var everyDayAtMidnightUtc = TriggerBuilder.Create() - .WithIdentity("EveryDayAtMidnightUtc") - .StartNow() - .WithCronSchedule("0 0 0 * * ?") - .Build(); - - var jobs = new List> + public static void AddJobsServices(IServiceCollection services, bool selfHosted) { - new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger), - new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), - new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), - new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), - new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), - new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger) - }; - - if (!_globalSettings.SelfHosted) - { - jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger)); + if (!selfHosted) + { + services.AddTransient(); + } + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } - - Jobs = jobs; - await base.StartAsync(cancellationToken); - } - - public static void AddJobsServices(IServiceCollection services, bool selfHosted) - { - if (!selfHosted) - { - services.AddTransient(); - } - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); } } diff --git a/src/Admin/Models/BillingInformationModel.cs b/src/Admin/Models/BillingInformationModel.cs index a90ec7955d..1457a0851d 100644 --- a/src/Admin/Models/BillingInformationModel.cs +++ b/src/Admin/Models/BillingInformationModel.cs @@ -1,10 +1,11 @@ using Bit.Core.Models.Business; -namespace Bit.Admin.Models; - -public class BillingInformationModel +namespace Bit.Admin.Models { - public BillingInfo BillingInfo { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } + public class BillingInformationModel + { + public BillingInfo BillingInfo { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + } } diff --git a/src/Admin/Models/ChargeBraintreeModel.cs b/src/Admin/Models/ChargeBraintreeModel.cs index 2ba06cb980..b7adba8f13 100644 --- a/src/Admin/Models/ChargeBraintreeModel.cs +++ b/src/Admin/Models/ChargeBraintreeModel.cs @@ -1,26 +1,27 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models; - -public class ChargeBraintreeModel : IValidatableObject +namespace Bit.Admin.Models { - [Required] - [Display(Name = "Braintree Customer Id")] - public string Id { get; set; } - [Required] - [Display(Name = "Amount")] - public decimal? Amount { get; set; } - public string TransactionId { get; set; } - public string PayPalTransactionId { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class ChargeBraintreeModel : IValidatableObject { - if (Id != null) + [Required] + [Display(Name = "Braintree Customer Id")] + public string Id { get; set; } + [Required] + [Display(Name = "Amount")] + public decimal? Amount { get; set; } + public string TransactionId { get; set; } + public string PayPalTransactionId { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || - !Guid.TryParse(Id.Substring(1, 32), out var guid)) + if (Id != null) { - yield return new ValidationResult("Customer Id is not a valid format."); + if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || + !Guid.TryParse(Id.Substring(1, 32), out var guid)) + { + yield return new ValidationResult("Customer Id is not a valid format."); + } } } } diff --git a/src/Admin/Models/CreateProviderModel.cs b/src/Admin/Models/CreateProviderModel.cs index 9bcbf1f75b..582c388af2 100644 --- a/src/Admin/Models/CreateProviderModel.cs +++ b/src/Admin/Models/CreateProviderModel.cs @@ -1,12 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models; - -public class CreateProviderModel +namespace Bit.Admin.Models { - public CreateProviderModel() { } + public class CreateProviderModel + { + public CreateProviderModel() { } - [Display(Name = "Owner Email")] - [Required] - public string OwnerEmail { get; set; } + [Display(Name = "Owner Email")] + [Required] + public string OwnerEmail { get; set; } + } } diff --git a/src/Admin/Models/CreateUpdateTransactionModel.cs b/src/Admin/Models/CreateUpdateTransactionModel.cs index 8004546f9e..0ab1f0dc8f 100644 --- a/src/Admin/Models/CreateUpdateTransactionModel.cs +++ b/src/Admin/Models/CreateUpdateTransactionModel.cs @@ -2,76 +2,77 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Admin.Models; - -public class CreateUpdateTransactionModel : IValidatableObject +namespace Bit.Admin.Models { - public CreateUpdateTransactionModel() { } - - public CreateUpdateTransactionModel(Transaction transaction) + public class CreateUpdateTransactionModel : IValidatableObject { - Edit = true; - UserId = transaction.UserId; - OrganizationId = transaction.OrganizationId; - Amount = transaction.Amount; - RefundedAmount = transaction.RefundedAmount; - Refunded = transaction.Refunded.GetValueOrDefault(); - Details = transaction.Details; - Date = transaction.CreationDate; - PaymentMethod = transaction.PaymentMethodType; - Gateway = transaction.Gateway; - GatewayId = transaction.GatewayId; - Type = transaction.Type; - } + public CreateUpdateTransactionModel() { } - public bool Edit { get; set; } - - [Display(Name = "User Id")] - public Guid? UserId { get; set; } - [Display(Name = "Organization Id")] - public Guid? OrganizationId { get; set; } - [Required] - public decimal? Amount { get; set; } - [Display(Name = "Refunded Amount")] - public decimal? RefundedAmount { get; set; } - public bool Refunded { get; set; } - [Required] - public string Details { get; set; } - [Required] - public DateTime? Date { get; set; } - [Display(Name = "Payment Method")] - public PaymentMethodType? PaymentMethod { get; set; } - public GatewayType? Gateway { get; set; } - [Display(Name = "Gateway Id")] - public string GatewayId { get; set; } - [Required] - public TransactionType? Type { get; set; } - - - public IEnumerable Validate(ValidationContext validationContext) - { - if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue)) + public CreateUpdateTransactionModel(Transaction transaction) { - yield return new ValidationResult("Must provide either User Id, or Organization Id."); + Edit = true; + UserId = transaction.UserId; + OrganizationId = transaction.OrganizationId; + Amount = transaction.Amount; + RefundedAmount = transaction.RefundedAmount; + Refunded = transaction.Refunded.GetValueOrDefault(); + Details = transaction.Details; + Date = transaction.CreationDate; + PaymentMethod = transaction.PaymentMethodType; + Gateway = transaction.Gateway; + GatewayId = transaction.GatewayId; + Type = transaction.Type; + } + + public bool Edit { get; set; } + + [Display(Name = "User Id")] + public Guid? UserId { get; set; } + [Display(Name = "Organization Id")] + public Guid? OrganizationId { get; set; } + [Required] + public decimal? Amount { get; set; } + [Display(Name = "Refunded Amount")] + public decimal? RefundedAmount { get; set; } + public bool Refunded { get; set; } + [Required] + public string Details { get; set; } + [Required] + public DateTime? Date { get; set; } + [Display(Name = "Payment Method")] + public PaymentMethodType? PaymentMethod { get; set; } + public GatewayType? Gateway { get; set; } + [Display(Name = "Gateway Id")] + public string GatewayId { get; set; } + [Required] + public TransactionType? Type { get; set; } + + + public IEnumerable Validate(ValidationContext validationContext) + { + if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue)) + { + yield return new ValidationResult("Must provide either User Id, or Organization Id."); + } + } + + public Transaction ToTransaction(Guid? id = null) + { + return new Transaction + { + Id = id.GetValueOrDefault(), + UserId = UserId, + OrganizationId = OrganizationId, + Amount = Amount.Value, + RefundedAmount = RefundedAmount, + Refunded = Refunded ? true : (bool?)null, + Details = Details, + CreationDate = Date.Value, + PaymentMethodType = PaymentMethod, + Gateway = Gateway, + GatewayId = GatewayId, + Type = Type.Value + }; } } - - public Transaction ToTransaction(Guid? id = null) - { - return new Transaction - { - Id = id.GetValueOrDefault(), - UserId = UserId, - OrganizationId = OrganizationId, - Amount = Amount.Value, - RefundedAmount = RefundedAmount, - Refunded = Refunded ? true : (bool?)null, - Details = Details, - CreationDate = Date.Value, - PaymentMethodType = PaymentMethod, - Gateway = Gateway, - GatewayId = GatewayId, - Type = Type.Value - }; - } } diff --git a/src/Admin/Models/CursorPagedModel.cs b/src/Admin/Models/CursorPagedModel.cs index 35a4de922a..59d13e268f 100644 --- a/src/Admin/Models/CursorPagedModel.cs +++ b/src/Admin/Models/CursorPagedModel.cs @@ -1,9 +1,10 @@ -namespace Bit.Admin.Models; - -public class CursorPagedModel +namespace Bit.Admin.Models { - public List Items { get; set; } - public int Count { get; set; } - public string Cursor { get; set; } - public string NextCursor { get; set; } + public class CursorPagedModel + { + public List Items { get; set; } + public int Count { get; set; } + public string Cursor { get; set; } + public string NextCursor { get; set; } + } } diff --git a/src/Admin/Models/ErrorViewModel.cs b/src/Admin/Models/ErrorViewModel.cs index 3b24a1ece7..7a448776de 100644 --- a/src/Admin/Models/ErrorViewModel.cs +++ b/src/Admin/Models/ErrorViewModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Admin.Models; - -public class ErrorViewModel +namespace Bit.Admin.Models { - public string RequestId { get; set; } + public class ErrorViewModel + { + public string RequestId { get; set; } - public bool ShowRequestId => !string.IsNullOrEmpty(RequestId); + public bool ShowRequestId => !string.IsNullOrEmpty(RequestId); + } } diff --git a/src/Admin/Models/HomeModel.cs b/src/Admin/Models/HomeModel.cs index 900a04e41a..1bdebbe02e 100644 --- a/src/Admin/Models/HomeModel.cs +++ b/src/Admin/Models/HomeModel.cs @@ -1,9 +1,10 @@ using Bit.Core.Settings; -namespace Bit.Admin.Models; - -public class HomeModel +namespace Bit.Admin.Models { - public string CurrentVersion { get; set; } - public GlobalSettings GlobalSettings { get; set; } + public class HomeModel + { + public string CurrentVersion { get; set; } + public GlobalSettings GlobalSettings { get; set; } + } } diff --git a/src/Admin/Models/LicenseModel.cs b/src/Admin/Models/LicenseModel.cs index b0fd912018..47d34ad18f 100644 --- a/src/Admin/Models/LicenseModel.cs +++ b/src/Admin/Models/LicenseModel.cs @@ -1,34 +1,35 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models; - -public class LicenseModel : IValidatableObject +namespace Bit.Admin.Models { - [Display(Name = "User Id")] - public Guid? UserId { get; set; } - [Display(Name = "Organization Id")] - public Guid? OrganizationId { get; set; } - [Display(Name = "Installation Id")] - public Guid? InstallationId { get; set; } - [Required] - [Display(Name = "Version")] - public int Version { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class LicenseModel : IValidatableObject { - if (UserId.HasValue && OrganizationId.HasValue) - { - yield return new ValidationResult("Use either User Id or Organization Id. Not both."); - } + [Display(Name = "User Id")] + public Guid? UserId { get; set; } + [Display(Name = "Organization Id")] + public Guid? OrganizationId { get; set; } + [Display(Name = "Installation Id")] + public Guid? InstallationId { get; set; } + [Required] + [Display(Name = "Version")] + public int Version { get; set; } - if (!UserId.HasValue && !OrganizationId.HasValue) + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("User Id or Organization Id is required."); - } + if (UserId.HasValue && OrganizationId.HasValue) + { + yield return new ValidationResult("Use either User Id or Organization Id. Not both."); + } - if (OrganizationId.HasValue && !InstallationId.HasValue) - { - yield return new ValidationResult("Installation Id is required for organization licenses."); + if (!UserId.HasValue && !OrganizationId.HasValue) + { + yield return new ValidationResult("User Id or Organization Id is required."); + } + + if (OrganizationId.HasValue && !InstallationId.HasValue) + { + yield return new ValidationResult("Installation Id is required for organization licenses."); + } } } } diff --git a/src/Admin/Models/LogModel.cs b/src/Admin/Models/LogModel.cs index 8967025d12..3e0437998a 100644 --- a/src/Admin/Models/LogModel.cs +++ b/src/Admin/Models/LogModel.cs @@ -1,54 +1,55 @@ using Microsoft.Azure.Documents; using Newtonsoft.Json.Linq; -namespace Bit.Admin.Models; - -public class LogModel : Resource +namespace Bit.Admin.Models { - public long EventIdHash { get; set; } - public string Level { get; set; } - public string Message { get; set; } - public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message; - public string MessageTemplate { get; set; } - public IDictionary Properties { get; set; } - public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null; -} - -public class LogDetailsModel : LogModel -{ - public JObject Exception { get; set; } - - public string ExceptionToString(JObject e) + public class LogModel : Resource { - if (e == null) - { - return null; - } + public long EventIdHash { get; set; } + public string Level { get; set; } + public string Message { get; set; } + public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message; + public string MessageTemplate { get; set; } + public IDictionary Properties { get; set; } + public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null; + } - var val = string.Empty; - if (e["Message"] != null && e["Message"].ToObject() != null) - { - val += "Message:\n"; - val += e["Message"] + "\n"; - } + public class LogDetailsModel : LogModel + { + public JObject Exception { get; set; } - if (e["StackTrace"] != null && e["StackTrace"].ToObject() != null) + public string ExceptionToString(JObject e) { - val += "\nStack Trace:\n"; - val += e["StackTrace"]; - } - else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject() != null) - { - val += "\nStack Trace String:\n"; - val += e["StackTraceString"]; - } + if (e == null) + { + return null; + } - if (e["InnerException"] != null && e["InnerException"].ToObject() != null) - { - val += "\n\n=== Inner Exception ===\n\n"; - val += ExceptionToString(e["InnerException"].ToObject()); - } + var val = string.Empty; + if (e["Message"] != null && e["Message"].ToObject() != null) + { + val += "Message:\n"; + val += e["Message"] + "\n"; + } - return val; + if (e["StackTrace"] != null && e["StackTrace"].ToObject() != null) + { + val += "\nStack Trace:\n"; + val += e["StackTrace"]; + } + else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject() != null) + { + val += "\nStack Trace String:\n"; + val += e["StackTraceString"]; + } + + if (e["InnerException"] != null && e["InnerException"].ToObject() != null) + { + val += "\n\n=== Inner Exception ===\n\n"; + val += ExceptionToString(e["InnerException"].ToObject()); + } + + return val; + } } } diff --git a/src/Admin/Models/LoginModel.cs b/src/Admin/Models/LoginModel.cs index 7f147874bb..fa77ddfe11 100644 --- a/src/Admin/Models/LoginModel.cs +++ b/src/Admin/Models/LoginModel.cs @@ -1,13 +1,14 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models; - -public class LoginModel +namespace Bit.Admin.Models { - [Required] - [EmailAddress] - public string Email { get; set; } - public string ReturnUrl { get; set; } - public string Error { get; set; } - public string Success { get; set; } + public class LoginModel + { + [Required] + [EmailAddress] + public string Email { get; set; } + public string ReturnUrl { get; set; } + public string Error { get; set; } + public string Success { get; set; } + } } diff --git a/src/Admin/Models/LogsModel.cs b/src/Admin/Models/LogsModel.cs index c5527a3191..d274aa9be3 100644 --- a/src/Admin/Models/LogsModel.cs +++ b/src/Admin/Models/LogsModel.cs @@ -1,11 +1,12 @@ using Serilog.Events; -namespace Bit.Admin.Models; - -public class LogsModel : CursorPagedModel +namespace Bit.Admin.Models { - public LogEventLevel? Level { get; set; } - public string Project { get; set; } - public DateTime? Start { get; set; } - public DateTime? End { get; set; } + public class LogsModel : CursorPagedModel + { + public LogEventLevel? Level { get; set; } + public string Project { get; set; } + public DateTime? Start { get; set; } + public DateTime? End { get; set; } + } } diff --git a/src/Admin/Models/OrganizationEditModel.cs b/src/Admin/Models/OrganizationEditModel.cs index 4a6fdde5ee..bf0d6c8d5c 100644 --- a/src/Admin/Models/OrganizationEditModel.cs +++ b/src/Admin/Models/OrganizationEditModel.cs @@ -6,147 +6,148 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Admin.Models; - -public class OrganizationEditModel : OrganizationViewModel +namespace Bit.Admin.Models { - public OrganizationEditModel() { } - - public OrganizationEditModel(Organization org, IEnumerable orgUsers, - IEnumerable ciphers, IEnumerable collections, IEnumerable groups, - IEnumerable policies, BillingInfo billingInfo, IEnumerable connections, - GlobalSettings globalSettings) - : base(org, connections, orgUsers, ciphers, collections, groups, policies) + public class OrganizationEditModel : OrganizationViewModel { - BillingInfo = billingInfo; - BraintreeMerchantId = globalSettings.Braintree.MerchantId; + public OrganizationEditModel() { } - Name = org.Name; - BusinessName = org.BusinessName; - BillingEmail = org.BillingEmail; - PlanType = org.PlanType; - Plan = org.Plan; - Seats = org.Seats; - MaxAutoscaleSeats = org.MaxAutoscaleSeats; - MaxCollections = org.MaxCollections; - UsePolicies = org.UsePolicies; - UseSso = org.UseSso; - UseKeyConnector = org.UseKeyConnector; - UseScim = org.UseScim; - UseGroups = org.UseGroups; - UseDirectory = org.UseDirectory; - UseEvents = org.UseEvents; - UseTotp = org.UseTotp; - Use2fa = org.Use2fa; - UseApi = org.UseApi; - UseResetPassword = org.UseResetPassword; - SelfHost = org.SelfHost; - UsersGetPremium = org.UsersGetPremium; - MaxStorageGb = org.MaxStorageGb; - Gateway = org.Gateway; - GatewayCustomerId = org.GatewayCustomerId; - GatewaySubscriptionId = org.GatewaySubscriptionId; - Enabled = org.Enabled; - LicenseKey = org.LicenseKey; - ExpirationDate = org.ExpirationDate; - } + public OrganizationEditModel(Organization org, IEnumerable orgUsers, + IEnumerable ciphers, IEnumerable collections, IEnumerable groups, + IEnumerable policies, BillingInfo billingInfo, IEnumerable connections, + GlobalSettings globalSettings) + : base(org, connections, orgUsers, ciphers, collections, groups, policies) + { + BillingInfo = billingInfo; + BraintreeMerchantId = globalSettings.Braintree.MerchantId; - public BillingInfo BillingInfo { get; set; } - public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); - public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm"); - public string BraintreeMerchantId { get; set; } + Name = org.Name; + BusinessName = org.BusinessName; + BillingEmail = org.BillingEmail; + PlanType = org.PlanType; + Plan = org.Plan; + Seats = org.Seats; + MaxAutoscaleSeats = org.MaxAutoscaleSeats; + MaxCollections = org.MaxCollections; + UsePolicies = org.UsePolicies; + UseSso = org.UseSso; + UseKeyConnector = org.UseKeyConnector; + UseScim = org.UseScim; + UseGroups = org.UseGroups; + UseDirectory = org.UseDirectory; + UseEvents = org.UseEvents; + UseTotp = org.UseTotp; + Use2fa = org.Use2fa; + UseApi = org.UseApi; + UseResetPassword = org.UseResetPassword; + SelfHost = org.SelfHost; + UsersGetPremium = org.UsersGetPremium; + MaxStorageGb = org.MaxStorageGb; + Gateway = org.Gateway; + GatewayCustomerId = org.GatewayCustomerId; + GatewaySubscriptionId = org.GatewaySubscriptionId; + Enabled = org.Enabled; + LicenseKey = org.LicenseKey; + ExpirationDate = org.ExpirationDate; + } - [Required] - [Display(Name = "Name")] - public string Name { get; set; } - [Display(Name = "Business Name")] - public string BusinessName { get; set; } - [Display(Name = "Billing Email")] - public string BillingEmail { get; set; } - [Required] - [Display(Name = "Plan")] - public PlanType? PlanType { get; set; } - [Required] - [Display(Name = "Plan Name")] - public string Plan { get; set; } - [Display(Name = "Seats")] - public int? Seats { get; set; } - [Display(Name = "Max. Autoscale Seats")] - public int? MaxAutoscaleSeats { get; set; } - [Display(Name = "Max. Collections")] - public short? MaxCollections { get; set; } - [Display(Name = "Policies")] - public bool UsePolicies { get; set; } - [Display(Name = "SSO")] - public bool UseSso { get; set; } - [Display(Name = "Key Connector with Customer Encryption")] - public bool UseKeyConnector { get; set; } - [Display(Name = "Groups")] - public bool UseGroups { get; set; } - [Display(Name = "Directory")] - public bool UseDirectory { get; set; } - [Display(Name = "Events")] - public bool UseEvents { get; set; } - [Display(Name = "TOTP")] - public bool UseTotp { get; set; } - [Display(Name = "2FA")] - public bool Use2fa { get; set; } - [Display(Name = "API")] - public bool UseApi { get; set; } - [Display(Name = "Reset Password")] - public bool UseResetPassword { get; set; } - [Display(Name = "SCIM")] - public bool UseScim { get; set; } - [Display(Name = "Self Host")] - public bool SelfHost { get; set; } - [Display(Name = "Users Get Premium")] - public bool UsersGetPremium { get; set; } - [Display(Name = "Max. Storage GB")] - public short? MaxStorageGb { get; set; } - [Display(Name = "Gateway")] - public GatewayType? Gateway { get; set; } - [Display(Name = "Gateway Customer Id")] - public string GatewayCustomerId { get; set; } - [Display(Name = "Gateway Subscription Id")] - public string GatewaySubscriptionId { get; set; } - [Display(Name = "Enabled")] - public bool Enabled { get; set; } - [Display(Name = "License Key")] - public string LicenseKey { get; set; } - [Display(Name = "Expiration Date")] - public DateTime? ExpirationDate { get; set; } - public bool SalesAssistedTrialStarted { get; set; } + public BillingInfo BillingInfo { get; set; } + public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); + public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm"); + public string BraintreeMerchantId { get; set; } - public Organization ToOrganization(Organization existingOrganization) - { - existingOrganization.Name = Name; - existingOrganization.BusinessName = BusinessName; - existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - existingOrganization.PlanType = PlanType.Value; - existingOrganization.Plan = Plan; - existingOrganization.Seats = Seats; - existingOrganization.MaxCollections = MaxCollections; - existingOrganization.UsePolicies = UsePolicies; - existingOrganization.UseSso = UseSso; - existingOrganization.UseKeyConnector = UseKeyConnector; - existingOrganization.UseScim = UseScim; - existingOrganization.UseGroups = UseGroups; - existingOrganization.UseDirectory = UseDirectory; - existingOrganization.UseEvents = UseEvents; - existingOrganization.UseTotp = UseTotp; - existingOrganization.Use2fa = Use2fa; - existingOrganization.UseApi = UseApi; - existingOrganization.UseResetPassword = UseResetPassword; - existingOrganization.SelfHost = SelfHost; - existingOrganization.UsersGetPremium = UsersGetPremium; - existingOrganization.MaxStorageGb = MaxStorageGb; - existingOrganization.Gateway = Gateway; - existingOrganization.GatewayCustomerId = GatewayCustomerId; - existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId; - existingOrganization.Enabled = Enabled; - existingOrganization.LicenseKey = LicenseKey; - existingOrganization.ExpirationDate = ExpirationDate; - existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats; - return existingOrganization; + [Required] + [Display(Name = "Name")] + public string Name { get; set; } + [Display(Name = "Business Name")] + public string BusinessName { get; set; } + [Display(Name = "Billing Email")] + public string BillingEmail { get; set; } + [Required] + [Display(Name = "Plan")] + public PlanType? PlanType { get; set; } + [Required] + [Display(Name = "Plan Name")] + public string Plan { get; set; } + [Display(Name = "Seats")] + public int? Seats { get; set; } + [Display(Name = "Max. Autoscale Seats")] + public int? MaxAutoscaleSeats { get; set; } + [Display(Name = "Max. Collections")] + public short? MaxCollections { get; set; } + [Display(Name = "Policies")] + public bool UsePolicies { get; set; } + [Display(Name = "SSO")] + public bool UseSso { get; set; } + [Display(Name = "Key Connector with Customer Encryption")] + public bool UseKeyConnector { get; set; } + [Display(Name = "Groups")] + public bool UseGroups { get; set; } + [Display(Name = "Directory")] + public bool UseDirectory { get; set; } + [Display(Name = "Events")] + public bool UseEvents { get; set; } + [Display(Name = "TOTP")] + public bool UseTotp { get; set; } + [Display(Name = "2FA")] + public bool Use2fa { get; set; } + [Display(Name = "API")] + public bool UseApi { get; set; } + [Display(Name = "Reset Password")] + public bool UseResetPassword { get; set; } + [Display(Name = "SCIM")] + public bool UseScim { get; set; } + [Display(Name = "Self Host")] + public bool SelfHost { get; set; } + [Display(Name = "Users Get Premium")] + public bool UsersGetPremium { get; set; } + [Display(Name = "Max. Storage GB")] + public short? MaxStorageGb { get; set; } + [Display(Name = "Gateway")] + public GatewayType? Gateway { get; set; } + [Display(Name = "Gateway Customer Id")] + public string GatewayCustomerId { get; set; } + [Display(Name = "Gateway Subscription Id")] + public string GatewaySubscriptionId { get; set; } + [Display(Name = "Enabled")] + public bool Enabled { get; set; } + [Display(Name = "License Key")] + public string LicenseKey { get; set; } + [Display(Name = "Expiration Date")] + public DateTime? ExpirationDate { get; set; } + public bool SalesAssistedTrialStarted { get; set; } + + public Organization ToOrganization(Organization existingOrganization) + { + existingOrganization.Name = Name; + existingOrganization.BusinessName = BusinessName; + existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + existingOrganization.PlanType = PlanType.Value; + existingOrganization.Plan = Plan; + existingOrganization.Seats = Seats; + existingOrganization.MaxCollections = MaxCollections; + existingOrganization.UsePolicies = UsePolicies; + existingOrganization.UseSso = UseSso; + existingOrganization.UseKeyConnector = UseKeyConnector; + existingOrganization.UseScim = UseScim; + existingOrganization.UseGroups = UseGroups; + existingOrganization.UseDirectory = UseDirectory; + existingOrganization.UseEvents = UseEvents; + existingOrganization.UseTotp = UseTotp; + existingOrganization.Use2fa = Use2fa; + existingOrganization.UseApi = UseApi; + existingOrganization.UseResetPassword = UseResetPassword; + existingOrganization.SelfHost = SelfHost; + existingOrganization.UsersGetPremium = UsersGetPremium; + existingOrganization.MaxStorageGb = MaxStorageGb; + existingOrganization.Gateway = Gateway; + existingOrganization.GatewayCustomerId = GatewayCustomerId; + existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId; + existingOrganization.Enabled = Enabled; + existingOrganization.LicenseKey = LicenseKey; + existingOrganization.ExpirationDate = ExpirationDate; + existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats; + return existingOrganization; + } } } diff --git a/src/Admin/Models/OrganizationViewModel.cs b/src/Admin/Models/OrganizationViewModel.cs index 5a487cd03e..c17f273a7f 100644 --- a/src/Admin/Models/OrganizationViewModel.cs +++ b/src/Admin/Models/OrganizationViewModel.cs @@ -2,48 +2,49 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Admin.Models; - -public class OrganizationViewModel +namespace Bit.Admin.Models { - public OrganizationViewModel() { } - - public OrganizationViewModel(Organization org, IEnumerable connections, - IEnumerable orgUsers, IEnumerable ciphers, IEnumerable collections, - IEnumerable groups, IEnumerable policies) + public class OrganizationViewModel { - Organization = org; - Connections = connections ?? Enumerable.Empty(); - HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null; - UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited); - UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted); - UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed); - UserCount = orgUsers.Count(); - CipherCount = ciphers.Count(); - CollectionCount = collections.Count(); - GroupCount = groups?.Count() ?? 0; - PolicyCount = policies?.Count() ?? 0; - Owners = string.Join(", ", - orgUsers - .Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed) - .Select(u => u.Email)); - Admins = string.Join(", ", - orgUsers - .Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed) - .Select(u => u.Email)); - } + public OrganizationViewModel() { } - public Organization Organization { get; set; } - public IEnumerable Connections { get; set; } - public string Owners { get; set; } - public string Admins { get; set; } - public int UserInvitedCount { get; set; } - public int UserConfirmedCount { get; set; } - public int UserAcceptedCount { get; set; } - public int UserCount { get; set; } - public int CipherCount { get; set; } - public int CollectionCount { get; set; } - public int GroupCount { get; set; } - public int PolicyCount { get; set; } - public bool HasPublicPrivateKeys { get; set; } + public OrganizationViewModel(Organization org, IEnumerable connections, + IEnumerable orgUsers, IEnumerable ciphers, IEnumerable collections, + IEnumerable groups, IEnumerable policies) + { + Organization = org; + Connections = connections ?? Enumerable.Empty(); + HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null; + UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited); + UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted); + UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed); + UserCount = orgUsers.Count(); + CipherCount = ciphers.Count(); + CollectionCount = collections.Count(); + GroupCount = groups?.Count() ?? 0; + PolicyCount = policies?.Count() ?? 0; + Owners = string.Join(", ", + orgUsers + .Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed) + .Select(u => u.Email)); + Admins = string.Join(", ", + orgUsers + .Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed) + .Select(u => u.Email)); + } + + public Organization Organization { get; set; } + public IEnumerable Connections { get; set; } + public string Owners { get; set; } + public string Admins { get; set; } + public int UserInvitedCount { get; set; } + public int UserConfirmedCount { get; set; } + public int UserAcceptedCount { get; set; } + public int UserCount { get; set; } + public int CipherCount { get; set; } + public int CollectionCount { get; set; } + public int GroupCount { get; set; } + public int PolicyCount { get; set; } + public bool HasPublicPrivateKeys { get; set; } + } } diff --git a/src/Admin/Models/OrganizationsModel.cs b/src/Admin/Models/OrganizationsModel.cs index 706377f8e2..da2eb20d6c 100644 --- a/src/Admin/Models/OrganizationsModel.cs +++ b/src/Admin/Models/OrganizationsModel.cs @@ -1,12 +1,13 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models; - -public class OrganizationsModel : PagedModel +namespace Bit.Admin.Models { - public string Name { get; set; } - public string UserEmail { get; set; } - public bool? Paid { get; set; } - public string Action { get; set; } - public bool SelfHosted { get; set; } + public class OrganizationsModel : PagedModel + { + public string Name { get; set; } + public string UserEmail { get; set; } + public bool? Paid { get; set; } + public string Action { get; set; } + public bool SelfHosted { get; set; } + } } diff --git a/src/Admin/Models/PagedModel.cs b/src/Admin/Models/PagedModel.cs index 4c9c8e1713..ac4f2e84db 100644 --- a/src/Admin/Models/PagedModel.cs +++ b/src/Admin/Models/PagedModel.cs @@ -1,10 +1,11 @@ -namespace Bit.Admin.Models; - -public abstract class PagedModel +namespace Bit.Admin.Models { - public List Items { get; set; } - public int Page { get; set; } - public int Count { get; set; } - public int? PreviousPage => Page < 2 ? (int?)null : Page - 1; - public int? NextPage => Items.Count < Count ? (int?)null : Page + 1; + public abstract class PagedModel + { + public List Items { get; set; } + public int Page { get; set; } + public int Count { get; set; } + public int? PreviousPage => Page < 2 ? (int?)null : Page - 1; + public int? NextPage => Items.Count < Count ? (int?)null : Page + 1; + } } diff --git a/src/Admin/Models/PromoteAdminModel.cs b/src/Admin/Models/PromoteAdminModel.cs index bc076d6ab1..0beae6bd8b 100644 --- a/src/Admin/Models/PromoteAdminModel.cs +++ b/src/Admin/Models/PromoteAdminModel.cs @@ -1,13 +1,14 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models; - -public class PromoteAdminModel +namespace Bit.Admin.Models { - [Required] - [Display(Name = "Admin User Id")] - public Guid? UserId { get; set; } - [Required] - [Display(Name = "Organization Id")] - public Guid? OrganizationId { get; set; } + public class PromoteAdminModel + { + [Required] + [Display(Name = "Admin User Id")] + public Guid? UserId { get; set; } + [Required] + [Display(Name = "Organization Id")] + public Guid? OrganizationId { get; set; } + } } diff --git a/src/Admin/Models/ProviderEditModel.cs b/src/Admin/Models/ProviderEditModel.cs index 92b2f89e98..578d0ff226 100644 --- a/src/Admin/Models/ProviderEditModel.cs +++ b/src/Admin/Models/ProviderEditModel.cs @@ -2,32 +2,33 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; -namespace Bit.Admin.Models; - -public class ProviderEditModel : ProviderViewModel +namespace Bit.Admin.Models { - public ProviderEditModel() { } - - public ProviderEditModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) - : base(provider, providerUsers, organizations) + public class ProviderEditModel : ProviderViewModel { - Name = provider.Name; - BusinessName = provider.BusinessName; - BillingEmail = provider.BillingEmail; - } + public ProviderEditModel() { } - [Display(Name = "Billing Email")] - public string BillingEmail { get; set; } - [Display(Name = "Business Name")] - public string BusinessName { get; set; } - public string Name { get; set; } - [Display(Name = "Events")] + public ProviderEditModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) + : base(provider, providerUsers, organizations) + { + Name = provider.Name; + BusinessName = provider.BusinessName; + BillingEmail = provider.BillingEmail; + } - public Provider ToProvider(Provider existingProvider) - { - existingProvider.Name = Name; - existingProvider.BusinessName = BusinessName; - existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - return existingProvider; + [Display(Name = "Billing Email")] + public string BillingEmail { get; set; } + [Display(Name = "Business Name")] + public string BusinessName { get; set; } + public string Name { get; set; } + [Display(Name = "Events")] + + public Provider ToProvider(Provider existingProvider) + { + existingProvider.Name = Name; + existingProvider.BusinessName = BusinessName; + existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + return existingProvider; + } } } diff --git a/src/Admin/Models/ProviderViewModel.cs b/src/Admin/Models/ProviderViewModel.cs index 766101e884..05fae3c9c1 100644 --- a/src/Admin/Models/ProviderViewModel.cs +++ b/src/Admin/Models/ProviderViewModel.cs @@ -2,23 +2,24 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Admin.Models; - -public class ProviderViewModel +namespace Bit.Admin.Models { - public ProviderViewModel() { } - - public ProviderViewModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) + public class ProviderViewModel { - Provider = provider; - UserCount = providerUsers.Count(); - ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin); + public ProviderViewModel() { } - ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id); + public ProviderViewModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) + { + Provider = provider; + UserCount = providerUsers.Count(); + ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin); + + ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id); + } + + public int UserCount { get; set; } + public Provider Provider { get; set; } + public IEnumerable ProviderAdmins { get; set; } + public IEnumerable ProviderOrganizations { get; set; } } - - public int UserCount { get; set; } - public Provider Provider { get; set; } - public IEnumerable ProviderAdmins { get; set; } - public IEnumerable ProviderOrganizations { get; set; } } diff --git a/src/Admin/Models/ProvidersModel.cs b/src/Admin/Models/ProvidersModel.cs index dccf4a4d76..02509593d9 100644 --- a/src/Admin/Models/ProvidersModel.cs +++ b/src/Admin/Models/ProvidersModel.cs @@ -1,12 +1,13 @@ using Bit.Core.Entities.Provider; -namespace Bit.Admin.Models; - -public class ProvidersModel : PagedModel +namespace Bit.Admin.Models { - public string Name { get; set; } - public string UserEmail { get; set; } - public bool? Paid { get; set; } - public string Action { get; set; } - public bool SelfHosted { get; set; } + public class ProvidersModel : PagedModel + { + public string Name { get; set; } + public string UserEmail { get; set; } + public bool? Paid { get; set; } + public string Action { get; set; } + public bool SelfHosted { get; set; } + } } diff --git a/src/Admin/Models/StripeSubscriptionsModel.cs b/src/Admin/Models/StripeSubscriptionsModel.cs index 99e9c5b77a..3e30d63d50 100644 --- a/src/Admin/Models/StripeSubscriptionsModel.cs +++ b/src/Admin/Models/StripeSubscriptionsModel.cs @@ -1,42 +1,43 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Models.BitStripe; -namespace Bit.Admin.Models; - -public class StripeSubscriptionRowModel +namespace Bit.Admin.Models { - public Stripe.Subscription Subscription { get; set; } - public bool Selected { get; set; } - - public StripeSubscriptionRowModel() { } - public StripeSubscriptionRowModel(Stripe.Subscription subscription) + public class StripeSubscriptionRowModel { - Subscription = subscription; - } -} + public Stripe.Subscription Subscription { get; set; } + public bool Selected { get; set; } -public enum StripeSubscriptionsAction -{ - Search, - PreviousPage, - NextPage, - Export, - BulkCancel -} - -public class StripeSubscriptionsModel : IValidatableObject -{ - public List Items { get; set; } - public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search; - public string Message { get; set; } - public List Prices { get; set; } - public List TestClocks { get; set; } - public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions(); - public IEnumerable Validate(ValidationContext validationContext) - { - if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") + public StripeSubscriptionRowModel() { } + public StripeSubscriptionRowModel(Stripe.Subscription subscription) { - yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); + Subscription = subscription; + } + } + + public enum StripeSubscriptionsAction + { + Search, + PreviousPage, + NextPage, + Export, + BulkCancel + } + + public class StripeSubscriptionsModel : IValidatableObject + { + public List Items { get; set; } + public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search; + public string Message { get; set; } + public List Prices { get; set; } + public List TestClocks { get; set; } + public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions(); + public IEnumerable Validate(ValidationContext validationContext) + { + if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") + { + yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); + } } } } diff --git a/src/Admin/Models/TaxRateAddEditModel.cs b/src/Admin/Models/TaxRateAddEditModel.cs index bfa87d7cc8..e55ec87c68 100644 --- a/src/Admin/Models/TaxRateAddEditModel.cs +++ b/src/Admin/Models/TaxRateAddEditModel.cs @@ -1,10 +1,11 @@ -namespace Bit.Admin.Models; - -public class TaxRateAddEditModel +namespace Bit.Admin.Models { - public string StripeTaxRateId { get; set; } - public string Country { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public decimal Rate { get; set; } + public class TaxRateAddEditModel + { + public string StripeTaxRateId { get; set; } + public string Country { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public decimal Rate { get; set; } + } } diff --git a/src/Admin/Models/TaxRatesModel.cs b/src/Admin/Models/TaxRatesModel.cs index 0af073f384..92564d82f7 100644 --- a/src/Admin/Models/TaxRatesModel.cs +++ b/src/Admin/Models/TaxRatesModel.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models; - -public class TaxRatesModel : PagedModel +namespace Bit.Admin.Models { - public string Message { get; set; } + public class TaxRatesModel : PagedModel + { + public string Message { get; set; } + } } diff --git a/src/Admin/Models/UserEditModel.cs b/src/Admin/Models/UserEditModel.cs index d7ef56f085..5b789c73d9 100644 --- a/src/Admin/Models/UserEditModel.cs +++ b/src/Admin/Models/UserEditModel.cs @@ -4,70 +4,71 @@ using Bit.Core.Models.Business; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Admin.Models; - -public class UserEditModel : UserViewModel +namespace Bit.Admin.Models { - public UserEditModel() { } - - public UserEditModel(User user, IEnumerable ciphers, BillingInfo billingInfo, - GlobalSettings globalSettings) - : base(user, ciphers) + public class UserEditModel : UserViewModel { - BillingInfo = billingInfo; - BraintreeMerchantId = globalSettings.Braintree.MerchantId; + public UserEditModel() { } - Name = user.Name; - Email = user.Email; - EmailVerified = user.EmailVerified; - Premium = user.Premium; - MaxStorageGb = user.MaxStorageGb; - Gateway = user.Gateway; - GatewayCustomerId = user.GatewayCustomerId; - GatewaySubscriptionId = user.GatewaySubscriptionId; - LicenseKey = user.LicenseKey; - PremiumExpirationDate = user.PremiumExpirationDate; - } + public UserEditModel(User user, IEnumerable ciphers, BillingInfo billingInfo, + GlobalSettings globalSettings) + : base(user, ciphers) + { + BillingInfo = billingInfo; + BraintreeMerchantId = globalSettings.Braintree.MerchantId; - public BillingInfo BillingInfo { get; set; } - public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); - public string OneYearExpirationDate => DateTime.Now.AddYears(1).ToString("yyyy-MM-ddTHH:mm"); - public string BraintreeMerchantId { get; set; } + Name = user.Name; + Email = user.Email; + EmailVerified = user.EmailVerified; + Premium = user.Premium; + MaxStorageGb = user.MaxStorageGb; + Gateway = user.Gateway; + GatewayCustomerId = user.GatewayCustomerId; + GatewaySubscriptionId = user.GatewaySubscriptionId; + LicenseKey = user.LicenseKey; + PremiumExpirationDate = user.PremiumExpirationDate; + } - [Display(Name = "Name")] - public string Name { get; set; } - [Required] - [Display(Name = "Email")] - public string Email { get; set; } - [Display(Name = "Email Verified")] - public bool EmailVerified { get; set; } - [Display(Name = "Premium")] - public bool Premium { get; set; } - [Display(Name = "Max. Storage GB")] - public short? MaxStorageGb { get; set; } - [Display(Name = "Gateway")] - public Core.Enums.GatewayType? Gateway { get; set; } - [Display(Name = "Gateway Customer Id")] - public string GatewayCustomerId { get; set; } - [Display(Name = "Gateway Subscription Id")] - public string GatewaySubscriptionId { get; set; } - [Display(Name = "License Key")] - public string LicenseKey { get; set; } - [Display(Name = "Premium Expiration Date")] - public DateTime? PremiumExpirationDate { get; set; } + public BillingInfo BillingInfo { get; set; } + public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); + public string OneYearExpirationDate => DateTime.Now.AddYears(1).ToString("yyyy-MM-ddTHH:mm"); + public string BraintreeMerchantId { get; set; } - public User ToUser(User existingUser) - { - existingUser.Name = Name; - existingUser.Email = Email; - existingUser.EmailVerified = EmailVerified; - existingUser.Premium = Premium; - existingUser.MaxStorageGb = MaxStorageGb; - existingUser.Gateway = Gateway; - existingUser.GatewayCustomerId = GatewayCustomerId; - existingUser.GatewaySubscriptionId = GatewaySubscriptionId; - existingUser.LicenseKey = LicenseKey; - existingUser.PremiumExpirationDate = PremiumExpirationDate; - return existingUser; + [Display(Name = "Name")] + public string Name { get; set; } + [Required] + [Display(Name = "Email")] + public string Email { get; set; } + [Display(Name = "Email Verified")] + public bool EmailVerified { get; set; } + [Display(Name = "Premium")] + public bool Premium { get; set; } + [Display(Name = "Max. Storage GB")] + public short? MaxStorageGb { get; set; } + [Display(Name = "Gateway")] + public Core.Enums.GatewayType? Gateway { get; set; } + [Display(Name = "Gateway Customer Id")] + public string GatewayCustomerId { get; set; } + [Display(Name = "Gateway Subscription Id")] + public string GatewaySubscriptionId { get; set; } + [Display(Name = "License Key")] + public string LicenseKey { get; set; } + [Display(Name = "Premium Expiration Date")] + public DateTime? PremiumExpirationDate { get; set; } + + public User ToUser(User existingUser) + { + existingUser.Name = Name; + existingUser.Email = Email; + existingUser.EmailVerified = EmailVerified; + existingUser.Premium = Premium; + existingUser.MaxStorageGb = MaxStorageGb; + existingUser.Gateway = Gateway; + existingUser.GatewayCustomerId = GatewayCustomerId; + existingUser.GatewaySubscriptionId = GatewaySubscriptionId; + existingUser.LicenseKey = LicenseKey; + existingUser.PremiumExpirationDate = PremiumExpirationDate; + return existingUser; + } } } diff --git a/src/Admin/Models/UserViewModel.cs b/src/Admin/Models/UserViewModel.cs index f493f68f2c..adc8fb2689 100644 --- a/src/Admin/Models/UserViewModel.cs +++ b/src/Admin/Models/UserViewModel.cs @@ -1,17 +1,18 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models; - -public class UserViewModel +namespace Bit.Admin.Models { - public UserViewModel() { } - - public UserViewModel(User user, IEnumerable ciphers) + public class UserViewModel { - User = user; - CipherCount = ciphers.Count(); - } + public UserViewModel() { } - public User User { get; set; } - public int CipherCount { get; set; } + public UserViewModel(User user, IEnumerable ciphers) + { + User = user; + CipherCount = ciphers.Count(); + } + + public User User { get; set; } + public int CipherCount { get; set; } + } } diff --git a/src/Admin/Models/UsersModel.cs b/src/Admin/Models/UsersModel.cs index 0a54e318db..1215a95558 100644 --- a/src/Admin/Models/UsersModel.cs +++ b/src/Admin/Models/UsersModel.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models; - -public class UsersModel : PagedModel +namespace Bit.Admin.Models { - public string Email { get; set; } - public string Action { get; set; } + public class UsersModel : PagedModel + { + public string Email { get; set; } + public string Action { get; set; } + } } diff --git a/src/Admin/Program.cs b/src/Admin/Program.cs index f5bc877ab9..d8a55e7b6c 100644 --- a/src/Admin/Program.cs +++ b/src/Admin/Program.cs @@ -1,36 +1,37 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Admin; - -public class Program +namespace Bit.Admin { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.ConfigureKestrel(o => + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => { - o.Limits.MaxRequestLineSize = 20_000; - }); - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + webBuilder.ConfigureKestrel(o => { - return false; - } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); + o.Limits.MaxRequestLineSize = 20_000; + }); + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); + } } } diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index 37645873eb..ea8485c79f 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -11,127 +11,128 @@ using Stripe; using Bit.Commercial.Core.Utilities; #endif -namespace Bit.Admin; - -public class Startup +namespace Bit.Admin { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; private set; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - services.Configure(Configuration.GetSection("AdminSettings")); - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - - // Identity - services.AddPasswordlessIdentityServices(globalSettings); - services.Configure(options => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - options.ValidationInterval = TimeSpan.FromMinutes(5); - }); - if (globalSettings.SelfHosted) - { - services.ConfigureApplicationCookie(options => - { - options.Cookie.Path = "/admin"; - }); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); + public IConfiguration Configuration { get; private set; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + services.Configure(Configuration.GetSection("AdminSettings")); + + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + + // Identity + services.AddPasswordlessIdentityServices(globalSettings); + services.Configure(options => + { + options.ValidationInterval = TimeSpan.FromMinutes(5); + }); + if (globalSettings.SelfHosted) + { + services.ConfigureApplicationCookie(options => + { + options.Cookie.Path = "/admin"; + }); + } + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); #if OSS - services.AddOosServices(); + services.AddOosServices(); #else - services.AddCommCoreServices(); + services.AddCommCoreServices(); #endif - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - services.Configure(options => options.LowercaseUrls = true); + // Mvc + services.AddMvc(config => + { + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + services.Configure(options => options.LowercaseUrls = true); - // Jobs service - Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); - services.AddHostedService(); - if (globalSettings.SelfHosted) - { - services.AddHostedService(); - } - else - { - if (CoreHelpers.SettingHasValue(globalSettings.Storage.ConnectionString)) + // Jobs service + Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); + services.AddHostedService(); + if (globalSettings.SelfHosted) { - services.AddHostedService(); + services.AddHostedService(); } - else if (CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) + else { - services.AddHostedService(); - } - if (CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) - { - services.AddHostedService(); + if (CoreHelpers.SettingHasValue(globalSettings.Storage.ConnectionString)) + { + services.AddHostedService(); + } + else if (CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) + { + services.AddHostedService(); + } + if (CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) + { + services.AddHostedService(); + } } } - } - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (globalSettings.SelfHosted) + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) { - app.UsePathBase("/admin"); - app.UseForwardedHeaders(globalSettings); - } + app.UseSerilog(env, appLifetime, globalSettings); - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - else - { - app.UseExceptionHandler("/error"); - } + // Add general security headers + app.UseMiddleware(); - app.UseStaticFiles(); - app.UseRouting(); - app.UseAuthentication(); - app.UseAuthorization(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + if (globalSettings.SelfHosted) + { + app.UsePathBase("/admin"); + app.UseForwardedHeaders(globalSettings); + } + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + else + { + app.UseExceptionHandler("/error"); + } + + app.UseStaticFiles(); + app.UseRouting(); + app.UseAuthentication(); + app.UseAuthorization(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + } } } diff --git a/src/Admin/TagHelpers/ActivePageTagHelper.cs b/src/Admin/TagHelpers/ActivePageTagHelper.cs index a148e3cdf7..6e400383d1 100644 --- a/src/Admin/TagHelpers/ActivePageTagHelper.cs +++ b/src/Admin/TagHelpers/ActivePageTagHelper.cs @@ -3,71 +3,72 @@ using Microsoft.AspNetCore.Mvc.Rendering; using Microsoft.AspNetCore.Mvc.ViewFeatures; using Microsoft.AspNetCore.Razor.TagHelpers; -namespace Bit.Admin.TagHelpers; - -[HtmlTargetElement("li", Attributes = ActiveControllerName)] -[HtmlTargetElement("li", Attributes = ActiveActionName)] -public class ActivePageTagHelper : TagHelper +namespace Bit.Admin.TagHelpers { - private const string ActiveControllerName = "active-controller"; - private const string ActiveActionName = "active-action"; - - private readonly IHtmlGenerator _generator; - - public ActivePageTagHelper(IHtmlGenerator generator) + [HtmlTargetElement("li", Attributes = ActiveControllerName)] + [HtmlTargetElement("li", Attributes = ActiveActionName)] + public class ActivePageTagHelper : TagHelper { - _generator = generator; - } + private const string ActiveControllerName = "active-controller"; + private const string ActiveActionName = "active-action"; - [HtmlAttributeNotBound] - [ViewContext] - public ViewContext ViewContext { get; set; } - [HtmlAttributeName(ActiveControllerName)] - public string ActiveController { get; set; } - [HtmlAttributeName(ActiveActionName)] - public string ActiveAction { get; set; } + private readonly IHtmlGenerator _generator; - public override void Process(TagHelperContext context, TagHelperOutput output) - { - if (context == null) + public ActivePageTagHelper(IHtmlGenerator generator) { - throw new ArgumentNullException(nameof(context)); + _generator = generator; } - if (output == null) - { - throw new ArgumentNullException(nameof(output)); - } + [HtmlAttributeNotBound] + [ViewContext] + public ViewContext ViewContext { get; set; } + [HtmlAttributeName(ActiveControllerName)] + public string ActiveController { get; set; } + [HtmlAttributeName(ActiveActionName)] + public string ActiveAction { get; set; } - if (ActiveAction == null && ActiveController == null) + public override void Process(TagHelperContext context, TagHelperOutput output) { - return; - } - - var descriptor = ViewContext.ActionDescriptor as ControllerActionDescriptor; - if (descriptor == null) - { - return; - } - - var controllerMatch = ActiveMatch(ActiveController, descriptor.ControllerName); - var actionMatch = ActiveMatch(ActiveAction, descriptor.ActionName); - if (controllerMatch && actionMatch) - { - var classValue = "active"; - if (output.Attributes["class"] != null) + if (context == null) { - classValue += " " + output.Attributes["class"].Value; - output.Attributes.Remove(output.Attributes["class"]); + throw new ArgumentNullException(nameof(context)); } - output.Attributes.Add("class", classValue); + if (output == null) + { + throw new ArgumentNullException(nameof(output)); + } + + if (ActiveAction == null && ActiveController == null) + { + return; + } + + var descriptor = ViewContext.ActionDescriptor as ControllerActionDescriptor; + if (descriptor == null) + { + return; + } + + var controllerMatch = ActiveMatch(ActiveController, descriptor.ControllerName); + var actionMatch = ActiveMatch(ActiveAction, descriptor.ActionName); + if (controllerMatch && actionMatch) + { + var classValue = "active"; + if (output.Attributes["class"] != null) + { + classValue += " " + output.Attributes["class"].Value; + output.Attributes.Remove(output.Attributes["class"]); + } + + output.Attributes.Add("class", classValue); + } + } + + private bool ActiveMatch(string route, string descriptor) + { + return route == null || route == "*" || + route.Split(',').Any(c => c.Trim().ToLower() == descriptor.ToLower()); } } - - private bool ActiveMatch(string route, string descriptor) - { - return route == null || route == "*" || - route.Split(',').Any(c => c.Trim().ToLower() == descriptor.ToLower()); - } } diff --git a/src/Admin/TagHelpers/OptionSelectedTagHelper.cs b/src/Admin/TagHelpers/OptionSelectedTagHelper.cs index 3dc9562a06..190d3d1cc5 100644 --- a/src/Admin/TagHelpers/OptionSelectedTagHelper.cs +++ b/src/Admin/TagHelpers/OptionSelectedTagHelper.cs @@ -1,42 +1,43 @@ using Microsoft.AspNetCore.Mvc.ViewFeatures; using Microsoft.AspNetCore.Razor.TagHelpers; -namespace Bit.Admin.TagHelpers; - -[HtmlTargetElement("option", Attributes = SelectedName)] -public class OptionSelectedTagHelper : TagHelper +namespace Bit.Admin.TagHelpers { - private const string SelectedName = "asp-selected"; - - private readonly IHtmlGenerator _generator; - - public OptionSelectedTagHelper(IHtmlGenerator generator) + [HtmlTargetElement("option", Attributes = SelectedName)] + public class OptionSelectedTagHelper : TagHelper { - _generator = generator; - } + private const string SelectedName = "asp-selected"; - [HtmlAttributeName(SelectedName)] - public bool Selected { get; set; } + private readonly IHtmlGenerator _generator; - public override void Process(TagHelperContext context, TagHelperOutput output) - { - if (context == null) + public OptionSelectedTagHelper(IHtmlGenerator generator) { - throw new ArgumentNullException(nameof(context)); + _generator = generator; } - if (output == null) - { - throw new ArgumentNullException(nameof(output)); - } + [HtmlAttributeName(SelectedName)] + public bool Selected { get; set; } - if (Selected) + public override void Process(TagHelperContext context, TagHelperOutput output) { - output.Attributes.Add("selected", "selected"); - } - else - { - output.Attributes.RemoveAll("selected"); + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + if (output == null) + { + throw new ArgumentNullException(nameof(output)); + } + + if (Selected) + { + output.Attributes.Add("selected", "selected"); + } + else + { + output.Attributes.RemoveAll("selected"); + } } } } diff --git a/src/Api/Controllers/AccountsBillingController.cs b/src/Api/Controllers/AccountsBillingController.cs index 9e480301f2..bc012e7b37 100644 --- a/src/Api/Controllers/AccountsBillingController.cs +++ b/src/Api/Controllers/AccountsBillingController.cs @@ -4,48 +4,49 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("accounts/billing")] -[Authorize("Application")] -public class AccountsBillingController : Controller +namespace Bit.Api.Controllers { - private readonly IPaymentService _paymentService; - private readonly IUserService _userService; - - public AccountsBillingController( - IPaymentService paymentService, - IUserService userService) + [Route("accounts/billing")] + [Authorize("Application")] + public class AccountsBillingController : Controller { - _paymentService = paymentService; - _userService = userService; - } + private readonly IPaymentService _paymentService; + private readonly IUserService _userService; - [HttpGet("history")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBillingHistory() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + public AccountsBillingController( + IPaymentService paymentService, + IUserService userService) { - throw new UnauthorizedAccessException(); + _paymentService = paymentService; + _userService = userService; } - var billingInfo = await _paymentService.GetBillingHistoryAsync(user); - return new BillingHistoryResponseModel(billingInfo); - } - - [HttpGet("payment-method")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetPaymentMethod() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpGet("history")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetBillingHistory() { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var billingInfo = await _paymentService.GetBillingHistoryAsync(user); + return new BillingHistoryResponseModel(billingInfo); } - var billingInfo = await _paymentService.GetBillingBalanceAndSourceAsync(user); - return new BillingPaymentResponseModel(billingInfo); + [HttpGet("payment-method")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetPaymentMethod() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var billingInfo = await _paymentService.GetBillingBalanceAndSourceAsync(user); + return new BillingPaymentResponseModel(billingInfo); + } } } diff --git a/src/Api/Controllers/AccountsController.cs b/src/Api/Controllers/AccountsController.cs index 74aa469c98..41708d3d2f 100644 --- a/src/Api/Controllers/AccountsController.cs +++ b/src/Api/Controllers/AccountsController.cs @@ -18,511 +18,139 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("accounts")] -[Authorize("Application")] -public class AccountsController : Controller +namespace Bit.Api.Controllers { - private readonly GlobalSettings _globalSettings; - private readonly ICipherRepository _cipherRepository; - private readonly IFolderRepository _folderRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IPaymentService _paymentService; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly ISendRepository _sendRepository; - private readonly ISendService _sendService; - - public AccountsController( - GlobalSettings globalSettings, - ICipherRepository cipherRepository, - IFolderRepository folderRepository, - IOrganizationService organizationService, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IPaymentService paymentService, - IUserRepository userRepository, - IUserService userService, - ISendRepository sendRepository, - ISendService sendService) + [Route("accounts")] + [Authorize("Application")] + public class AccountsController : Controller { - _cipherRepository = cipherRepository; - _folderRepository = folderRepository; - _globalSettings = globalSettings; - _organizationService = organizationService; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _paymentService = paymentService; - _userRepository = userRepository; - _userService = userService; - _sendRepository = sendRepository; - _sendService = sendService; - } + private readonly GlobalSettings _globalSettings; + private readonly ICipherRepository _cipherRepository; + private readonly IFolderRepository _folderRepository; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IPaymentService _paymentService; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; - #region DEPRECATED (Moved to Identity Service) - - [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] - [HttpPost("prelogin")] - [AllowAnonymous] - public async Task PostPrelogin([FromBody] PreloginRequestModel model) - { - var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); - if (kdfInformation == null) + public AccountsController( + GlobalSettings globalSettings, + ICipherRepository cipherRepository, + IFolderRepository folderRepository, + IOrganizationService organizationService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IPaymentService paymentService, + IUserRepository userRepository, + IUserService userService, + ISendRepository sendRepository, + ISendService sendService) { - kdfInformation = new UserKdfInformation + _cipherRepository = cipherRepository; + _folderRepository = folderRepository; + _globalSettings = globalSettings; + _organizationService = organizationService; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _paymentService = paymentService; + _userRepository = userRepository; + _userService = userService; + _sendRepository = sendRepository; + _sendService = sendService; + } + + #region DEPRECATED (Moved to Identity Service) + + [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] + [HttpPost("prelogin")] + [AllowAnonymous] + public async Task PostPrelogin([FromBody] PreloginRequestModel model) + { + var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); + if (kdfInformation == null) { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 100000, - }; - } - return new PreloginResponseModel(kdfInformation); - } - - [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] - [HttpPost("register")] - [AllowAnonymous] - [CaptchaProtected] - public async Task PostRegister([FromBody] RegisterRequestModel model) - { - var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, - model.Token, model.OrganizationUserId); - if (result.Succeeded) - { - return; + kdfInformation = new UserKdfInformation + { + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 100000, + }; + } + return new PreloginResponseModel(kdfInformation); } - foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) + [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] + [HttpPost("register")] + [AllowAnonymous] + [CaptchaProtected] + public async Task PostRegister([FromBody] RegisterRequestModel model) { - ModelState.AddModelError(string.Empty, error.Description); - } + var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, + model.Token, model.OrganizationUserId); + if (result.Succeeded) + { + return; + } - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } + foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) + { + ModelState.AddModelError(string.Empty, error.Description); + } - #endregion - - [HttpPost("password-hint")] - [AllowAnonymous] - public async Task PostPasswordHint([FromBody] PasswordHintRequestModel model) - { - await _userService.SendMasterPasswordHintAsync(model.Email); - } - - [HttpPost("email-token")] - public async Task PostEmailToken([FromBody] EmailTokenRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (user.UsesKeyConnector) - { - throw new BadRequestException("You cannot change your email when using Key Connector."); - } - - if (!await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) - { await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); + throw new BadRequestException(ModelState); } - await _userService.InitiateEmailChangeAsync(user, model.NewEmail); - } + #endregion - [HttpPost("email")] - public async Task PostEmail([FromBody] EmailRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("password-hint")] + [AllowAnonymous] + public async Task PostPasswordHint([FromBody] PasswordHintRequestModel model) { - throw new UnauthorizedAccessException(); + await _userService.SendMasterPasswordHintAsync(model.Email); } - if (user.UsesKeyConnector) + [HttpPost("email-token")] + public async Task PostEmailToken([FromBody] EmailTokenRequestModel model) { - throw new BadRequestException("You cannot change your email when using Key Connector."); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (user.UsesKeyConnector) + { + throw new BadRequestException("You cannot change your email when using Key Connector."); + } + + if (!await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + + await _userService.InitiateEmailChangeAsync(user, model.NewEmail); } - var result = await _userService.ChangeEmailAsync(user, model.MasterPasswordHash, model.NewEmail, - model.NewMasterPasswordHash, model.Token, model.Key); - if (result.Succeeded) + [HttpPost("email")] + public async Task PostEmail([FromBody] EmailRequestModel model) { - return; - } + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } + if (user.UsesKeyConnector) + { + throw new BadRequestException("You cannot change your email when using Key Connector."); + } - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("verify-email")] - public async Task PostVerifyEmail() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.SendEmailVerificationAsync(user); - } - - [HttpPost("verify-email-token")] - [AllowAnonymous] - public async Task PostVerifyEmailToken([FromBody] VerifyEmailRequestModel model) - { - var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - var result = await _userService.ConfirmEmailAsync(user, model.Token); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("password")] - public async Task PostPassword([FromBody] PasswordRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.ChangePasswordAsync(user, model.MasterPasswordHash, - model.NewMasterPasswordHash, model.MasterPasswordHint, model.Key); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("set-password")] - public async Task PostSetPasswordAsync([FromBody] SetPasswordRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.SetPasswordAsync(model.ToUser(user), model.MasterPasswordHash, model.Key, - model.OrgIdentifier); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); - } - - [HttpPost("verify-password")] - public async Task PostVerifyPassword([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) - { - return; - } - - ModelState.AddModelError(nameof(model.MasterPasswordHash), "Invalid password."); - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("set-key-connector-key")] - public async Task PostSetKeyConnectorKeyAsync([FromBody] SetKeyConnectorKeyRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.SetKeyConnectorKeyAsync(model.ToUser(user), model.Key, model.OrgIdentifier); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); - } - - [HttpPost("convert-to-key-connector")] - public async Task PostConvertToKeyConnector() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.ConvertToKeyConnectorAsync(user); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); - } - - [HttpPost("kdf")] - public async Task PostKdf([FromBody] KdfRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.ChangeKdfAsync(user, model.MasterPasswordHash, - model.NewMasterPasswordHash, model.Key, model.Kdf.Value, model.KdfIterations.Value); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("key")] - public async Task PostKey([FromBody] UpdateKeyRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var ciphers = new List(); - if (model.Ciphers.Any()) - { - var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id); - ciphers.AddRange(existingCiphers - .Join(model.Ciphers, c => c.Id, c => c.Id, (existing, c) => c.ToCipher(existing))); - } - - var folders = new List(); - if (model.Folders.Any()) - { - var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id); - folders.AddRange(existingFolders - .Join(model.Folders, f => f.Id, f => f.Id, (existing, f) => f.ToFolder(existing))); - } - - var sends = new List(); - if (model.Sends?.Any() == true) - { - var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id); - sends.AddRange(existingSends - .Join(model.Sends, s => s.Id, s => s.Id, (existing, s) => s.ToSend(existing, _sendService))); - } - - var result = await _userService.UpdateKeyAsync( - user, - model.MasterPasswordHash, - model.Key, - model.PrivateKey, - ciphers, - folders, - sends); - - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("security-stamp")] - public async Task PostSecurityStamp([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.RefreshSecurityStampAsync(user, model.Secret); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpGet("profile")] - public async Task GetProfile() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, - ProviderUserStatusType.Confirmed); - var providerUserOrganizationDetails = - await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, - ProviderUserStatusType.Confirmed); - var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, - providerUserOrganizationDetails, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); - return response; - } - - [HttpGet("organizations")] - public async Task> GetOrganizations() - { - var userId = _userService.GetProperUserId(User); - var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(userId.Value, - OrganizationUserStatusType.Confirmed); - var responseData = organizationUserDetails.Select(o => new ProfileOrganizationResponseModel(o)); - return new ListResponseModel(responseData); - } - - [HttpPut("profile")] - [HttpPost("profile")] - public async Task PutProfile([FromBody] UpdateProfileRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.SaveUserAsync(model.ToUser(user)); - var response = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); - return response; - } - - [HttpGet("revision-date")] - public async Task GetAccountRevisionDate() - { - var userId = _userService.GetProperUserId(User); - long? revisionDate = null; - if (userId.HasValue) - { - var date = await _userService.GetAccountRevisionDateByIdAsync(userId.Value); - revisionDate = CoreHelpers.ToEpocMilliseconds(date); - } - - return revisionDate; - } - - [HttpPost("keys")] - public async Task PostKeys([FromBody] KeysRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.SaveUserAsync(model.ToUser(user)); - return new KeysResponseModel(user); - } - - [HttpGet("keys")] - public async Task GetKeys() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - return new KeysResponseModel(user); - } - - [HttpDelete] - [HttpPost("delete")] - public async Task Delete([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - ModelState.AddModelError(string.Empty, "User verification failed."); - await Task.Delay(2000); - } - else - { - var result = await _userService.DeleteAsync(user); + var result = await _userService.ChangeEmailAsync(user, model.MasterPasswordHash, model.NewEmail, + model.NewMasterPasswordHash, model.Token, model.Key); if (result.Succeeded) { return; @@ -532,357 +160,730 @@ public class AccountsController : Controller { ModelState.AddModelError(string.Empty, error.Description); } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); } - throw new BadRequestException(ModelState); - } - - [AllowAnonymous] - [HttpPost("delete-recover")] - public async Task PostDeleteRecover([FromBody] DeleteRecoverRequestModel model) - { - await _userService.SendDeleteConfirmationAsync(model.Email); - } - - [HttpPost("delete-recover-token")] - [AllowAnonymous] - public async Task PostDeleteRecoverToken([FromBody] VerifyDeleteRecoverRequestModel model) - { - var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); - if (user == null) + [HttpPost("verify-email")] + public async Task PostVerifyEmail() { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.DeleteAsync(user, model.Token); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpPost("iap-check")] - public async Task PostIapCheck([FromBody] IapCheckRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - await _userService.IapCheckAsync(user, model.PaymentMethodType.Value); - } - - [HttpPost("premium")] - public async Task PostPremium(PremiumRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var valid = model.Validate(_globalSettings); - UserLicense license = null; - if (valid && _globalSettings.SelfHosted) - { - license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - } - - if (!valid && !_globalSettings.SelfHosted && string.IsNullOrWhiteSpace(model.Country)) - { - throw new BadRequestException("Country is required."); - } - - if (!valid || (_globalSettings.SelfHosted && license == null)) - { - throw new BadRequestException("Invalid license."); - } - - var result = await _userService.SignUpPremiumAsync(user, model.PaymentToken, - model.PaymentMethodType.Value, model.AdditionalStorageGb.GetValueOrDefault(0), license, - new TaxInfo + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - BillingAddressCountry = model.Country, - BillingAddressPostalCode = model.PostalCode, - }); - var profile = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); - return new PaymentResponseModel - { - UserProfile = profile, - PaymentIntentClientSecret = result.Item2, - Success = result.Item1 - }; - } + throw new UnauthorizedAccessException(); + } - [Obsolete("2022-04-01 Use separate Billing History/Payment APIs, left for backwards compatability with older clients")] - [HttpGet("billing")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBilling() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); + await _userService.SendEmailVerificationAsync(user); } - var billingInfo = await _paymentService.GetBillingAsync(user); - return new BillingResponseModel(billingInfo); - } - - [HttpGet("subscription")] - public async Task GetSubscription() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("verify-email-token")] + [AllowAnonymous] + public async Task PostVerifyEmailToken([FromBody] VerifyEmailRequestModel model) { - throw new UnauthorizedAccessException(); - } - - if (!_globalSettings.SelfHosted && user.Gateway != null) - { - var subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); - var license = await _userService.GenerateLicenseAsync(user, subscriptionInfo); - return new SubscriptionResponseModel(user, subscriptionInfo, license); - } - else if (!_globalSettings.SelfHosted) - { - var license = await _userService.GenerateLicenseAsync(user); - return new SubscriptionResponseModel(user, license); - } - else - { - return new SubscriptionResponseModel(user); - } - } - - [HttpPost("payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPayment([FromBody] PaymentRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType.Value, - new TaxInfo + var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + var result = await _userService.ConfirmEmailAsync(user, model.Token); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("password")] + public async Task PostPassword([FromBody] PasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.ChangePasswordAsync(user, model.MasterPasswordHash, + model.NewMasterPasswordHash, model.MasterPasswordHint, model.Key); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("set-password")] + public async Task PostSetPasswordAsync([FromBody] SetPasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.SetPasswordAsync(model.ToUser(user), model.MasterPasswordHash, model.Key, + model.OrgIdentifier); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } + + [HttpPost("verify-password")] + public async Task PostVerifyPassword([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) + { + return; + } + + ModelState.AddModelError(nameof(model.MasterPasswordHash), "Invalid password."); + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("set-key-connector-key")] + public async Task PostSetKeyConnectorKeyAsync([FromBody] SetKeyConnectorKeyRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.SetKeyConnectorKeyAsync(model.ToUser(user), model.Key, model.OrgIdentifier); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } + + [HttpPost("convert-to-key-connector")] + public async Task PostConvertToKeyConnector() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.ConvertToKeyConnectorAsync(user); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } + + [HttpPost("kdf")] + public async Task PostKdf([FromBody] KdfRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.ChangeKdfAsync(user, model.MasterPasswordHash, + model.NewMasterPasswordHash, model.Key, model.Kdf.Value, model.KdfIterations.Value); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("key")] + public async Task PostKey([FromBody] UpdateKeyRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var ciphers = new List(); + if (model.Ciphers.Any()) + { + var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id); + ciphers.AddRange(existingCiphers + .Join(model.Ciphers, c => c.Id, c => c.Id, (existing, c) => c.ToCipher(existing))); + } + + var folders = new List(); + if (model.Folders.Any()) + { + var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id); + folders.AddRange(existingFolders + .Join(model.Folders, f => f.Id, f => f.Id, (existing, f) => f.ToFolder(existing))); + } + + var sends = new List(); + if (model.Sends?.Any() == true) + { + var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id); + sends.AddRange(existingSends + .Join(model.Sends, s => s.Id, s => s.Id, (existing, s) => s.ToSend(existing, _sendService))); + } + + var result = await _userService.UpdateKeyAsync( + user, + model.MasterPasswordHash, + model.Key, + model.PrivateKey, + ciphers, + folders, + sends); + + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("security-stamp")] + public async Task PostSecurityStamp([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.RefreshSecurityStampAsync(user, model.Secret); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpGet("profile")] + public async Task GetProfile() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, + ProviderUserStatusType.Confirmed); + var providerUserOrganizationDetails = + await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, + ProviderUserStatusType.Confirmed); + var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + providerUserOrganizationDetails, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); + return response; + } + + [HttpGet("organizations")] + public async Task> GetOrganizations() + { + var userId = _userService.GetProperUserId(User); + var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(userId.Value, + OrganizationUserStatusType.Confirmed); + var responseData = organizationUserDetails.Select(o => new ProfileOrganizationResponseModel(o)); + return new ListResponseModel(responseData); + } + + [HttpPut("profile")] + [HttpPost("profile")] + public async Task PutProfile([FromBody] UpdateProfileRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.SaveUserAsync(model.ToUser(user)); + var response = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); + return response; + } + + [HttpGet("revision-date")] + public async Task GetAccountRevisionDate() + { + var userId = _userService.GetProperUserId(User); + long? revisionDate = null; + if (userId.HasValue) + { + var date = await _userService.GetAccountRevisionDateByIdAsync(userId.Value); + revisionDate = CoreHelpers.ToEpocMilliseconds(date); + } + + return revisionDate; + } + + [HttpPost("keys")] + public async Task PostKeys([FromBody] KeysRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.SaveUserAsync(model.ToUser(user)); + return new KeysResponseModel(user); + } + + [HttpGet("keys")] + public async Task GetKeys() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + return new KeysResponseModel(user); + } + + [HttpDelete] + [HttpPost("delete")] + public async Task Delete([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + ModelState.AddModelError(string.Empty, "User verification failed."); + await Task.Delay(2000); + } + else + { + var result = await _userService.DeleteAsync(user); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + } + + throw new BadRequestException(ModelState); + } + + [AllowAnonymous] + [HttpPost("delete-recover")] + public async Task PostDeleteRecover([FromBody] DeleteRecoverRequestModel model) + { + await _userService.SendDeleteConfirmationAsync(model.Email); + } + + [HttpPost("delete-recover-token")] + [AllowAnonymous] + public async Task PostDeleteRecoverToken([FromBody] VerifyDeleteRecoverRequestModel model) + { + var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.DeleteAsync(user, model.Token); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("iap-check")] + public async Task PostIapCheck([FromBody] IapCheckRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + await _userService.IapCheckAsync(user, model.PaymentMethodType.Value); + } + + [HttpPost("premium")] + public async Task PostPremium(PremiumRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var valid = model.Validate(_globalSettings); + UserLicense license = null; + if (valid && _globalSettings.SelfHosted) + { + license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + } + + if (!valid && !_globalSettings.SelfHosted && string.IsNullOrWhiteSpace(model.Country)) + { + throw new BadRequestException("Country is required."); + } + + if (!valid || (_globalSettings.SelfHosted && license == null)) + { + throw new BadRequestException("Invalid license."); + } + + var result = await _userService.SignUpPremiumAsync(user, model.PaymentToken, + model.PaymentMethodType.Value, model.AdditionalStorageGb.GetValueOrDefault(0), license, + new TaxInfo + { + BillingAddressCountry = model.Country, + BillingAddressPostalCode = model.PostalCode, + }); + var profile = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); + return new PaymentResponseModel + { + UserProfile = profile, + PaymentIntentClientSecret = result.Item2, + Success = result.Item1 + }; + } + + [Obsolete("2022-04-01 Use separate Billing History/Payment APIs, left for backwards compatability with older clients")] + [HttpGet("billing")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetBilling() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var billingInfo = await _paymentService.GetBillingAsync(user); + return new BillingResponseModel(billingInfo); + } + + [HttpGet("subscription")] + public async Task GetSubscription() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!_globalSettings.SelfHosted && user.Gateway != null) + { + var subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); + var license = await _userService.GenerateLicenseAsync(user, subscriptionInfo); + return new SubscriptionResponseModel(user, subscriptionInfo, license); + } + else if (!_globalSettings.SelfHosted) + { + var license = await _userService.GenerateLicenseAsync(user); + return new SubscriptionResponseModel(user, license); + } + else + { + return new SubscriptionResponseModel(user); + } + } + + [HttpPost("payment")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostPayment([FromBody] PaymentRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType.Value, + new TaxInfo + { + BillingAddressCountry = model.Country, + BillingAddressPostalCode = model.PostalCode, + }); + } + + [HttpPost("storage")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostStorage([FromBody] StorageRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.AdjustStorageAsync(user, model.StorageGbAdjustment.Value); + return new PaymentResponseModel + { + Success = true, + PaymentIntentClientSecret = result + }; + } + + [HttpPost("license")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostLicense(LicenseRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + if (license == null) + { + throw new BadRequestException("Invalid license"); + } + + await _userService.UpdateLicenseAsync(user, license); + } + + [HttpPost("cancel-premium")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostCancel() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.CancelPremiumAsync(user); + } + + [HttpPost("reinstate-premium")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostReinstate() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.ReinstatePremiumAsync(user); + } + + [HttpGet("tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetTaxInfo() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var taxInfo = await _paymentService.GetTaxInfoAsync(user); + return new TaxInfoResponseModel(taxInfo); + } + + [HttpPut("tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PutTaxInfo([FromBody] TaxInfoUpdateRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var taxInfo = new TaxInfo { - BillingAddressCountry = model.Country, BillingAddressPostalCode = model.PostalCode, - }); - } - - [HttpPost("storage")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostStorage([FromBody] StorageRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); + BillingAddressCountry = model.Country, + }; + await _paymentService.SaveTaxInfoAsync(user, taxInfo); } - var result = await _userService.AdjustStorageAsync(user, model.StorageGbAdjustment.Value); - return new PaymentResponseModel + [HttpDelete("sso/{organizationId}")] + public async Task DeleteSsoUser(string organizationId) { - Success = true, - PaymentIntentClientSecret = result - }; - } + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } - [HttpPost("license")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostLicense(LicenseRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); + await _organizationService.DeleteSsoUserAsync(userId.Value, new Guid(organizationId)); } - var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - if (license == null) + [HttpGet("sso/user-identifier")] + public async Task GetSsoUserIdentifier() { - throw new BadRequestException("Invalid license"); + var user = await _userService.GetUserByPrincipalAsync(User); + var token = await _userService.GenerateSignInTokenAsync(user, TokenPurposes.LinkSso); + var userIdentifier = $"{user.Id},{token}"; + return userIdentifier; } - await _userService.UpdateLicenseAsync(user, license); - } - - [HttpPost("cancel-premium")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostCancel() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("api-key")] + public async Task ApiKey([FromBody] SecretVerificationRequestModel model) { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + + return new ApiKeyResponseModel(user); } - await _userService.CancelPremiumAsync(user); - } - - [HttpPost("reinstate-premium")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostReinstate() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("rotate-api-key")] + public async Task RotateApiKey([FromBody] SecretVerificationRequestModel model) { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + + await _userService.RotateApiKeyAsync(user); + var response = new ApiKeyResponseModel(user); + return response; } - await _userService.ReinstatePremiumAsync(user); - } - - [HttpGet("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfo() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPut("update-temp-password")] + public async Task PutUpdateTempPasswordAsync([FromBody] UpdateTempPasswordRequestModel model) { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.UpdateTempPasswordAsync(user, model.NewMasterPasswordHash, model.Key, model.MasterPasswordHint); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); } - var taxInfo = await _paymentService.GetTaxInfoAsync(user); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo([FromBody] TaxInfoUpdateRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("request-otp")] + public async Task PostRequestOTP() { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user is not { UsesKeyConnector: true }) + { + throw new UnauthorizedAccessException(); + } + + await _userService.SendOTPAsync(user); } - var taxInfo = new TaxInfo + [HttpPost("verify-otp")] + public async Task VerifyOTP([FromBody] VerifyOTPRequestModel model) { - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await _paymentService.SaveTaxInfoAsync(user, taxInfo); - } + var user = await _userService.GetUserByPrincipalAsync(User); + if (user is not { UsesKeyConnector: true }) + { + throw new UnauthorizedAccessException(); + } - [HttpDelete("sso/{organizationId}")] - public async Task DeleteSsoUser(string organizationId) - { - var userId = _userService.GetProperUserId(User); - if (!userId.HasValue) - { - throw new NotFoundException(); - } - - await _organizationService.DeleteSsoUserAsync(userId.Value, new Guid(organizationId)); - } - - [HttpGet("sso/user-identifier")] - public async Task GetSsoUserIdentifier() - { - var user = await _userService.GetUserByPrincipalAsync(User); - var token = await _userService.GenerateSignInTokenAsync(user, TokenPurposes.LinkSso); - var userIdentifier = $"{user.Id},{token}"; - return userIdentifier; - } - - [HttpPost("api-key")] - public async Task ApiKey([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - - return new ApiKeyResponseModel(user); - } - - [HttpPost("rotate-api-key")] - public async Task RotateApiKey([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - - await _userService.RotateApiKeyAsync(user); - var response = new ApiKeyResponseModel(user); - return response; - } - - [HttpPut("update-temp-password")] - public async Task PutUpdateTempPasswordAsync([FromBody] UpdateTempPasswordRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.UpdateTempPasswordAsync(user, model.NewMasterPasswordHash, model.Key, model.MasterPasswordHint); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); - } - - [HttpPost("request-otp")] - public async Task PostRequestOTP() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user is not { UsesKeyConnector: true }) - { - throw new UnauthorizedAccessException(); - } - - await _userService.SendOTPAsync(user); - } - - [HttpPost("verify-otp")] - public async Task VerifyOTP([FromBody] VerifyOTPRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user is not { UsesKeyConnector: true }) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifyOTPAsync(user, model.OTP)) - { - await Task.Delay(2000); - throw new BadRequestException("Token", "Invalid token"); + if (!await _userService.VerifyOTPAsync(user, model.OTP)) + { + await Task.Delay(2000); + throw new BadRequestException("Token", "Invalid token"); + } } } } diff --git a/src/Api/Controllers/CiphersController.cs b/src/Api/Controllers/CiphersController.cs index 5b059a332b..f5831acaa4 100644 --- a/src/Api/Controllers/CiphersController.cs +++ b/src/Api/Controllers/CiphersController.cs @@ -18,788 +18,789 @@ using Core.Models.Data; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("ciphers")] -[Authorize("Application")] -public class CiphersController : Controller +namespace Bit.Api.Controllers { - private readonly ICipherRepository _cipherRepository; - private readonly ICollectionCipherRepository _collectionCipherRepository; - private readonly ICipherService _cipherService; - private readonly IUserService _userService; - private readonly IAttachmentStorageService _attachmentStorageService; - private readonly IProviderService _providerService; - private readonly ICurrentContext _currentContext; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - - public CiphersController( - ICipherRepository cipherRepository, - ICollectionCipherRepository collectionCipherRepository, - ICipherService cipherService, - IUserService userService, - IAttachmentStorageService attachmentStorageService, - IProviderService providerService, - ICurrentContext currentContext, - ILogger logger, - GlobalSettings globalSettings) + [Route("ciphers")] + [Authorize("Application")] + public class CiphersController : Controller { - _cipherRepository = cipherRepository; - _collectionCipherRepository = collectionCipherRepository; - _cipherService = cipherService; - _userService = userService; - _attachmentStorageService = attachmentStorageService; - _providerService = providerService; - _currentContext = currentContext; - _logger = logger; - _globalSettings = globalSettings; - } + private readonly ICipherRepository _cipherRepository; + private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly ICipherService _cipherService; + private readonly IUserService _userService; + private readonly IAttachmentStorageService _attachmentStorageService; + private readonly IProviderService _providerService; + private readonly ICurrentContext _currentContext; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; - [HttpGet("{id}")] - public async Task Get(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) + public CiphersController( + ICipherRepository cipherRepository, + ICollectionCipherRepository collectionCipherRepository, + ICipherService cipherService, + IUserService userService, + IAttachmentStorageService attachmentStorageService, + IProviderService providerService, + ICurrentContext currentContext, + ILogger logger, + GlobalSettings globalSettings) { - throw new NotFoundException(); + _cipherRepository = cipherRepository; + _collectionCipherRepository = collectionCipherRepository; + _cipherService = cipherService; + _userService = userService; + _attachmentStorageService = attachmentStorageService; + _providerService = providerService; + _currentContext = currentContext; + _logger = logger; + _globalSettings = globalSettings; } - return new CipherResponseModel(cipher, _globalSettings); - } - - [HttpGet("{id}/admin")] - public async Task GetAdmin(string id) - { - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.ViewAllCollections(cipher.OrganizationId.Value)) + [HttpGet("{id}")] + public async Task Get(string id) { - throw new NotFoundException(); - } - - return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); - } - - [HttpGet("{id}/full-details")] - [HttpGet("{id}/details")] - public async Task GetDetails(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipherId = new Guid(id); - var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, cipherId); - return new CipherDetailsResponseModel(cipher, _globalSettings, collectionCiphers); - } - - [HttpGet("")] - public async Task> Get() - { - var userId = _userService.GetProperUserId(User).Value; - var hasOrgs = _currentContext.Organizations?.Any() ?? false; - // TODO: Use hasOrgs proper for cipher listing here? - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true || hasOrgs); - Dictionary> collectionCiphersGroupDict = null; - if (hasOrgs) - { - var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(userId); - collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); - } - - var responses = ciphers.Select(c => new CipherDetailsResponseModel(c, _globalSettings, - collectionCiphersGroupDict)).ToList(); - return new ListResponseModel(responses); - } - - [HttpPost("")] - public async Task Post([FromBody] CipherRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = model.ToCipherDetails(userId); - if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.SaveDetailsAsync(cipher, userId, model.LastKnownRevisionDate, null, cipher.OrganizationId.HasValue); - var response = new CipherResponseModel(cipher, _globalSettings); - return response; - } - - [HttpPost("create")] - public async Task PostCreate([FromBody] CipherCreateRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = model.Cipher.ToCipherDetails(userId); - if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.SaveDetailsAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, cipher.OrganizationId.HasValue); - var response = new CipherResponseModel(cipher, _globalSettings); - return response; - } - - [HttpPost("admin")] - public async Task PostAdmin([FromBody] CipherCreateRequestModel model) - { - var cipher = model.Cipher.ToOrganizationCipher(); - if (!await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.SaveAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, true, false); - - var response = new CipherMiniResponseModel(cipher, _globalSettings, false); - return response; - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(Guid id, [FromBody] CipherRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(id, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); - var modelOrgId = string.IsNullOrWhiteSpace(model.OrganizationId) ? - (Guid?)null : new Guid(model.OrganizationId); - if (cipher.OrganizationId != modelOrgId) - { - throw new BadRequestException("Organization mismatch. Re-sync if you recently moved this item, " + - "then try again."); - } - - await _cipherService.SaveDetailsAsync(model.ToCipherDetails(cipher), userId, model.LastKnownRevisionDate, collectionIds); - - var response = new CipherResponseModel(cipher, _globalSettings); - return response; - } - - [HttpPut("{id}/admin")] - [HttpPost("{id}/admin")] - public async Task PutAdmin(Guid id, [FromBody] CipherRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(id); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); - // object cannot be a descendant of CipherDetails, so let's clone it. - var cipherClone = model.ToCipher(cipher).Clone(); - await _cipherService.SaveAsync(cipherClone, userId, model.LastKnownRevisionDate, collectionIds, true, false); - - var response = new CipherMiniResponseModel(cipherClone, _globalSettings, cipher.OrganizationUseTotp); - return response; - } - - [HttpGet("organization-details")] - public async Task> GetOrganizationCollections( - string organizationId) - { - var userId = _userService.GetProperUserId(User).Value; - var orgIdGuid = new Guid(organizationId); - - (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, orgIdGuid); - - var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, - collectionCiphersGroupDict, c.OrganizationUseTotp)); - - return new ListResponseModel(responses); - } - - [HttpPost("import")] - public async Task PostImport([FromBody] ImportCiphersRequestModel model) - { - if (!_globalSettings.SelfHosted && - (model.Ciphers.Count() > 6000 || model.FolderRelationships.Count() > 6000 || - model.Folders.Count() > 1000)) - { - throw new BadRequestException("You cannot import this much data at once."); - } - - var userId = _userService.GetProperUserId(User).Value; - var folders = model.Folders.Select(f => f.ToFolder(userId)).ToList(); - var ciphers = model.Ciphers.Select(c => c.ToCipherDetails(userId, false)).ToList(); - await _cipherService.ImportCiphersAsync(folders, ciphers, model.FolderRelationships); - } - - [HttpPost("import-organization")] - public async Task PostImport([FromQuery] string organizationId, - [FromBody] ImportOrganizationCiphersRequestModel model) - { - if (!_globalSettings.SelfHosted && - (model.Ciphers.Count() > 6000 || model.CollectionRelationships.Count() > 12000 || - model.Collections.Count() > 1000)) - { - throw new BadRequestException("You cannot import this much data at once."); - } - - var orgId = new Guid(organizationId); - if (!await _currentContext.AccessImportExport(orgId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); - var ciphers = model.Ciphers.Select(l => l.ToOrganizationCipherDetails(orgId)).ToList(); - await _cipherService.ImportCiphersAsync(collections, ciphers, model.CollectionRelationships, userId); - } - - [HttpPut("{id}/partial")] - [HttpPost("{id}/partial")] - public async Task PutPartial(string id, [FromBody] CipherPartialRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var folderId = string.IsNullOrWhiteSpace(model.FolderId) ? null : (Guid?)new Guid(model.FolderId); - await _cipherRepository.UpdatePartialAsync(new Guid(id), userId, folderId, model.Favorite); - } - - [HttpPut("{id}/share")] - [HttpPost("{id}/share")] - public async Task PutShare(string id, [FromBody] CipherShareRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipherId = new Guid(id); - var cipher = await _cipherRepository.GetByIdAsync(cipherId); - if (cipher == null || cipher.UserId != userId || - !await _currentContext.OrganizationUser(new Guid(model.Cipher.OrganizationId))) - { - throw new NotFoundException(); - } - - var original = cipher.Clone(); - await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), - model.CollectionIds.Select(c => new Guid(c)), userId, model.Cipher.LastKnownRevisionDate); - - var sharedCipher = await _cipherRepository.GetByIdAsync(cipherId, userId); - var response = new CipherResponseModel(sharedCipher, _globalSettings); - return response; - } - - [HttpPut("{id}/collections")] - [HttpPost("{id}/collections")] - public async Task PutCollections(string id, [FromBody] CipherCollectionsRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.SaveCollectionsAsync(cipher, - model.CollectionIds.Select(c => new Guid(c)), userId, false); - } - - [HttpPut("{id}/collections-admin")] - [HttpPost("{id}/collections-admin")] - public async Task PutCollectionsAdmin(string id, [FromBody] CipherCollectionsRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.SaveCollectionsAsync(cipher, - model.CollectionIds.Select(c => new Guid(c)), userId, true); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteAsync(cipher, userId); - } - - [HttpDelete("{id}/admin")] - [HttpPost("{id}/delete-admin")] - public async Task DeleteAdmin(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteAsync(cipher, userId, true); - } - - [HttpDelete("")] - [HttpPost("delete")] - public async Task DeleteMany([FromBody] CipherBulkDeleteRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time. " + - "Consider using the \"Purge Vault\" option instead."); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); - } - - [HttpDelete("admin")] - [HttpPost("delete-admin")] - public async Task DeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time. " + - "Consider using the \"Purge Vault\" option instead."); - } - - if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || - !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); - } - - [HttpPut("{id}/delete")] - public async Task PutDelete(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) - { - throw new NotFoundException(); - } - await _cipherService.SoftDeleteAsync(cipher, userId); - } - - [HttpPut("{id}/delete-admin")] - public async Task PutDeleteAdmin(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.SoftDeleteAsync(cipher, userId, true); - } - - [HttpPut("delete")] - public async Task PutDeleteMany([FromBody] CipherBulkDeleteRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time."); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); - } - - [HttpPut("delete-admin")] - public async Task PutDeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time."); - } - - if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || - !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); - } - - [HttpPut("{id}/restore")] - public async Task PutRestore(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await _cipherService.RestoreAsync(cipher, userId); - return new CipherResponseModel(cipher, _globalSettings); - } - - [HttpPut("{id}/restore-admin")] - public async Task PutRestoreAdmin(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.RestoreAsync(cipher, userId, true); - return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); - } - - [HttpPut("restore")] - public async Task> PutRestoreMany([FromBody] CipherBulkRestoreRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only restore up to 500 items at a time."); - } - - var userId = _userService.GetProperUserId(User).Value; - var cipherIdsToRestore = new HashSet(model.Ids.Select(i => new Guid(i))); - - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId); - var restoringCiphers = ciphers.Where(c => cipherIdsToRestore.Contains(c.Id) && c.Edit); - - await _cipherService.RestoreManyAsync(restoringCiphers, userId); - var responses = restoringCiphers.Select(c => new CipherResponseModel(c, _globalSettings)); - return new ListResponseModel(responses); - } - - [HttpPut("move")] - [HttpPost("move")] - public async Task MoveMany([FromBody] CipherBulkMoveRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only move up to 500 items at a time."); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.MoveManyAsync(model.Ids.Select(i => new Guid(i)), - string.IsNullOrWhiteSpace(model.FolderId) ? (Guid?)null : new Guid(model.FolderId), userId); - } - - [HttpPut("share")] - [HttpPost("share")] - public async Task PutShareMany([FromBody] CipherBulkShareRequestModel model) - { - var organizationId = new Guid(model.Ciphers.First().OrganizationId); - if (!await _currentContext.OrganizationUser(organizationId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, false); - var ciphersDict = ciphers.ToDictionary(c => c.Id); - - var shareCiphers = new List<(Cipher, DateTime?)>(); - foreach (var cipher in model.Ciphers) - { - if (!ciphersDict.ContainsKey(cipher.Id.Value)) - { - throw new BadRequestException("Trying to move ciphers that you do not own."); - } - - shareCiphers.Add((cipher.ToCipher(ciphersDict[cipher.Id.Value]), cipher.LastKnownRevisionDate)); - } - - await _cipherService.ShareManyAsync(shareCiphers, organizationId, - model.CollectionIds.Select(c => new Guid(c)), userId); - } - - [HttpPost("purge")] - public async Task PostPurge([FromBody] SecretVerificationRequestModel model, string organizationId = null) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - ModelState.AddModelError(string.Empty, "User verification failed."); - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - if (string.IsNullOrWhiteSpace(organizationId)) - { - await _cipherRepository.DeleteByUserIdAsync(user.Id); - } - else - { - var orgId = new Guid(organizationId); - if (!await _currentContext.EditAnyCollection(orgId)) + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) { throw new NotFoundException(); } - await _cipherService.PurgeAsync(orgId); - } - } - [HttpPost("{id}/attachment/v2")] - public async Task PostAttachment(string id, [FromBody] AttachmentRequestModel request) - { - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = request.AdminRequest ? - await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid) : - await _cipherRepository.GetByIdAsync(idGuid, userId); - - if (cipher == null || (request.AdminRequest && (!cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)))) - { - throw new NotFoundException(); + return new CipherResponseModel(cipher, _globalSettings); } - if (request.FileSize > CipherService.MAX_FILE_SIZE) - { - throw new BadRequestException($"Max file size is {CipherService.MAX_FILE_SIZE_READABLE}."); - } - - var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, - request.Key, request.FileName, request.FileSize, request.AdminRequest, userId); - return new AttachmentUploadDataResponseModel - { - AttachmentId = attachmentId, - Url = uploadUrl, - FileUploadType = _attachmentStorageService.FileUploadType, - CipherResponse = request.AdminRequest ? null : new CipherResponseModel((CipherDetails)cipher, _globalSettings), - CipherMiniResponse = request.AdminRequest ? new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp) : null, - }; - } - - [HttpGet("{id}/attachment/{attachmentId}/renew")] - public async Task RenewFileUploadUrl(string id, string attachmentId) - { - var userId = _userService.GetProperUserId(User).Value; - var cipherId = new Guid(id); - var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); - var attachments = cipher?.GetAttachments(); - - if (attachments == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) - { - throw new NotFoundException(); - } - - return new AttachmentUploadDataResponseModel - { - Url = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, attachments[attachmentId]), - FileUploadType = _attachmentStorageService.FileUploadType, - }; - } - - [HttpPost("{id}/attachment/{attachmentId}")] - [SelfHosted(SelfHostedOnly = true)] - [RequestSizeLimit(Constants.FileSize501mb)] - [DisableFormValueModelBinding] - public async Task PostFileForExistingAttachment(string id, string attachmentId) - { - if (!Request?.ContentType.Contains("multipart/") ?? true) - { - throw new BadRequestException("Invalid content."); - } - - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - var attachments = cipher?.GetAttachments(); - if (attachments == null || !attachments.ContainsKey(attachmentId)) - { - throw new NotFoundException(); - } - var attachmentData = attachments[attachmentId]; - - await Request.GetFileAsync(async (stream) => - { - await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); - }); - } - - [HttpPost("{id}/attachment")] - [Obsolete("Deprecated Attachments API", false)] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostAttachment(string id) - { - ValidateAttachment(); - - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await Request.GetFileAsync(async (stream, fileName, key) => - { - await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId); - }); - - return new CipherResponseModel(cipher, _globalSettings); - } - - [HttpPost("{id}/attachment-admin")] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostAttachmentAdmin(string id) - { - ValidateAttachment(); - - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await Request.GetFileAsync(async (stream, fileName, key) => - { - await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId, true); - }); - - return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); - } - - [HttpGet("{id}/attachment/{attachmentId}")] - public async Task GetAttachmentData(string id, string attachmentId) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - var result = await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); - return new AttachmentResponseModel(result); - } - - [HttpPost("{id}/attachment/{attachmentId}/share")] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostAttachmentShare(string id, string attachmentId, Guid organizationId) - { - ValidateAttachment(); - - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || cipher.UserId != userId || !await _currentContext.OrganizationUser(organizationId)) - { - throw new NotFoundException(); - } - - await Request.GetFileAsync(async (stream, fileName, key) => - { - await _cipherService.CreateAttachmentShareAsync(cipher, stream, - Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); - }); - } - - [HttpDelete("{id}/attachment/{attachmentId}")] - [HttpPost("{id}/attachment/{attachmentId}/delete")] - public async Task DeleteAttachment(string id, string attachmentId) - { - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, false); - } - - [HttpDelete("{id}/attachment/{attachmentId}/admin")] - [HttpPost("{id}/attachment/{attachmentId}/delete-admin")] - public async Task DeleteAttachmentAdmin(string id, string attachmentId) - { - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(idGuid); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, true); - } - - [AllowAnonymous] - [HttpPost("attachment/validate/azure")] - public async Task AzureValidateFile() - { - return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> + [HttpGet("{id}/admin")] + public async Task GetAdmin(string id) { + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.ViewAllCollections(cipher.OrganizationId.Value)) { - "Microsoft.Storage.BlobCreated", async (eventGridEvent) => - { - try - { - var blobName = eventGridEvent.Subject.Split($"{AzureAttachmentStorageService.EventGridEnabledContainerName}/blobs/")[1]; - var (cipherId, organizationId, attachmentId) = AzureAttachmentStorageService.IdentifiersFromBlobName(blobName); - var cipher = await _cipherRepository.GetByIdAsync(new Guid(cipherId)); - var attachments = cipher?.GetAttachments() ?? new Dictionary(); + throw new NotFoundException(); + } - if (cipher == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) + return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + } + + [HttpGet("{id}/full-details")] + [HttpGet("{id}/details")] + public async Task GetDetails(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipherId = new Guid(id); + var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, cipherId); + return new CipherDetailsResponseModel(cipher, _globalSettings, collectionCiphers); + } + + [HttpGet("")] + public async Task> Get() + { + var userId = _userService.GetProperUserId(User).Value; + var hasOrgs = _currentContext.Organizations?.Any() ?? false; + // TODO: Use hasOrgs proper for cipher listing here? + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true || hasOrgs); + Dictionary> collectionCiphersGroupDict = null; + if (hasOrgs) + { + var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(userId); + collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); + } + + var responses = ciphers.Select(c => new CipherDetailsResponseModel(c, _globalSettings, + collectionCiphersGroupDict)).ToList(); + return new ListResponseModel(responses); + } + + [HttpPost("")] + public async Task Post([FromBody] CipherRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = model.ToCipherDetails(userId); + if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SaveDetailsAsync(cipher, userId, model.LastKnownRevisionDate, null, cipher.OrganizationId.HasValue); + var response = new CipherResponseModel(cipher, _globalSettings); + return response; + } + + [HttpPost("create")] + public async Task PostCreate([FromBody] CipherCreateRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = model.Cipher.ToCipherDetails(userId); + if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SaveDetailsAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, cipher.OrganizationId.HasValue); + var response = new CipherResponseModel(cipher, _globalSettings); + return response; + } + + [HttpPost("admin")] + public async Task PostAdmin([FromBody] CipherCreateRequestModel model) + { + var cipher = model.Cipher.ToOrganizationCipher(); + if (!await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.SaveAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, true, false); + + var response = new CipherMiniResponseModel(cipher, _globalSettings, false); + return response; + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(Guid id, [FromBody] CipherRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(id, userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); + var modelOrgId = string.IsNullOrWhiteSpace(model.OrganizationId) ? + (Guid?)null : new Guid(model.OrganizationId); + if (cipher.OrganizationId != modelOrgId) + { + throw new BadRequestException("Organization mismatch. Re-sync if you recently moved this item, " + + "then try again."); + } + + await _cipherService.SaveDetailsAsync(model.ToCipherDetails(cipher), userId, model.LastKnownRevisionDate, collectionIds); + + var response = new CipherResponseModel(cipher, _globalSettings); + return response; + } + + [HttpPut("{id}/admin")] + [HttpPost("{id}/admin")] + public async Task PutAdmin(Guid id, [FromBody] CipherRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(id); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); + // object cannot be a descendant of CipherDetails, so let's clone it. + var cipherClone = model.ToCipher(cipher).Clone(); + await _cipherService.SaveAsync(cipherClone, userId, model.LastKnownRevisionDate, collectionIds, true, false); + + var response = new CipherMiniResponseModel(cipherClone, _globalSettings, cipher.OrganizationUseTotp); + return response; + } + + [HttpGet("organization-details")] + public async Task> GetOrganizationCollections( + string organizationId) + { + var userId = _userService.GetProperUserId(User).Value; + var orgIdGuid = new Guid(organizationId); + + (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, orgIdGuid); + + var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, + collectionCiphersGroupDict, c.OrganizationUseTotp)); + + return new ListResponseModel(responses); + } + + [HttpPost("import")] + public async Task PostImport([FromBody] ImportCiphersRequestModel model) + { + if (!_globalSettings.SelfHosted && + (model.Ciphers.Count() > 6000 || model.FolderRelationships.Count() > 6000 || + model.Folders.Count() > 1000)) + { + throw new BadRequestException("You cannot import this much data at once."); + } + + var userId = _userService.GetProperUserId(User).Value; + var folders = model.Folders.Select(f => f.ToFolder(userId)).ToList(); + var ciphers = model.Ciphers.Select(c => c.ToCipherDetails(userId, false)).ToList(); + await _cipherService.ImportCiphersAsync(folders, ciphers, model.FolderRelationships); + } + + [HttpPost("import-organization")] + public async Task PostImport([FromQuery] string organizationId, + [FromBody] ImportOrganizationCiphersRequestModel model) + { + if (!_globalSettings.SelfHosted && + (model.Ciphers.Count() > 6000 || model.CollectionRelationships.Count() > 12000 || + model.Collections.Count() > 1000)) + { + throw new BadRequestException("You cannot import this much data at once."); + } + + var orgId = new Guid(organizationId); + if (!await _currentContext.AccessImportExport(orgId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); + var ciphers = model.Ciphers.Select(l => l.ToOrganizationCipherDetails(orgId)).ToList(); + await _cipherService.ImportCiphersAsync(collections, ciphers, model.CollectionRelationships, userId); + } + + [HttpPut("{id}/partial")] + [HttpPost("{id}/partial")] + public async Task PutPartial(string id, [FromBody] CipherPartialRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var folderId = string.IsNullOrWhiteSpace(model.FolderId) ? null : (Guid?)new Guid(model.FolderId); + await _cipherRepository.UpdatePartialAsync(new Guid(id), userId, folderId, model.Favorite); + } + + [HttpPut("{id}/share")] + [HttpPost("{id}/share")] + public async Task PutShare(string id, [FromBody] CipherShareRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipherId = new Guid(id); + var cipher = await _cipherRepository.GetByIdAsync(cipherId); + if (cipher == null || cipher.UserId != userId || + !await _currentContext.OrganizationUser(new Guid(model.Cipher.OrganizationId))) + { + throw new NotFoundException(); + } + + var original = cipher.Clone(); + await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), + model.CollectionIds.Select(c => new Guid(c)), userId, model.Cipher.LastKnownRevisionDate); + + var sharedCipher = await _cipherRepository.GetByIdAsync(cipherId, userId); + var response = new CipherResponseModel(sharedCipher, _globalSettings); + return response; + } + + [HttpPut("{id}/collections")] + [HttpPost("{id}/collections")] + public async Task PutCollections(string id, [FromBody] CipherCollectionsRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SaveCollectionsAsync(cipher, + model.CollectionIds.Select(c => new Guid(c)), userId, false); + } + + [HttpPut("{id}/collections-admin")] + [HttpPost("{id}/collections-admin")] + public async Task PutCollectionsAdmin(string id, [FromBody] CipherCollectionsRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SaveCollectionsAsync(cipher, + model.CollectionIds.Select(c => new Guid(c)), userId, true); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteAsync(cipher, userId); + } + + [HttpDelete("{id}/admin")] + [HttpPost("{id}/delete-admin")] + public async Task DeleteAdmin(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteAsync(cipher, userId, true); + } + + [HttpDelete("")] + [HttpPost("delete")] + public async Task DeleteMany([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time. " + + "Consider using the \"Purge Vault\" option instead."); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); + } + + [HttpDelete("admin")] + [HttpPost("delete-admin")] + public async Task DeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time. " + + "Consider using the \"Purge Vault\" option instead."); + } + + if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || + !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); + } + + [HttpPut("{id}/delete")] + public async Task PutDelete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) + { + throw new NotFoundException(); + } + await _cipherService.SoftDeleteAsync(cipher, userId); + } + + [HttpPut("{id}/delete-admin")] + public async Task PutDeleteAdmin(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SoftDeleteAsync(cipher, userId, true); + } + + [HttpPut("delete")] + public async Task PutDeleteMany([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time."); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); + } + + [HttpPut("delete-admin")] + public async Task PutDeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time."); + } + + if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || + !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); + } + + [HttpPut("{id}/restore")] + public async Task PutRestore(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + await _cipherService.RestoreAsync(cipher, userId); + return new CipherResponseModel(cipher, _globalSettings); + } + + [HttpPut("{id}/restore-admin")] + public async Task PutRestoreAdmin(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.RestoreAsync(cipher, userId, true); + return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + } + + [HttpPut("restore")] + public async Task> PutRestoreMany([FromBody] CipherBulkRestoreRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only restore up to 500 items at a time."); + } + + var userId = _userService.GetProperUserId(User).Value; + var cipherIdsToRestore = new HashSet(model.Ids.Select(i => new Guid(i))); + + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId); + var restoringCiphers = ciphers.Where(c => cipherIdsToRestore.Contains(c.Id) && c.Edit); + + await _cipherService.RestoreManyAsync(restoringCiphers, userId); + var responses = restoringCiphers.Select(c => new CipherResponseModel(c, _globalSettings)); + return new ListResponseModel(responses); + } + + [HttpPut("move")] + [HttpPost("move")] + public async Task MoveMany([FromBody] CipherBulkMoveRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only move up to 500 items at a time."); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.MoveManyAsync(model.Ids.Select(i => new Guid(i)), + string.IsNullOrWhiteSpace(model.FolderId) ? (Guid?)null : new Guid(model.FolderId), userId); + } + + [HttpPut("share")] + [HttpPost("share")] + public async Task PutShareMany([FromBody] CipherBulkShareRequestModel model) + { + var organizationId = new Guid(model.Ciphers.First().OrganizationId); + if (!await _currentContext.OrganizationUser(organizationId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, false); + var ciphersDict = ciphers.ToDictionary(c => c.Id); + + var shareCiphers = new List<(Cipher, DateTime?)>(); + foreach (var cipher in model.Ciphers) + { + if (!ciphersDict.ContainsKey(cipher.Id.Value)) + { + throw new BadRequestException("Trying to move ciphers that you do not own."); + } + + shareCiphers.Add((cipher.ToCipher(ciphersDict[cipher.Id.Value]), cipher.LastKnownRevisionDate)); + } + + await _cipherService.ShareManyAsync(shareCiphers, organizationId, + model.CollectionIds.Select(c => new Guid(c)), userId); + } + + [HttpPost("purge")] + public async Task PostPurge([FromBody] SecretVerificationRequestModel model, string organizationId = null) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + ModelState.AddModelError(string.Empty, "User verification failed."); + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + if (string.IsNullOrWhiteSpace(organizationId)) + { + await _cipherRepository.DeleteByUserIdAsync(user.Id); + } + else + { + var orgId = new Guid(organizationId); + if (!await _currentContext.EditAnyCollection(orgId)) + { + throw new NotFoundException(); + } + await _cipherService.PurgeAsync(orgId); + } + } + + [HttpPost("{id}/attachment/v2")] + public async Task PostAttachment(string id, [FromBody] AttachmentRequestModel request) + { + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = request.AdminRequest ? + await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid) : + await _cipherRepository.GetByIdAsync(idGuid, userId); + + if (cipher == null || (request.AdminRequest && (!cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)))) + { + throw new NotFoundException(); + } + + if (request.FileSize > CipherService.MAX_FILE_SIZE) + { + throw new BadRequestException($"Max file size is {CipherService.MAX_FILE_SIZE_READABLE}."); + } + + var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, + request.Key, request.FileName, request.FileSize, request.AdminRequest, userId); + return new AttachmentUploadDataResponseModel + { + AttachmentId = attachmentId, + Url = uploadUrl, + FileUploadType = _attachmentStorageService.FileUploadType, + CipherResponse = request.AdminRequest ? null : new CipherResponseModel((CipherDetails)cipher, _globalSettings), + CipherMiniResponse = request.AdminRequest ? new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp) : null, + }; + } + + [HttpGet("{id}/attachment/{attachmentId}/renew")] + public async Task RenewFileUploadUrl(string id, string attachmentId) + { + var userId = _userService.GetProperUserId(User).Value; + var cipherId = new Guid(id); + var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); + var attachments = cipher?.GetAttachments(); + + if (attachments == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) + { + throw new NotFoundException(); + } + + return new AttachmentUploadDataResponseModel + { + Url = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, attachments[attachmentId]), + FileUploadType = _attachmentStorageService.FileUploadType, + }; + } + + [HttpPost("{id}/attachment/{attachmentId}")] + [SelfHosted(SelfHostedOnly = true)] + [RequestSizeLimit(Constants.FileSize501mb)] + [DisableFormValueModelBinding] + public async Task PostFileForExistingAttachment(string id, string attachmentId) + { + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); + } + + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + var attachments = cipher?.GetAttachments(); + if (attachments == null || !attachments.ContainsKey(attachmentId)) + { + throw new NotFoundException(); + } + var attachmentData = attachments[attachmentId]; + + await Request.GetFileAsync(async (stream) => + { + await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); + }); + } + + [HttpPost("{id}/attachment")] + [Obsolete("Deprecated Attachments API", false)] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostAttachment(string id) + { + ValidateAttachment(); + + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + await Request.GetFileAsync(async (stream, fileName, key) => + { + await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, + Request.ContentLength.GetValueOrDefault(0), userId); + }); + + return new CipherResponseModel(cipher, _globalSettings); + } + + [HttpPost("{id}/attachment-admin")] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostAttachmentAdmin(string id) + { + ValidateAttachment(); + + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await Request.GetFileAsync(async (stream, fileName, key) => + { + await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, + Request.ContentLength.GetValueOrDefault(0), userId, true); + }); + + return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + } + + [HttpGet("{id}/attachment/{attachmentId}")] + public async Task GetAttachmentData(string id, string attachmentId) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + var result = await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); + return new AttachmentResponseModel(result); + } + + [HttpPost("{id}/attachment/{attachmentId}/share")] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostAttachmentShare(string id, string attachmentId, Guid organizationId) + { + ValidateAttachment(); + + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || cipher.UserId != userId || !await _currentContext.OrganizationUser(organizationId)) + { + throw new NotFoundException(); + } + + await Request.GetFileAsync(async (stream, fileName, key) => + { + await _cipherService.CreateAttachmentShareAsync(cipher, stream, + Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); + }); + } + + [HttpDelete("{id}/attachment/{attachmentId}")] + [HttpPost("{id}/attachment/{attachmentId}/delete")] + public async Task DeleteAttachment(string id, string attachmentId) + { + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, false); + } + + [HttpDelete("{id}/attachment/{attachmentId}/admin")] + [HttpPost("{id}/attachment/{attachmentId}/delete-admin")] + public async Task DeleteAttachmentAdmin(string id, string attachmentId) + { + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(idGuid); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, true); + } + + [AllowAnonymous] + [HttpPost("attachment/validate/azure")] + public async Task AzureValidateFile() + { + return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> + { + { + "Microsoft.Storage.BlobCreated", async (eventGridEvent) => + { + try { - if (_attachmentStorageService is AzureSendFileStorageService azureFileStorageService) + var blobName = eventGridEvent.Subject.Split($"{AzureAttachmentStorageService.EventGridEnabledContainerName}/blobs/")[1]; + var (cipherId, organizationId, attachmentId) = AzureAttachmentStorageService.IdentifiersFromBlobName(blobName); + var cipher = await _cipherRepository.GetByIdAsync(new Guid(cipherId)); + var attachments = cipher?.GetAttachments() ?? new Dictionary(); + + if (cipher == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) { - await azureFileStorageService.DeleteBlobAsync(blobName); + if (_attachmentStorageService is AzureSendFileStorageService azureFileStorageService) + { + await azureFileStorageService.DeleteBlobAsync(blobName); + } + + return; } + await _cipherService.ValidateCipherAttachmentFile(cipher, attachments[attachmentId]); + } + catch (Exception e) + { + _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); return; } - - await _cipherService.ValidateCipherAttachmentFile(cipher, attachments[attachmentId]); - } - catch (Exception e) - { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); - return; } } - } - }); - } + }); + } - private void ValidateAttachment() - { - if (!Request?.ContentType.Contains("multipart/") ?? true) + private void ValidateAttachment() { - throw new BadRequestException("Invalid content."); + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); + } } } } diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index 14d4e95b27..548ff80d40 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -8,260 +8,261 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organizations/{orgId}/collections")] -[Authorize("Application")] -public class CollectionsController : Controller +namespace Bit.Api.Controllers { - private readonly ICollectionRepository _collectionRepository; - private readonly ICollectionService _collectionService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - - public CollectionsController( - ICollectionRepository collectionRepository, - ICollectionService collectionService, - IUserService userService, - ICurrentContext currentContext) + [Route("organizations/{orgId}/collections")] + [Authorize("Application")] + public class CollectionsController : Controller { - _collectionRepository = collectionRepository; - _collectionService = collectionService; - _userService = userService; - _currentContext = currentContext; - } + private readonly ICollectionRepository _collectionRepository; + private readonly ICollectionService _collectionService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; - [HttpGet("{id}")] - public async Task Get(Guid orgId, Guid id) - { - if (!await CanViewCollectionAsync(orgId, id)) + public CollectionsController( + ICollectionRepository collectionRepository, + ICollectionService collectionService, + IUserService userService, + ICurrentContext currentContext) { - throw new NotFoundException(); + _collectionRepository = collectionRepository; + _collectionService = collectionService; + _userService = userService; + _currentContext = currentContext; } - var collection = await GetCollectionAsync(id, orgId); - return new CollectionResponseModel(collection); - } - - [HttpGet("{id}/details")] - public async Task GetDetails(Guid orgId, Guid id) - { - if (!await ViewAtLeastOneCollectionAsync(orgId) && !await _currentContext.ManageUsers(orgId)) + [HttpGet("{id}")] + public async Task Get(Guid orgId, Guid id) { - throw new NotFoundException(); - } - - if (await _currentContext.ViewAllCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id); - if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) + if (!await CanViewCollectionAsync(orgId, id)) { throw new NotFoundException(); } - return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); + + var collection = await GetCollectionAsync(id, orgId); + return new CollectionResponseModel(collection); } - else + + [HttpGet("{id}/details")] + public async Task GetDetails(Guid orgId, Guid id) { - var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id, - _currentContext.UserId.Value); - if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) + if (!await ViewAtLeastOneCollectionAsync(orgId) && !await _currentContext.ManageUsers(orgId)) { throw new NotFoundException(); } - return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); - } - } - [HttpGet("")] - public async Task> Get(Guid orgId) - { - IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(orgId); - - var responses = orgCollections.Select(c => new CollectionResponseModel(c)); - return new ListResponseModel(responses); - } - - [HttpGet("~/collections")] - public async Task> GetUser() - { - var collections = await _collectionRepository.GetManyByUserIdAsync( - _userService.GetProperUserId(User).Value); - var responses = collections.Select(c => new CollectionDetailsResponseModel(c)); - return new ListResponseModel(responses); - } - - [HttpGet("{id}/users")] - public async Task> GetUsers(Guid orgId, Guid id) - { - var collection = await GetCollectionAsync(id, orgId); - var collectionUsers = await _collectionRepository.GetManyUsersByIdAsync(collection.Id); - var responses = collectionUsers.Select(cu => new SelectionReadOnlyResponseModel(cu)); - return responses; - } - - [HttpPost("")] - public async Task Post(Guid orgId, [FromBody] CollectionRequestModel model) - { - var collection = model.ToCollection(orgId); - - if (!await CanCreateCollection(orgId, collection.Id) && - !await CanEditCollectionAsync(orgId, collection.Id)) - { - throw new NotFoundException(); + if (await _currentContext.ViewAllCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id); + if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) + { + throw new NotFoundException(); + } + return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); + } + else + { + var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id, + _currentContext.UserId.Value); + if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) + { + throw new NotFoundException(); + } + return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); + } } - var assignUserToCollection = !(await _currentContext.EditAnyCollection(orgId)) && - await _currentContext.EditAssignedCollections(orgId); - - await _collectionService.SaveAsync(collection, model.Groups?.Select(g => g.ToSelectionReadOnly()), - assignUserToCollection ? _currentContext.UserId : null); - return new CollectionResponseModel(collection); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(Guid orgId, Guid id, [FromBody] CollectionRequestModel model) - { - if (!await CanEditCollectionAsync(orgId, id)) + [HttpGet("")] + public async Task> Get(Guid orgId) { - throw new NotFoundException(); + IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(orgId); + + var responses = orgCollections.Select(c => new CollectionResponseModel(c)); + return new ListResponseModel(responses); } - var collection = await GetCollectionAsync(id, orgId); - await _collectionService.SaveAsync(model.ToCollection(collection), - model.Groups?.Select(g => g.ToSelectionReadOnly())); - return new CollectionResponseModel(collection); - } - - [HttpPut("{id}/users")] - public async Task PutUsers(Guid orgId, Guid id, [FromBody] IEnumerable model) - { - if (!await CanEditCollectionAsync(orgId, id)) + [HttpGet("~/collections")] + public async Task> GetUser() { - throw new NotFoundException(); + var collections = await _collectionRepository.GetManyByUserIdAsync( + _userService.GetProperUserId(User).Value); + var responses = collections.Select(c => new CollectionDetailsResponseModel(c)); + return new ListResponseModel(responses); } - var collection = await GetCollectionAsync(id, orgId); - await _collectionRepository.UpdateUsersAsync(collection.Id, model?.Select(g => g.ToSelectionReadOnly())); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(Guid orgId, Guid id) - { - if (!await CanDeleteCollectionAsync(orgId, id)) + [HttpGet("{id}/users")] + public async Task> GetUsers(Guid orgId, Guid id) { - throw new NotFoundException(); + var collection = await GetCollectionAsync(id, orgId); + var collectionUsers = await _collectionRepository.GetManyUsersByIdAsync(collection.Id); + var responses = collectionUsers.Select(cu => new SelectionReadOnlyResponseModel(cu)); + return responses; } - var collection = await GetCollectionAsync(id, orgId); - await _collectionService.DeleteAsync(collection); - } - - [HttpDelete("{id}/user/{orgUserId}")] - [HttpPost("{id}/delete-user/{orgUserId}")] - public async Task Delete(string orgId, string id, string orgUserId) - { - var collection = await GetCollectionAsync(new Guid(id), new Guid(orgId)); - await _collectionService.DeleteUserAsync(collection, new Guid(orgUserId)); - } - - private async Task GetCollectionAsync(Guid id, Guid orgId) - { - Collection collection = default; - if (await _currentContext.ViewAllCollections(orgId)) + [HttpPost("")] + public async Task Post(Guid orgId, [FromBody] CollectionRequestModel model) { - collection = await _collectionRepository.GetByIdAsync(id); - } - else if (await _currentContext.ViewAssignedCollections(orgId)) - { - collection = await _collectionRepository.GetByIdAsync(id, _currentContext.UserId.Value); + var collection = model.ToCollection(orgId); + + if (!await CanCreateCollection(orgId, collection.Id) && + !await CanEditCollectionAsync(orgId, collection.Id)) + { + throw new NotFoundException(); + } + + var assignUserToCollection = !(await _currentContext.EditAnyCollection(orgId)) && + await _currentContext.EditAssignedCollections(orgId); + + await _collectionService.SaveAsync(collection, model.Groups?.Select(g => g.ToSelectionReadOnly()), + assignUserToCollection ? _currentContext.UserId : null); + return new CollectionResponseModel(collection); } - if (collection == null || collection.OrganizationId != orgId) + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(Guid orgId, Guid id, [FromBody] CollectionRequestModel model) { - throw new NotFoundException(); + if (!await CanEditCollectionAsync(orgId, id)) + { + throw new NotFoundException(); + } + + var collection = await GetCollectionAsync(id, orgId); + await _collectionService.SaveAsync(model.ToCollection(collection), + model.Groups?.Select(g => g.ToSelectionReadOnly())); + return new CollectionResponseModel(collection); } - return collection; - } - - - private async Task CanCreateCollection(Guid orgId, Guid collectionId) - { - if (collectionId != default) + [HttpPut("{id}/users")] + public async Task PutUsers(Guid orgId, Guid id, [FromBody] IEnumerable model) { + if (!await CanEditCollectionAsync(orgId, id)) + { + throw new NotFoundException(); + } + + var collection = await GetCollectionAsync(id, orgId); + await _collectionRepository.UpdateUsersAsync(collection.Id, model?.Select(g => g.ToSelectionReadOnly())); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(Guid orgId, Guid id) + { + if (!await CanDeleteCollectionAsync(orgId, id)) + { + throw new NotFoundException(); + } + + var collection = await GetCollectionAsync(id, orgId); + await _collectionService.DeleteAsync(collection); + } + + [HttpDelete("{id}/user/{orgUserId}")] + [HttpPost("{id}/delete-user/{orgUserId}")] + public async Task Delete(string orgId, string id, string orgUserId) + { + var collection = await GetCollectionAsync(new Guid(id), new Guid(orgId)); + await _collectionService.DeleteUserAsync(collection, new Guid(orgUserId)); + } + + private async Task GetCollectionAsync(Guid id, Guid orgId) + { + Collection collection = default; + if (await _currentContext.ViewAllCollections(orgId)) + { + collection = await _collectionRepository.GetByIdAsync(id); + } + else if (await _currentContext.ViewAssignedCollections(orgId)) + { + collection = await _collectionRepository.GetByIdAsync(id, _currentContext.UserId.Value); + } + + if (collection == null || collection.OrganizationId != orgId) + { + throw new NotFoundException(); + } + + return collection; + } + + + private async Task CanCreateCollection(Guid orgId, Guid collectionId) + { + if (collectionId != default) + { + return false; + } + + return await _currentContext.CreateNewCollections(orgId); + } + + private async Task CanEditCollectionAsync(Guid orgId, Guid collectionId) + { + if (collectionId == default) + { + return false; + } + + if (await _currentContext.EditAnyCollection(orgId)) + { + return true; + } + + if (await _currentContext.EditAssignedCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); + return collectionDetails != null; + } + return false; } - return await _currentContext.CreateNewCollections(orgId); - } - - private async Task CanEditCollectionAsync(Guid orgId, Guid collectionId) - { - if (collectionId == default) + private async Task CanDeleteCollectionAsync(Guid orgId, Guid collectionId) { + if (collectionId == default) + { + return false; + } + + if (await _currentContext.DeleteAnyCollection(orgId)) + { + return true; + } + + if (await _currentContext.DeleteAssignedCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); + return collectionDetails != null; + } + return false; } - if (await _currentContext.EditAnyCollection(orgId)) + private async Task CanViewCollectionAsync(Guid orgId, Guid collectionId) { - return true; - } + if (collectionId == default) + { + return false; + } - if (await _currentContext.EditAssignedCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); - return collectionDetails != null; - } + if (await _currentContext.ViewAllCollections(orgId)) + { + return true; + } - return false; - } + if (await _currentContext.ViewAssignedCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); + return collectionDetails != null; + } - private async Task CanDeleteCollectionAsync(Guid orgId, Guid collectionId) - { - if (collectionId == default) - { return false; } - if (await _currentContext.DeleteAnyCollection(orgId)) + private async Task ViewAtLeastOneCollectionAsync(Guid orgId) { - return true; + return await _currentContext.ViewAllCollections(orgId) || await _currentContext.ViewAssignedCollections(orgId); } - - if (await _currentContext.DeleteAssignedCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); - return collectionDetails != null; - } - - return false; - } - - private async Task CanViewCollectionAsync(Guid orgId, Guid collectionId) - { - if (collectionId == default) - { - return false; - } - - if (await _currentContext.ViewAllCollections(orgId)) - { - return true; - } - - if (await _currentContext.ViewAssignedCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); - return collectionDetails != null; - } - - return false; - } - - private async Task ViewAtLeastOneCollectionAsync(Guid orgId) - { - return await _currentContext.ViewAllCollections(orgId) || await _currentContext.ViewAssignedCollections(orgId); } } diff --git a/src/Api/Controllers/DevicesController.cs b/src/Api/Controllers/DevicesController.cs index 77fb34c648..8bfa5d7b08 100644 --- a/src/Api/Controllers/DevicesController.cs +++ b/src/Api/Controllers/DevicesController.cs @@ -7,123 +7,124 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("devices")] -[Authorize("Application")] -public class DevicesController : Controller +namespace Bit.Api.Controllers { - private readonly IDeviceRepository _deviceRepository; - private readonly IDeviceService _deviceService; - private readonly IUserService _userService; - - public DevicesController( - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService) + [Route("devices")] + [Authorize("Application")] + public class DevicesController : Controller { - _deviceRepository = deviceRepository; - _deviceService = deviceService; - _userService = userService; - } + private readonly IDeviceRepository _deviceRepository; + private readonly IDeviceService _deviceService; + private readonly IUserService _userService; - [HttpGet("{id}")] - public async Task Get(string id) - { - var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); - if (device == null) + public DevicesController( + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService) { - throw new NotFoundException(); + _deviceRepository = deviceRepository; + _deviceService = deviceService; + _userService = userService; } - var response = new DeviceResponseModel(device); - return response; - } - - [HttpGet("identifier/{identifier}")] - public async Task GetByIdentifier(string identifier) - { - var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); - if (device == null) + [HttpGet("{id}")] + public async Task Get(string id) { - throw new NotFoundException(); + var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); + } + + var response = new DeviceResponseModel(device); + return response; } - var response = new DeviceResponseModel(device); - return response; - } - - [HttpGet("")] - public async Task> Get() - { - ICollection devices = await _deviceRepository.GetManyByUserIdAsync(_userService.GetProperUserId(User).Value); - var responses = devices.Select(d => new DeviceResponseModel(d)); - return new ListResponseModel(responses); - } - - [HttpPost("")] - public async Task Post([FromBody] DeviceRequestModel model) - { - var device = model.ToDevice(_userService.GetProperUserId(User)); - await _deviceService.SaveAsync(device); - - var response = new DeviceResponseModel(device); - return response; - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string id, [FromBody] DeviceRequestModel model) - { - var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); - if (device == null) + [HttpGet("identifier/{identifier}")] + public async Task GetByIdentifier(string identifier) { - throw new NotFoundException(); + var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); + } + + var response = new DeviceResponseModel(device); + return response; } - await _deviceService.SaveAsync(model.ToDevice(device)); - - var response = new DeviceResponseModel(device); - return response; - } - - [HttpPut("identifier/{identifier}/token")] - [HttpPost("identifier/{identifier}/token")] - public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model) - { - var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); - if (device == null) + [HttpGet("")] + public async Task> Get() { - throw new NotFoundException(); + ICollection devices = await _deviceRepository.GetManyByUserIdAsync(_userService.GetProperUserId(User).Value); + var responses = devices.Select(d => new DeviceResponseModel(d)); + return new ListResponseModel(responses); } - await _deviceService.SaveAsync(model.ToDevice(device)); - } - - [AllowAnonymous] - [HttpPut("identifier/{identifier}/clear-token")] - [HttpPost("identifier/{identifier}/clear-token")] - public async Task PutClearToken(string identifier) - { - var device = await _deviceRepository.GetByIdentifierAsync(identifier); - if (device == null) + [HttpPost("")] + public async Task Post([FromBody] DeviceRequestModel model) { - throw new NotFoundException(); + var device = model.ToDevice(_userService.GetProperUserId(User)); + await _deviceService.SaveAsync(device); + + var response = new DeviceResponseModel(device); + return response; } - await _deviceService.ClearTokenAsync(device); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id) - { - var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); - if (device == null) + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string id, [FromBody] DeviceRequestModel model) { - throw new NotFoundException(); + var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); + } + + await _deviceService.SaveAsync(model.ToDevice(device)); + + var response = new DeviceResponseModel(device); + return response; } - await _deviceService.DeleteAsync(device); + [HttpPut("identifier/{identifier}/token")] + [HttpPost("identifier/{identifier}/token")] + public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model) + { + var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); + } + + await _deviceService.SaveAsync(model.ToDevice(device)); + } + + [AllowAnonymous] + [HttpPut("identifier/{identifier}/clear-token")] + [HttpPost("identifier/{identifier}/clear-token")] + public async Task PutClearToken(string identifier) + { + var device = await _deviceRepository.GetByIdentifierAsync(identifier); + if (device == null) + { + throw new NotFoundException(); + } + + await _deviceService.ClearTokenAsync(device); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id) + { + var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); + } + + await _deviceService.DeleteAsync(device); + } } } diff --git a/src/Api/Controllers/EmergencyAccessController.cs b/src/Api/Controllers/EmergencyAccessController.cs index b2eb997b41..4e8ac834d4 100644 --- a/src/Api/Controllers/EmergencyAccessController.cs +++ b/src/Api/Controllers/EmergencyAccessController.cs @@ -9,169 +9,170 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("emergency-access")] -[Authorize("Application")] -public class EmergencyAccessController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - private readonly IEmergencyAccessRepository _emergencyAccessRepository; - private readonly IEmergencyAccessService _emergencyAccessService; - private readonly IGlobalSettings _globalSettings; - - public EmergencyAccessController( - IUserService userService, - IEmergencyAccessRepository emergencyAccessRepository, - IEmergencyAccessService emergencyAccessService, - IGlobalSettings globalSettings) + [Route("emergency-access")] + [Authorize("Application")] + public class EmergencyAccessController : Controller { - _userService = userService; - _emergencyAccessRepository = emergencyAccessRepository; - _emergencyAccessService = emergencyAccessService; - _globalSettings = globalSettings; - } + private readonly IUserService _userService; + private readonly IEmergencyAccessRepository _emergencyAccessRepository; + private readonly IEmergencyAccessService _emergencyAccessService; + private readonly IGlobalSettings _globalSettings; - [HttpGet("trusted")] - public async Task> GetContacts() - { - var userId = _userService.GetProperUserId(User); - var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGrantorIdAsync(userId.Value); - - var responses = granteeDetails.Select(d => - new EmergencyAccessGranteeDetailsResponseModel(d)); - - return new ListResponseModel(responses); - } - - [HttpGet("granted")] - public async Task> GetGrantees() - { - var userId = _userService.GetProperUserId(User); - var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGranteeIdAsync(userId.Value); - - var responses = granteeDetails.Select(d => new EmergencyAccessGrantorDetailsResponseModel(d)); - - return new ListResponseModel(responses); - } - - [HttpGet("{id}")] - public async Task Get(Guid id) - { - var userId = _userService.GetProperUserId(User); - var result = await _emergencyAccessService.GetAsync(id, userId.Value); - return new EmergencyAccessGranteeDetailsResponseModel(result); - } - - [HttpGet("{id}/policies")] - public async Task> Policies(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var policies = await _emergencyAccessService.GetPoliciesAsync(id, user); - var responses = policies.Select(policy => new PolicyResponseModel(policy)); - return new ListResponseModel(responses); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(Guid id, [FromBody] EmergencyAccessUpdateRequestModel model) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - if (emergencyAccess == null) + public EmergencyAccessController( + IUserService userService, + IEmergencyAccessRepository emergencyAccessRepository, + IEmergencyAccessService emergencyAccessService, + IGlobalSettings globalSettings) { - throw new NotFoundException(); + _userService = userService; + _emergencyAccessRepository = emergencyAccessRepository; + _emergencyAccessService = emergencyAccessService; + _globalSettings = globalSettings; } - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.SaveAsync(model.ToEmergencyAccess(emergencyAccess), user); - } + [HttpGet("trusted")] + public async Task> GetContacts() + { + var userId = _userService.GetProperUserId(User); + var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGrantorIdAsync(userId.Value); - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(Guid id) - { - var userId = _userService.GetProperUserId(User); - await _emergencyAccessService.DeleteAsync(id, userId.Value); - } + var responses = granteeDetails.Select(d => + new EmergencyAccessGranteeDetailsResponseModel(d)); - [HttpPost("invite")] - public async Task Invite([FromBody] EmergencyAccessInviteRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.InviteAsync(user, model.Email, model.Type.Value, model.WaitTimeDays); - } + return new ListResponseModel(responses); + } - [HttpPost("{id}/reinvite")] - public async Task Reinvite(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.ResendInviteAsync(user, id); - } + [HttpGet("granted")] + public async Task> GetGrantees() + { + var userId = _userService.GetProperUserId(User); + var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGranteeIdAsync(userId.Value); - [HttpPost("{id}/accept")] - public async Task Accept(Guid id, [FromBody] OrganizationUserAcceptRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.AcceptUserAsync(id, user, model.Token, _userService); - } + var responses = granteeDetails.Select(d => new EmergencyAccessGrantorDetailsResponseModel(d)); - [HttpPost("{id}/confirm")] - public async Task Confirm(Guid id, [FromBody] OrganizationUserConfirmRequestModel model) - { - var userId = _userService.GetProperUserId(User); - await _emergencyAccessService.ConfirmUserAsync(id, model.Key, userId.Value); - } + return new ListResponseModel(responses); + } - [HttpPost("{id}/initiate")] - public async Task Initiate(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.InitiateAsync(id, user); - } + [HttpGet("{id}")] + public async Task Get(Guid id) + { + var userId = _userService.GetProperUserId(User); + var result = await _emergencyAccessService.GetAsync(id, userId.Value); + return new EmergencyAccessGranteeDetailsResponseModel(result); + } - [HttpPost("{id}/approve")] - public async Task Accept(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.ApproveAsync(id, user); - } + [HttpGet("{id}/policies")] + public async Task> Policies(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var policies = await _emergencyAccessService.GetPoliciesAsync(id, user); + var responses = policies.Select(policy => new PolicyResponseModel(policy)); + return new ListResponseModel(responses); + } - [HttpPost("{id}/reject")] - public async Task Reject(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.RejectAsync(id, user); - } + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(Guid id, [FromBody] EmergencyAccessUpdateRequestModel model) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + if (emergencyAccess == null) + { + throw new NotFoundException(); + } - [HttpPost("{id}/takeover")] - public async Task Takeover(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var (result, grantor) = await _emergencyAccessService.TakeoverAsync(id, user); - return new EmergencyAccessTakeoverResponseModel(result, grantor); - } + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.SaveAsync(model.ToEmergencyAccess(emergencyAccess), user); + } - [HttpPost("{id}/password")] - public async Task Password(Guid id, [FromBody] EmergencyAccessPasswordRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.PasswordAsync(id, user, model.NewMasterPasswordHash, model.Key); - } + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(Guid id) + { + var userId = _userService.GetProperUserId(User); + await _emergencyAccessService.DeleteAsync(id, userId.Value); + } - [HttpPost("{id}/view")] - public async Task ViewCiphers(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var viewResult = await _emergencyAccessService.ViewAsync(id, user); - return new EmergencyAccessViewResponseModel(_globalSettings, viewResult.EmergencyAccess, viewResult.Ciphers); - } + [HttpPost("invite")] + public async Task Invite([FromBody] EmergencyAccessInviteRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.InviteAsync(user, model.Email, model.Type.Value, model.WaitTimeDays); + } - [HttpGet("{id}/{cipherId}/attachment/{attachmentId}")] - public async Task GetAttachmentData(Guid id, Guid cipherId, string attachmentId) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var result = - await _emergencyAccessService.GetAttachmentDownloadAsync(id, cipherId, attachmentId, user); - return new AttachmentResponseModel(result); + [HttpPost("{id}/reinvite")] + public async Task Reinvite(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.ResendInviteAsync(user, id); + } + + [HttpPost("{id}/accept")] + public async Task Accept(Guid id, [FromBody] OrganizationUserAcceptRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.AcceptUserAsync(id, user, model.Token, _userService); + } + + [HttpPost("{id}/confirm")] + public async Task Confirm(Guid id, [FromBody] OrganizationUserConfirmRequestModel model) + { + var userId = _userService.GetProperUserId(User); + await _emergencyAccessService.ConfirmUserAsync(id, model.Key, userId.Value); + } + + [HttpPost("{id}/initiate")] + public async Task Initiate(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.InitiateAsync(id, user); + } + + [HttpPost("{id}/approve")] + public async Task Accept(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.ApproveAsync(id, user); + } + + [HttpPost("{id}/reject")] + public async Task Reject(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.RejectAsync(id, user); + } + + [HttpPost("{id}/takeover")] + public async Task Takeover(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var (result, grantor) = await _emergencyAccessService.TakeoverAsync(id, user); + return new EmergencyAccessTakeoverResponseModel(result, grantor); + } + + [HttpPost("{id}/password")] + public async Task Password(Guid id, [FromBody] EmergencyAccessPasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.PasswordAsync(id, user, model.NewMasterPasswordHash, model.Key); + } + + [HttpPost("{id}/view")] + public async Task ViewCiphers(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var viewResult = await _emergencyAccessService.ViewAsync(id, user); + return new EmergencyAccessViewResponseModel(_globalSettings, viewResult.EmergencyAccess, viewResult.Ciphers); + } + + [HttpGet("{id}/{cipherId}/attachment/{attachmentId}")] + public async Task GetAttachmentData(Guid id, Guid cipherId, string attachmentId) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var result = + await _emergencyAccessService.GetAttachmentDownloadAsync(id, cipherId, attachmentId, user); + return new AttachmentResponseModel(result); + } } } diff --git a/src/Api/Controllers/EventsController.cs b/src/Api/Controllers/EventsController.cs index 4fd1496b0c..ad657692e1 100644 --- a/src/Api/Controllers/EventsController.cs +++ b/src/Api/Controllers/EventsController.cs @@ -7,170 +7,171 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("events")] -[Authorize("Application")] -public class EventsController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - private readonly ICipherRepository _cipherRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IEventRepository _eventRepository; - private readonly ICurrentContext _currentContext; - - public EventsController( - IUserService userService, - ICipherRepository cipherRepository, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IEventRepository eventRepository, - ICurrentContext currentContext) + [Route("events")] + [Authorize("Application")] + public class EventsController : Controller { - _userService = userService; - _cipherRepository = cipherRepository; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _eventRepository = eventRepository; - _currentContext = currentContext; - } + private readonly IUserService _userService; + private readonly ICipherRepository _cipherRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IEventRepository _eventRepository; + private readonly ICurrentContext _currentContext; - [HttpGet("")] - public async Task> GetUser( - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) - { - var dateRange = GetDateRange(start, end); - var userId = _userService.GetProperUserId(User).Value; - var result = await _eventRepository.GetManyByUserAsync(userId, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); - } - - [HttpGet("~/ciphers/{id}/events")] - public async Task> GetCipher(string id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) - { - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null) + public EventsController( + IUserService userService, + ICipherRepository cipherRepository, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IEventRepository eventRepository, + ICurrentContext currentContext) { - throw new NotFoundException(); + _userService = userService; + _cipherRepository = cipherRepository; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _eventRepository = eventRepository; + _currentContext = currentContext; } - var canView = false; - if (cipher.OrganizationId.HasValue) - { - canView = await _currentContext.AccessEventLogs(cipher.OrganizationId.Value); - } - else if (cipher.UserId.HasValue) + [HttpGet("")] + public async Task> GetUser( + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) { + var dateRange = GetDateRange(start, end); var userId = _userService.GetProperUserId(User).Value; - canView = userId == cipher.UserId.Value; + var result = await _eventRepository.GetManyByUserAsync(userId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); } - if (!canView) + [HttpGet("~/ciphers/{id}/events")] + public async Task> GetCipher(string id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) { - throw new NotFoundException(); + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null) + { + throw new NotFoundException(); + } + + var canView = false; + if (cipher.OrganizationId.HasValue) + { + canView = await _currentContext.AccessEventLogs(cipher.OrganizationId.Value); + } + else if (cipher.UserId.HasValue) + { + var userId = _userService.GetProperUserId(User).Value; + canView = userId == cipher.UserId.Value; + } + + if (!canView) + { + throw new NotFoundException(); + } + + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByCipherAsync(cipher, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); } - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByCipherAsync(cipher, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); - } - - [HttpGet("~/organizations/{id}/events")] - public async Task> GetOrganization(string id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) - { - var orgId = new Guid(id); - if (!await _currentContext.AccessEventLogs(orgId)) + [HttpGet("~/organizations/{id}/events")] + public async Task> GetOrganization(string id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) { - throw new NotFoundException(); + var orgId = new Guid(id); + if (!await _currentContext.AccessEventLogs(orgId)) + { + throw new NotFoundException(); + } + + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); } - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); - } - - [HttpGet("~/organizations/{orgId}/users/{id}/events")] - public async Task> GetOrganizationUser(string orgId, string id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) - { - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || !organizationUser.UserId.HasValue || - !await _currentContext.AccessEventLogs(organizationUser.OrganizationId)) + [HttpGet("~/organizations/{orgId}/users/{id}/events")] + public async Task> GetOrganizationUser(string orgId, string id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) { - throw new NotFoundException(); + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || !organizationUser.UserId.HasValue || + !await _currentContext.AccessEventLogs(organizationUser.OrganizationId)) + { + throw new NotFoundException(); + } + + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByOrganizationActingUserAsync(organizationUser.OrganizationId, + organizationUser.UserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); } - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByOrganizationActingUserAsync(organizationUser.OrganizationId, - organizationUser.UserId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); - } - - [HttpGet("~/providers/{providerId:guid}/events")] - public async Task> GetProvider(Guid providerId, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) - { - if (!_currentContext.ProviderAccessEventLogs(providerId)) + [HttpGet("~/providers/{providerId:guid}/events")] + public async Task> GetProvider(Guid providerId, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) { - throw new NotFoundException(); + if (!_currentContext.ProviderAccessEventLogs(providerId)) + { + throw new NotFoundException(); + } + + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByProviderAsync(providerId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); } - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByProviderAsync(providerId, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); - } - - [HttpGet("~/providers/{providerId:guid}/users/{id:guid}/events")] - public async Task> GetProviderUser(Guid providerId, Guid id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) - { - var providerUser = await _providerUserRepository.GetByIdAsync(id); - if (providerUser == null || !providerUser.UserId.HasValue || - !_currentContext.ProviderAccessEventLogs(providerUser.ProviderId)) + [HttpGet("~/providers/{providerId:guid}/users/{id:guid}/events")] + public async Task> GetProviderUser(Guid providerId, Guid id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) { - throw new NotFoundException(); + var providerUser = await _providerUserRepository.GetByIdAsync(id); + if (providerUser == null || !providerUser.UserId.HasValue || + !_currentContext.ProviderAccessEventLogs(providerUser.ProviderId)) + { + throw new NotFoundException(); + } + + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByProviderActingUserAsync(providerUser.ProviderId, + providerUser.UserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); } - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByProviderActingUserAsync(providerUser.ProviderId, - providerUser.UserId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); - } - - private Tuple GetDateRange(DateTime? start, DateTime? end) - { - if (!end.HasValue || !start.HasValue) + private Tuple GetDateRange(DateTime? start, DateTime? end) { - end = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); - start = DateTime.UtcNow.Date.AddDays(-30); - } - else if (start.Value > end.Value) - { - var newEnd = start; - start = end; - end = newEnd; - } + if (!end.HasValue || !start.HasValue) + { + end = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); + start = DateTime.UtcNow.Date.AddDays(-30); + } + else if (start.Value > end.Value) + { + var newEnd = start; + start = end; + end = newEnd; + } - if ((end.Value - start.Value) > TimeSpan.FromDays(367)) - { - throw new BadRequestException("Range too large."); - } + if ((end.Value - start.Value) > TimeSpan.FromDays(367)) + { + throw new BadRequestException("Range too large."); + } - return new Tuple(start.Value, end.Value); + return new Tuple(start.Value, end.Value); + } } } diff --git a/src/Api/Controllers/FoldersController.cs b/src/Api/Controllers/FoldersController.cs index b387809ecd..752856302a 100644 --- a/src/Api/Controllers/FoldersController.cs +++ b/src/Api/Controllers/FoldersController.cs @@ -6,83 +6,84 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("folders")] -[Authorize("Application")] -public class FoldersController : Controller +namespace Bit.Api.Controllers { - private readonly IFolderRepository _folderRepository; - private readonly ICipherService _cipherService; - private readonly IUserService _userService; - - public FoldersController( - IFolderRepository folderRepository, - ICipherService cipherService, - IUserService userService) + [Route("folders")] + [Authorize("Application")] + public class FoldersController : Controller { - _folderRepository = folderRepository; - _cipherService = cipherService; - _userService = userService; - } + private readonly IFolderRepository _folderRepository; + private readonly ICipherService _cipherService; + private readonly IUserService _userService; - [HttpGet("{id}")] - public async Task Get(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); - if (folder == null) + public FoldersController( + IFolderRepository folderRepository, + ICipherService cipherService, + IUserService userService) { - throw new NotFoundException(); + _folderRepository = folderRepository; + _cipherService = cipherService; + _userService = userService; } - return new FolderResponseModel(folder); - } - - [HttpGet("")] - public async Task> Get() - { - var userId = _userService.GetProperUserId(User).Value; - var folders = await _folderRepository.GetManyByUserIdAsync(userId); - var responses = folders.Select(f => new FolderResponseModel(f)); - return new ListResponseModel(responses); - } - - [HttpPost("")] - public async Task Post([FromBody] FolderRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = model.ToFolder(_userService.GetProperUserId(User).Value); - await _cipherService.SaveFolderAsync(folder); - return new FolderResponseModel(folder); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string id, [FromBody] FolderRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); - if (folder == null) + [HttpGet("{id}")] + public async Task Get(string id) { - throw new NotFoundException(); + var userId = _userService.GetProperUserId(User).Value; + var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); + if (folder == null) + { + throw new NotFoundException(); + } + + return new FolderResponseModel(folder); } - await _cipherService.SaveFolderAsync(model.ToFolder(folder)); - return new FolderResponseModel(folder); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); - if (folder == null) + [HttpGet("")] + public async Task> Get() { - throw new NotFoundException(); + var userId = _userService.GetProperUserId(User).Value; + var folders = await _folderRepository.GetManyByUserIdAsync(userId); + var responses = folders.Select(f => new FolderResponseModel(f)); + return new ListResponseModel(responses); } - await _cipherService.DeleteFolderAsync(folder); + [HttpPost("")] + public async Task Post([FromBody] FolderRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = model.ToFolder(_userService.GetProperUserId(User).Value); + await _cipherService.SaveFolderAsync(folder); + return new FolderResponseModel(folder); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string id, [FromBody] FolderRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); + if (folder == null) + { + throw new NotFoundException(); + } + + await _cipherService.SaveFolderAsync(model.ToFolder(folder)); + return new FolderResponseModel(folder); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); + if (folder == null) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteFolderAsync(folder); + } } } diff --git a/src/Api/Controllers/GroupsController.cs b/src/Api/Controllers/GroupsController.cs index d38ba03bc1..fc226c4d11 100644 --- a/src/Api/Controllers/GroupsController.cs +++ b/src/Api/Controllers/GroupsController.cs @@ -7,145 +7,146 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organizations/{orgId}/groups")] -[Authorize("Application")] -public class GroupsController : Controller +namespace Bit.Api.Controllers { - private readonly IGroupRepository _groupRepository; - private readonly IGroupService _groupService; - private readonly ICurrentContext _currentContext; - - public GroupsController( - IGroupRepository groupRepository, - IGroupService groupService, - ICurrentContext currentContext) + [Route("organizations/{orgId}/groups")] + [Authorize("Application")] + public class GroupsController : Controller { - _groupRepository = groupRepository; - _groupService = groupService; - _currentContext = currentContext; - } + private readonly IGroupRepository _groupRepository; + private readonly IGroupService _groupService; + private readonly ICurrentContext _currentContext; - [HttpGet("{id}")] - public async Task Get(string orgId, string id) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + public GroupsController( + IGroupRepository groupRepository, + IGroupService groupService, + ICurrentContext currentContext) { - throw new NotFoundException(); + _groupRepository = groupRepository; + _groupService = groupService; + _currentContext = currentContext; } - return new GroupResponseModel(group); - } - - [HttpGet("{id}/details")] - public async Task GetDetails(string orgId, string id) - { - var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(new Guid(id)); - if (groupDetails?.Item1 == null || !await _currentContext.ManageGroups(groupDetails.Item1.OrganizationId)) + [HttpGet("{id}")] + public async Task Get(string orgId, string id) { - throw new NotFoundException(); + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } + + return new GroupResponseModel(group); } - return new GroupDetailsResponseModel(groupDetails.Item1, groupDetails.Item2); - } - - [HttpGet("")] - public async Task> Get(string orgId) - { - var orgIdGuid = new Guid(orgId); - var canAccess = await _currentContext.ManageGroups(orgIdGuid) || - await _currentContext.ViewAssignedCollections(orgIdGuid) || - await _currentContext.ViewAllCollections(orgIdGuid) || - await _currentContext.ManageUsers(orgIdGuid); - - if (!canAccess) + [HttpGet("{id}/details")] + public async Task GetDetails(string orgId, string id) { - throw new NotFoundException(); + var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(new Guid(id)); + if (groupDetails?.Item1 == null || !await _currentContext.ManageGroups(groupDetails.Item1.OrganizationId)) + { + throw new NotFoundException(); + } + + return new GroupDetailsResponseModel(groupDetails.Item1, groupDetails.Item2); } - var groups = await _groupRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = groups.Select(g => new GroupResponseModel(g)); - return new ListResponseModel(responses); - } - - [HttpGet("{id}/users")] - public async Task> GetUsers(string orgId, string id) - { - var idGuid = new Guid(id); - var group = await _groupRepository.GetByIdAsync(idGuid); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + [HttpGet("")] + public async Task> Get(string orgId) { - throw new NotFoundException(); + var orgIdGuid = new Guid(orgId); + var canAccess = await _currentContext.ManageGroups(orgIdGuid) || + await _currentContext.ViewAssignedCollections(orgIdGuid) || + await _currentContext.ViewAllCollections(orgIdGuid) || + await _currentContext.ManageUsers(orgIdGuid); + + if (!canAccess) + { + throw new NotFoundException(); + } + + var groups = await _groupRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = groups.Select(g => new GroupResponseModel(g)); + return new ListResponseModel(responses); } - var groupIds = await _groupRepository.GetManyUserIdsByIdAsync(idGuid); - return groupIds; - } - - [HttpPost("")] - public async Task Post(string orgId, [FromBody] GroupRequestModel model) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManageGroups(orgIdGuid)) + [HttpGet("{id}/users")] + public async Task> GetUsers(string orgId, string id) { - throw new NotFoundException(); + var idGuid = new Guid(id); + var group = await _groupRepository.GetByIdAsync(idGuid); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } + + var groupIds = await _groupRepository.GetManyUserIdsByIdAsync(idGuid); + return groupIds; } - var group = model.ToGroup(orgIdGuid); - await _groupService.SaveAsync(group, model.Collections?.Select(c => c.ToSelectionReadOnly())); - return new GroupResponseModel(group); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string orgId, string id, [FromBody] GroupRequestModel model) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + [HttpPost("")] + public async Task Post(string orgId, [FromBody] GroupRequestModel model) { - throw new NotFoundException(); + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManageGroups(orgIdGuid)) + { + throw new NotFoundException(); + } + + var group = model.ToGroup(orgIdGuid); + await _groupService.SaveAsync(group, model.Collections?.Select(c => c.ToSelectionReadOnly())); + return new GroupResponseModel(group); } - await _groupService.SaveAsync(model.ToGroup(group), model.Collections?.Select(c => c.ToSelectionReadOnly())); - return new GroupResponseModel(group); - } - - [HttpPut("{id}/users")] - public async Task PutUsers(string orgId, string id, [FromBody] IEnumerable model) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string orgId, string id, [FromBody] GroupRequestModel model) { - throw new NotFoundException(); - } - await _groupRepository.UpdateUsersAsync(group.Id, model); - } + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string orgId, string id) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); + await _groupService.SaveAsync(model.ToGroup(group), model.Collections?.Select(c => c.ToSelectionReadOnly())); + return new GroupResponseModel(group); } - await _groupService.DeleteAsync(group); - } - - [HttpDelete("{id}/user/{orgUserId}")] - [HttpPost("{id}/delete-user/{orgUserId}")] - public async Task Delete(string orgId, string id, string orgUserId) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + [HttpPut("{id}/users")] + public async Task PutUsers(string orgId, string id, [FromBody] IEnumerable model) { - throw new NotFoundException(); + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } + await _groupRepository.UpdateUsersAsync(group.Id, model); } - await _groupService.DeleteUserAsync(group, new Guid(orgUserId)); + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string orgId, string id) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } + + await _groupService.DeleteAsync(group); + } + + [HttpDelete("{id}/user/{orgUserId}")] + [HttpPost("{id}/delete-user/{orgUserId}")] + public async Task Delete(string orgId, string id, string orgUserId) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } + + await _groupService.DeleteUserAsync(group, new Guid(orgUserId)); + } } } diff --git a/src/Api/Controllers/HibpController.cs b/src/Api/Controllers/HibpController.cs index 517ffb5ef9..3b94901b01 100644 --- a/src/Api/Controllers/HibpController.cs +++ b/src/Api/Controllers/HibpController.cs @@ -8,90 +8,91 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("hibp")] -[Authorize("Application")] -public class HibpController : Controller +namespace Bit.Api.Controllers { - private const string HibpBreachApi = "https://haveibeenpwned.com/api/v3/breachedaccount/{0}" + - "?truncateResponse=false&includeUnverified=false"; - private static HttpClient _httpClient; - - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - private readonly string _userAgent; - - static HibpController() + [Route("hibp")] + [Authorize("Application")] + public class HibpController : Controller { - _httpClient = new HttpClient(); - } + private const string HibpBreachApi = "https://haveibeenpwned.com/api/v3/breachedaccount/{0}" + + "?truncateResponse=false&includeUnverified=false"; + private static HttpClient _httpClient; - public HibpController( - IUserService userService, - ICurrentContext currentContext, - GlobalSettings globalSettings) - { - _userService = userService; - _currentContext = currentContext; - _globalSettings = globalSettings; - _userAgent = _globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"; - } + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + private readonly string _userAgent; - [HttpGet("breach")] - public async Task Get(string username) - { - return await SendAsync(WebUtility.UrlEncode(username), true); - } - - private async Task SendAsync(string username, bool retry) - { - if (!CoreHelpers.SettingHasValue(_globalSettings.HibpApiKey)) + static HibpController() { - throw new BadRequestException("HaveIBeenPwned API key not set."); + _httpClient = new HttpClient(); } - var request = new HttpRequestMessage(HttpMethod.Get, string.Format(HibpBreachApi, username)); - request.Headers.Add("hibp-api-key", _globalSettings.HibpApiKey); - request.Headers.Add("hibp-client-id", GetClientId()); - request.Headers.Add("User-Agent", _userAgent); - var response = await _httpClient.SendAsync(request); - if (response.IsSuccessStatusCode) + + public HibpController( + IUserService userService, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - var data = await response.Content.ReadAsStringAsync(); - return Content(data, "application/json"); + _userService = userService; + _currentContext = currentContext; + _globalSettings = globalSettings; + _userAgent = _globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"; } - else if (response.StatusCode == HttpStatusCode.NotFound) + + [HttpGet("breach")] + public async Task Get(string username) { - return new NotFoundResult(); + return await SendAsync(WebUtility.UrlEncode(username), true); } - else if (response.StatusCode == HttpStatusCode.TooManyRequests && retry) + + private async Task SendAsync(string username, bool retry) { - var delay = 2000; - if (response.Headers.Contains("retry-after")) + if (!CoreHelpers.SettingHasValue(_globalSettings.HibpApiKey)) { - var vals = response.Headers.GetValues("retry-after"); - if (vals.Any() && int.TryParse(vals.FirstOrDefault(), out var secDelay)) - { - delay = (secDelay * 1000) + 200; - } + throw new BadRequestException("HaveIBeenPwned API key not set."); + } + var request = new HttpRequestMessage(HttpMethod.Get, string.Format(HibpBreachApi, username)); + request.Headers.Add("hibp-api-key", _globalSettings.HibpApiKey); + request.Headers.Add("hibp-client-id", GetClientId()); + request.Headers.Add("User-Agent", _userAgent); + var response = await _httpClient.SendAsync(request); + if (response.IsSuccessStatusCode) + { + var data = await response.Content.ReadAsStringAsync(); + return Content(data, "application/json"); + } + else if (response.StatusCode == HttpStatusCode.NotFound) + { + return new NotFoundResult(); + } + else if (response.StatusCode == HttpStatusCode.TooManyRequests && retry) + { + var delay = 2000; + if (response.Headers.Contains("retry-after")) + { + var vals = response.Headers.GetValues("retry-after"); + if (vals.Any() && int.TryParse(vals.FirstOrDefault(), out var secDelay)) + { + delay = (secDelay * 1000) + 200; + } + } + await Task.Delay(delay); + return await SendAsync(username, false); + } + else + { + throw new BadRequestException("Request failed. Status code: " + response.StatusCode); } - await Task.Delay(delay); - return await SendAsync(username, false); } - else - { - throw new BadRequestException("Request failed. Status code: " + response.StatusCode); - } - } - private string GetClientId() - { - var userId = _userService.GetProperUserId(User).Value; - using (var sha256 = SHA256.Create()) + private string GetClientId() { - var hash = sha256.ComputeHash(userId.ToByteArray()); - return Convert.ToBase64String(hash); + var userId = _userService.GetProperUserId(User).Value; + using (var sha256 = SHA256.Create()) + { + var hash = sha256.ComputeHash(userId.ToByteArray()); + return Convert.ToBase64String(hash); + } } } } diff --git a/src/Api/Controllers/InfoController.cs b/src/Api/Controllers/InfoController.cs index 739f9f4257..206ba68109 100644 --- a/src/Api/Controllers/InfoController.cs +++ b/src/Api/Controllers/InfoController.cs @@ -1,34 +1,35 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -public class InfoController : Controller +namespace Bit.Api.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } - - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } - - [HttpGet("~/ip")] - public JsonResult Ip() - { - var headerSet = new HashSet { "x-forwarded-for", "cf-connecting-ip", "client-ip" }; - var headers = HttpContext.Request?.Headers - .Where(h => headerSet.Contains(h.Key.ToLower())) - .ToDictionary(h => h.Key); - return new JsonResult(new + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() { - Ip = HttpContext.Connection?.RemoteIpAddress?.ToString(), - Headers = headers, - }); + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } + + [HttpGet("~/ip")] + public JsonResult Ip() + { + var headerSet = new HashSet { "x-forwarded-for", "cf-connecting-ip", "client-ip" }; + var headers = HttpContext.Request?.Headers + .Where(h => headerSet.Contains(h.Key.ToLower())) + .ToDictionary(h => h.Key); + return new JsonResult(new + { + Ip = HttpContext.Connection?.RemoteIpAddress?.ToString(), + Headers = headers, + }); + } } } diff --git a/src/Api/Controllers/InstallationsController.cs b/src/Api/Controllers/InstallationsController.cs index a2eeebab37..c75468b476 100644 --- a/src/Api/Controllers/InstallationsController.cs +++ b/src/Api/Controllers/InstallationsController.cs @@ -6,39 +6,40 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("installations")] -[SelfHosted(NotSelfHostedOnly = true)] -public class InstallationsController : Controller +namespace Bit.Api.Controllers { - private readonly IInstallationRepository _installationRepository; - - public InstallationsController( - IInstallationRepository installationRepository) + [Route("installations")] + [SelfHosted(NotSelfHostedOnly = true)] + public class InstallationsController : Controller { - _installationRepository = installationRepository; - } + private readonly IInstallationRepository _installationRepository; - [HttpGet("{id}")] - [AllowAnonymous] - public async Task Get(Guid id) - { - var installation = await _installationRepository.GetByIdAsync(id); - if (installation == null) + public InstallationsController( + IInstallationRepository installationRepository) { - throw new NotFoundException(); + _installationRepository = installationRepository; } - return new InstallationResponseModel(installation, false); - } + [HttpGet("{id}")] + [AllowAnonymous] + public async Task Get(Guid id) + { + var installation = await _installationRepository.GetByIdAsync(id); + if (installation == null) + { + throw new NotFoundException(); + } - [HttpPost("")] - [AllowAnonymous] - public async Task Post([FromBody] InstallationRequestModel model) - { - var installation = model.ToInstallation(); - await _installationRepository.CreateAsync(installation); - return new InstallationResponseModel(installation, true); + return new InstallationResponseModel(installation, false); + } + + [HttpPost("")] + [AllowAnonymous] + public async Task Post([FromBody] InstallationRequestModel model) + { + var installation = model.ToInstallation(); + await _installationRepository.CreateAsync(installation); + return new InstallationResponseModel(installation, true); + } } } diff --git a/src/Api/Controllers/LicensesController.cs b/src/Api/Controllers/LicensesController.cs index 63ed824795..4de0798850 100644 --- a/src/Api/Controllers/LicensesController.cs +++ b/src/Api/Controllers/LicensesController.cs @@ -7,69 +7,70 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("licenses")] -[Authorize("Licensing")] -[SelfHosted(NotSelfHostedOnly = true)] -public class LicensesController : Controller +namespace Bit.Api.Controllers { - private readonly ILicensingService _licensingService; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; - - public LicensesController( - ILicensingService licensingService, - IUserRepository userRepository, - IUserService userService, - IOrganizationRepository organizationRepository, - IOrganizationService organizationService, - ICurrentContext currentContext) + [Route("licenses")] + [Authorize("Licensing")] + [SelfHosted(NotSelfHostedOnly = true)] + public class LicensesController : Controller { - _licensingService = licensingService; - _userRepository = userRepository; - _userService = userService; - _organizationRepository = organizationRepository; - _organizationService = organizationService; - _currentContext = currentContext; - } + private readonly ILicensingService _licensingService; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; - [HttpGet("user/{id}")] - public async Task GetUser(string id, [FromQuery] string key) - { - var user = await _userRepository.GetByIdAsync(new Guid(id)); - if (user == null) + public LicensesController( + ILicensingService licensingService, + IUserRepository userRepository, + IUserService userService, + IOrganizationRepository organizationRepository, + IOrganizationService organizationService, + ICurrentContext currentContext) { - return null; - } - else if (!user.LicenseKey.Equals(key)) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid license key."); + _licensingService = licensingService; + _userRepository = userRepository; + _userService = userService; + _organizationRepository = organizationRepository; + _organizationService = organizationService; + _currentContext = currentContext; } - var license = await _userService.GenerateLicenseAsync(user, null); - return license; - } + [HttpGet("user/{id}")] + public async Task GetUser(string id, [FromQuery] string key) + { + var user = await _userRepository.GetByIdAsync(new Guid(id)); + if (user == null) + { + return null; + } + else if (!user.LicenseKey.Equals(key)) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid license key."); + } - [HttpGet("organization/{id}")] - public async Task GetOrganization(string id, [FromQuery] string key) - { - var org = await _organizationRepository.GetByIdAsync(new Guid(id)); - if (org == null) - { - return null; - } - else if (!org.LicenseKey.Equals(key)) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid license key."); + var license = await _userService.GenerateLicenseAsync(user, null); + return license; } - var license = await _organizationService.GenerateLicenseAsync(org, _currentContext.InstallationId.Value); - return license; + [HttpGet("organization/{id}")] + public async Task GetOrganization(string id, [FromQuery] string key) + { + var org = await _organizationRepository.GetByIdAsync(new Guid(id)); + if (org == null) + { + return null; + } + else if (!org.LicenseKey.Equals(key)) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid license key."); + } + + var license = await _organizationService.GenerateLicenseAsync(org, _currentContext.InstallationId.Value); + return license; + } } } diff --git a/src/Api/Controllers/MiscController.cs b/src/Api/Controllers/MiscController.cs index 6f23a27fbf..edd4dfde4c 100644 --- a/src/Api/Controllers/MiscController.cs +++ b/src/Api/Controllers/MiscController.cs @@ -5,41 +5,42 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Stripe; -namespace Bit.Api.Controllers; - -public class MiscController : Controller +namespace Bit.Api.Controllers { - private readonly BitPayClient _bitPayClient; - private readonly GlobalSettings _globalSettings; - - public MiscController( - BitPayClient bitPayClient, - GlobalSettings globalSettings) + public class MiscController : Controller { - _bitPayClient = bitPayClient; - _globalSettings = globalSettings; - } + private readonly BitPayClient _bitPayClient; + private readonly GlobalSettings _globalSettings; - [Authorize("Application")] - [HttpPost("~/bitpay-invoice")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostBitPayInvoice([FromBody] BitPayInvoiceRequestModel model) - { - var invoice = await _bitPayClient.CreateInvoiceAsync(model.ToBitpayInvoice(_globalSettings)); - return invoice.Url; - } - - [Authorize("Application")] - [HttpPost("~/setup-payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSetupPayment() - { - var options = new SetupIntentCreateOptions + public MiscController( + BitPayClient bitPayClient, + GlobalSettings globalSettings) { - Usage = "off_session" - }; - var service = new SetupIntentService(); - var setupIntent = await service.CreateAsync(options); - return setupIntent.ClientSecret; + _bitPayClient = bitPayClient; + _globalSettings = globalSettings; + } + + [Authorize("Application")] + [HttpPost("~/bitpay-invoice")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostBitPayInvoice([FromBody] BitPayInvoiceRequestModel model) + { + var invoice = await _bitPayClient.CreateInvoiceAsync(model.ToBitpayInvoice(_globalSettings)); + return invoice.Url; + } + + [Authorize("Application")] + [HttpPost("~/setup-payment")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostSetupPayment() + { + var options = new SetupIntentCreateOptions + { + Usage = "off_session" + }; + var service = new SetupIntentService(); + var setupIntent = await service.CreateAsync(options); + return setupIntent.ClientSecret; + } } } diff --git a/src/Api/Controllers/OrganizationConnectionsController.cs b/src/Api/Controllers/OrganizationConnectionsController.cs index 73754dba76..83f7a6ed73 100644 --- a/src/Api/Controllers/OrganizationConnectionsController.cs +++ b/src/Api/Controllers/OrganizationConnectionsController.cs @@ -12,198 +12,199 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Authorize("Application")] -[Route("organizations/connections")] -public class OrganizationConnectionsController : Controller +namespace Bit.Api.Controllers { - private readonly ICreateOrganizationConnectionCommand _createOrganizationConnectionCommand; - private readonly IUpdateOrganizationConnectionCommand _updateOrganizationConnectionCommand; - private readonly IDeleteOrganizationConnectionCommand _deleteOrganizationConnectionCommand; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ICurrentContext _currentContext; - private readonly IGlobalSettings _globalSettings; - private readonly ILicensingService _licensingService; - - public OrganizationConnectionsController( - ICreateOrganizationConnectionCommand createOrganizationConnectionCommand, - IUpdateOrganizationConnectionCommand updateOrganizationConnectionCommand, - IDeleteOrganizationConnectionCommand deleteOrganizationConnectionCommand, - IOrganizationConnectionRepository organizationConnectionRepository, - ICurrentContext currentContext, - IGlobalSettings globalSettings, - ILicensingService licensingService) + [Authorize("Application")] + [Route("organizations/connections")] + public class OrganizationConnectionsController : Controller { - _createOrganizationConnectionCommand = createOrganizationConnectionCommand; - _updateOrganizationConnectionCommand = updateOrganizationConnectionCommand; - _deleteOrganizationConnectionCommand = deleteOrganizationConnectionCommand; - _organizationConnectionRepository = organizationConnectionRepository; - _currentContext = currentContext; - _globalSettings = globalSettings; - _licensingService = licensingService; - } + private readonly ICreateOrganizationConnectionCommand _createOrganizationConnectionCommand; + private readonly IUpdateOrganizationConnectionCommand _updateOrganizationConnectionCommand; + private readonly IDeleteOrganizationConnectionCommand _deleteOrganizationConnectionCommand; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ICurrentContext _currentContext; + private readonly IGlobalSettings _globalSettings; + private readonly ILicensingService _licensingService; - [HttpGet("enabled")] - public bool ConnectionsEnabled() - { - return _globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication; - } - - [HttpPost] - public async Task CreateConnection([FromBody] OrganizationConnectionRequestModel model) - { - if (!await HasPermissionAsync(model?.OrganizationId)) + public OrganizationConnectionsController( + ICreateOrganizationConnectionCommand createOrganizationConnectionCommand, + IUpdateOrganizationConnectionCommand updateOrganizationConnectionCommand, + IDeleteOrganizationConnectionCommand deleteOrganizationConnectionCommand, + IOrganizationConnectionRepository organizationConnectionRepository, + ICurrentContext currentContext, + IGlobalSettings globalSettings, + ILicensingService licensingService) { - throw new BadRequestException($"You do not have permission to create a connection of type {model.Type}."); + _createOrganizationConnectionCommand = createOrganizationConnectionCommand; + _updateOrganizationConnectionCommand = updateOrganizationConnectionCommand; + _deleteOrganizationConnectionCommand = deleteOrganizationConnectionCommand; + _organizationConnectionRepository = organizationConnectionRepository; + _currentContext = currentContext; + _globalSettings = globalSettings; + _licensingService = licensingService; } - if (await HasConnectionTypeAsync(model, null, model.Type)) + [HttpGet("enabled")] + public bool ConnectionsEnabled() { - throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); + return _globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication; } - switch (model.Type) + [HttpPost] + public async Task CreateConnection([FromBody] OrganizationConnectionRequestModel model) { - case OrganizationConnectionType.CloudBillingSync: - return await CreateOrUpdateOrganizationConnectionAsync(null, model, ValidateBillingSyncConfig); - case OrganizationConnectionType.Scim: - return await CreateOrUpdateOrganizationConnectionAsync(null, model); - default: - throw new BadRequestException($"Unknown Organization connection Type: {model.Type}"); - } - } + if (!await HasPermissionAsync(model?.OrganizationId)) + { + throw new BadRequestException($"You do not have permission to create a connection of type {model.Type}."); + } - [HttpPut("{organizationConnectionId}")] - public async Task UpdateConnection(Guid organizationConnectionId, [FromBody] OrganizationConnectionRequestModel model) - { - var existingOrganizationConnection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); - if (existingOrganizationConnection == null) - { - throw new NotFoundException(); + if (await HasConnectionTypeAsync(model, null, model.Type)) + { + throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); + } + + switch (model.Type) + { + case OrganizationConnectionType.CloudBillingSync: + return await CreateOrUpdateOrganizationConnectionAsync(null, model, ValidateBillingSyncConfig); + case OrganizationConnectionType.Scim: + return await CreateOrUpdateOrganizationConnectionAsync(null, model); + default: + throw new BadRequestException($"Unknown Organization connection Type: {model.Type}"); + } } - if (!await HasPermissionAsync(model?.OrganizationId, model?.Type)) + [HttpPut("{organizationConnectionId}")] + public async Task UpdateConnection(Guid organizationConnectionId, [FromBody] OrganizationConnectionRequestModel model) { - throw new BadRequestException("You do not have permission to update this connection."); + var existingOrganizationConnection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); + if (existingOrganizationConnection == null) + { + throw new NotFoundException(); + } + + if (!await HasPermissionAsync(model?.OrganizationId, model?.Type)) + { + throw new BadRequestException("You do not have permission to update this connection."); + } + + if (await HasConnectionTypeAsync(model, organizationConnectionId, model.Type)) + { + throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); + } + + switch (model.Type) + { + case OrganizationConnectionType.CloudBillingSync: + return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); + case OrganizationConnectionType.Scim: + return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); + default: + throw new BadRequestException($"Unkown Organization connection Type: {model.Type}"); + } } - if (await HasConnectionTypeAsync(model, organizationConnectionId, model.Type)) + [HttpGet("{organizationId}/{type}")] + public async Task GetConnection(Guid organizationId, OrganizationConnectionType type) { - throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); + if (!await HasPermissionAsync(organizationId, type)) + { + throw new BadRequestException($"You do not have permission to retrieve a connection of type {type}."); + } + + var connections = await GetConnectionsAsync(organizationId, type); + var connection = connections.FirstOrDefault(c => c.Type == type); + + switch (type) + { + case OrganizationConnectionType.CloudBillingSync: + if (!_globalSettings.SelfHosted) + { + throw new BadRequestException($"Cannot get a {type} connection outside of a self-hosted instance."); + } + return new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); + case OrganizationConnectionType.Scim: + return new OrganizationConnectionResponseModel(connection, typeof(ScimConfig)); + default: + throw new BadRequestException($"Unkown Organization connection Type: {type}"); + } } - switch (model.Type) + [HttpDelete("{organizationConnectionId}")] + [HttpPost("{organizationConnectionId}/delete")] + public async Task DeleteConnection(Guid organizationConnectionId) { - case OrganizationConnectionType.CloudBillingSync: - return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); - case OrganizationConnectionType.Scim: - return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); - default: - throw new BadRequestException($"Unkown Organization connection Type: {model.Type}"); - } - } + var connection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); - [HttpGet("{organizationId}/{type}")] - public async Task GetConnection(Guid organizationId, OrganizationConnectionType type) - { - if (!await HasPermissionAsync(organizationId, type)) - { - throw new BadRequestException($"You do not have permission to retrieve a connection of type {type}."); + if (connection == null) + { + throw new NotFoundException(); + } + + if (!await HasPermissionAsync(connection.OrganizationId, connection.Type)) + { + throw new BadRequestException($"You do not have permission to remove this connection of type {connection.Type}."); + } + + await _deleteOrganizationConnectionCommand.DeleteAsync(connection); } - var connections = await GetConnectionsAsync(organizationId, type); - var connection = connections.FirstOrDefault(c => c.Type == type); + private async Task> GetConnectionsAsync(Guid organizationId, OrganizationConnectionType type) => + await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organizationId, type); - switch (type) + private async Task HasConnectionTypeAsync(OrganizationConnectionRequestModel model, Guid? connectionId, + OrganizationConnectionType type) { - case OrganizationConnectionType.CloudBillingSync: - if (!_globalSettings.SelfHosted) - { - throw new BadRequestException($"Cannot get a {type} connection outside of a self-hosted instance."); - } - return new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); - case OrganizationConnectionType.Scim: - return new OrganizationConnectionResponseModel(connection, typeof(ScimConfig)); - default: - throw new BadRequestException($"Unkown Organization connection Type: {type}"); - } - } + var existingConnections = await GetConnectionsAsync(model.OrganizationId, type); - [HttpDelete("{organizationConnectionId}")] - [HttpPost("{organizationConnectionId}/delete")] - public async Task DeleteConnection(Guid organizationConnectionId) - { - var connection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); - - if (connection == null) - { - throw new NotFoundException(); + return existingConnections.Any(c => c.Type == model.Type && (!connectionId.HasValue || c.Id != connectionId.Value)); } - if (!await HasPermissionAsync(connection.OrganizationId, connection.Type)) + private async Task HasPermissionAsync(Guid? organizationId, OrganizationConnectionType? type = null) { - throw new BadRequestException($"You do not have permission to remove this connection of type {connection.Type}."); + if (!organizationId.HasValue) + { + return false; + } + return type switch + { + OrganizationConnectionType.Scim => await _currentContext.ManageScim(organizationId.Value), + _ => await _currentContext.OrganizationOwner(organizationId.Value), + }; } - await _deleteOrganizationConnectionCommand.DeleteAsync(connection); - } - - private async Task> GetConnectionsAsync(Guid organizationId, OrganizationConnectionType type) => - await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organizationId, type); - - private async Task HasConnectionTypeAsync(OrganizationConnectionRequestModel model, Guid? connectionId, - OrganizationConnectionType type) - { - var existingConnections = await GetConnectionsAsync(model.OrganizationId, type); - - return existingConnections.Any(c => c.Type == model.Type && (!connectionId.HasValue || c.Id != connectionId.Value)); - } - - private async Task HasPermissionAsync(Guid? organizationId, OrganizationConnectionType? type = null) - { - if (!organizationId.HasValue) + private async Task ValidateBillingSyncConfig(OrganizationConnectionRequestModel typedModel) { - return false; - } - return type switch - { - OrganizationConnectionType.Scim => await _currentContext.ManageScim(organizationId.Value), - _ => await _currentContext.OrganizationOwner(organizationId.Value), - }; - } - - private async Task ValidateBillingSyncConfig(OrganizationConnectionRequestModel typedModel) - { - if (!_globalSettings.SelfHosted) - { - throw new BadRequestException($"Cannot create a {typedModel.Type} connection outside of a self-hosted instance."); - } - var license = await _licensingService.ReadOrganizationLicenseAsync(typedModel.OrganizationId); - if (!_licensingService.VerifyLicense(license)) - { - throw new BadRequestException("Cannot verify license file."); - } - typedModel.ParsedConfig.CloudOrganizationId = license.Id; - } - - private async Task CreateOrUpdateOrganizationConnectionAsync( - Guid? organizationConnectionId, - OrganizationConnectionRequestModel model, - Func, Task> validateAction = null) - where T : new() - { - var typedModel = new OrganizationConnectionRequestModel(model); - if (validateAction != null) - { - await validateAction(typedModel); + if (!_globalSettings.SelfHosted) + { + throw new BadRequestException($"Cannot create a {typedModel.Type} connection outside of a self-hosted instance."); + } + var license = await _licensingService.ReadOrganizationLicenseAsync(typedModel.OrganizationId); + if (!_licensingService.VerifyLicense(license)) + { + throw new BadRequestException("Cannot verify license file."); + } + typedModel.ParsedConfig.CloudOrganizationId = license.Id; } - var data = typedModel.ToData(organizationConnectionId); - var connection = organizationConnectionId.HasValue - ? await _updateOrganizationConnectionCommand.UpdateAsync(data) - : await _createOrganizationConnectionCommand.CreateAsync(data); + private async Task CreateOrUpdateOrganizationConnectionAsync( + Guid? organizationConnectionId, + OrganizationConnectionRequestModel model, + Func, Task> validateAction = null) + where T : new() + { + var typedModel = new OrganizationConnectionRequestModel(model); + if (validateAction != null) + { + await validateAction(typedModel); + } - return new OrganizationConnectionResponseModel(connection, typeof(T)); + var data = typedModel.ToData(organizationConnectionId); + var connection = organizationConnectionId.HasValue + ? await _updateOrganizationConnectionCommand.UpdateAsync(data) + : await _createOrganizationConnectionCommand.CreateAsync(data); + + return new OrganizationConnectionResponseModel(connection, typeof(T)); + } } } diff --git a/src/Api/Controllers/OrganizationExportController.cs b/src/Api/Controllers/OrganizationExportController.cs index f2a2265f9d..dd04fe0098 100644 --- a/src/Api/Controllers/OrganizationExportController.cs +++ b/src/Api/Controllers/OrganizationExportController.cs @@ -6,58 +6,59 @@ using Core.Models.Data; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organizations/{organizationId}")] -[Authorize("Application")] -public class OrganizationExportController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - private readonly ICollectionService _collectionService; - private readonly ICipherService _cipherService; - private readonly GlobalSettings _globalSettings; - - public OrganizationExportController( - ICipherService cipherService, - ICollectionService collectionService, - IUserService userService, - GlobalSettings globalSettings) + [Route("organizations/{organizationId}")] + [Authorize("Application")] + public class OrganizationExportController : Controller { - _cipherService = cipherService; - _collectionService = collectionService; - _userService = userService; - _globalSettings = globalSettings; - } + private readonly IUserService _userService; + private readonly ICollectionService _collectionService; + private readonly ICipherService _cipherService; + private readonly GlobalSettings _globalSettings; - [HttpGet("export")] - public async Task Export(Guid organizationId) - { - var userId = _userService.GetProperUserId(User).Value; - - IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(organizationId); - (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, organizationId); - - var result = new OrganizationExportResponseModel + public OrganizationExportController( + ICipherService cipherService, + ICollectionService collectionService, + IUserService userService, + GlobalSettings globalSettings) { - Collections = GetOrganizationCollectionsResponse(orgCollections), - Ciphers = GetOrganizationCiphersResponse(orgCiphers, collectionCiphersGroupDict) - }; + _cipherService = cipherService; + _collectionService = collectionService; + _userService = userService; + _globalSettings = globalSettings; + } - return result; - } + [HttpGet("export")] + public async Task Export(Guid organizationId) + { + var userId = _userService.GetProperUserId(User).Value; - private ListResponseModel GetOrganizationCollectionsResponse(IEnumerable orgCollections) - { - var collections = orgCollections.Select(c => new CollectionResponseModel(c)); - return new ListResponseModel(collections); - } + IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(organizationId); + (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, organizationId); - private ListResponseModel GetOrganizationCiphersResponse(IEnumerable orgCiphers, - Dictionary> collectionCiphersGroupDict) - { - var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, - collectionCiphersGroupDict, c.OrganizationUseTotp)); + var result = new OrganizationExportResponseModel + { + Collections = GetOrganizationCollectionsResponse(orgCollections), + Ciphers = GetOrganizationCiphersResponse(orgCiphers, collectionCiphersGroupDict) + }; - return new ListResponseModel(responses); + return result; + } + + private ListResponseModel GetOrganizationCollectionsResponse(IEnumerable orgCollections) + { + var collections = orgCollections.Select(c => new CollectionResponseModel(c)); + return new ListResponseModel(collections); + } + + private ListResponseModel GetOrganizationCiphersResponse(IEnumerable orgCiphers, + Dictionary> collectionCiphersGroupDict) + { + var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, + collectionCiphersGroupDict, c.OrganizationUseTotp)); + + return new ListResponseModel(responses); + } } } diff --git a/src/Api/Controllers/OrganizationSponsorshipsController.cs b/src/Api/Controllers/OrganizationSponsorshipsController.cs index fc5d38db1c..ae7be386cd 100644 --- a/src/Api/Controllers/OrganizationSponsorshipsController.cs +++ b/src/Api/Controllers/OrganizationSponsorshipsController.cs @@ -12,179 +12,180 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organization/sponsorship")] -public class OrganizationSponsorshipsController : Controller +namespace Bit.Api.Controllers { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IValidateRedemptionTokenCommand _validateRedemptionTokenCommand; - private readonly IValidateBillingSyncKeyCommand _validateBillingSyncKeyCommand; - private readonly ICreateSponsorshipCommand _createSponsorshipCommand; - private readonly ISendSponsorshipOfferCommand _sendSponsorshipOfferCommand; - private readonly ISetUpSponsorshipCommand _setUpSponsorshipCommand; - private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; - private readonly IRemoveSponsorshipCommand _removeSponsorshipCommand; - private readonly ICloudSyncSponsorshipsCommand _syncSponsorshipsCommand; - private readonly ICurrentContext _currentContext; - private readonly IUserService _userService; - - public OrganizationSponsorshipsController( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IValidateRedemptionTokenCommand validateRedemptionTokenCommand, - IValidateBillingSyncKeyCommand validateBillingSyncKeyCommand, - ICreateSponsorshipCommand createSponsorshipCommand, - ISendSponsorshipOfferCommand sendSponsorshipOfferCommand, - ISetUpSponsorshipCommand setUpSponsorshipCommand, - IRevokeSponsorshipCommand revokeSponsorshipCommand, - IRemoveSponsorshipCommand removeSponsorshipCommand, - ICloudSyncSponsorshipsCommand syncSponsorshipsCommand, - IUserService userService, - ICurrentContext currentContext) + [Route("organization/sponsorship")] + public class OrganizationSponsorshipsController : Controller { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _validateRedemptionTokenCommand = validateRedemptionTokenCommand; - _validateBillingSyncKeyCommand = validateBillingSyncKeyCommand; - _createSponsorshipCommand = createSponsorshipCommand; - _sendSponsorshipOfferCommand = sendSponsorshipOfferCommand; - _setUpSponsorshipCommand = setUpSponsorshipCommand; - _revokeSponsorshipCommand = revokeSponsorshipCommand; - _removeSponsorshipCommand = removeSponsorshipCommand; - _syncSponsorshipsCommand = syncSponsorshipsCommand; - _userService = userService; - _currentContext = currentContext; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IValidateRedemptionTokenCommand _validateRedemptionTokenCommand; + private readonly IValidateBillingSyncKeyCommand _validateBillingSyncKeyCommand; + private readonly ICreateSponsorshipCommand _createSponsorshipCommand; + private readonly ISendSponsorshipOfferCommand _sendSponsorshipOfferCommand; + private readonly ISetUpSponsorshipCommand _setUpSponsorshipCommand; + private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; + private readonly IRemoveSponsorshipCommand _removeSponsorshipCommand; + private readonly ICloudSyncSponsorshipsCommand _syncSponsorshipsCommand; + private readonly ICurrentContext _currentContext; + private readonly IUserService _userService; - [Authorize("Application")] - [HttpPost("{sponsoringOrgId}/families-for-enterprise")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) - { - var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); - - var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( - sponsoringOrg, - await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), - model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); - await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); - } - - [Authorize("Application")] - [HttpPost("{sponsoringOrgId}/families-for-enterprise/resend")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task ResendSponsorshipOffer(Guid sponsoringOrgId) - { - var sponsoringOrgUser = await _organizationUserRepository - .GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); - - await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync( - await _organizationRepository.GetByIdAsync(sponsoringOrgId), - sponsoringOrgUser, - await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id)); - } - - [Authorize("Application")] - [HttpPost("validate-token")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PreValidateSponsorshipToken([FromQuery] string sponsorshipToken) - { - return (await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email)).valid; - } - - [Authorize("Application")] - [HttpPost("redeem")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task RedeemSponsorship([FromQuery] string sponsorshipToken, [FromBody] OrganizationSponsorshipRedeemRequestModel model) - { - var (valid, sponsorship) = await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email); - - if (!valid) + public OrganizationSponsorshipsController( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IValidateRedemptionTokenCommand validateRedemptionTokenCommand, + IValidateBillingSyncKeyCommand validateBillingSyncKeyCommand, + ICreateSponsorshipCommand createSponsorshipCommand, + ISendSponsorshipOfferCommand sendSponsorshipOfferCommand, + ISetUpSponsorshipCommand setUpSponsorshipCommand, + IRevokeSponsorshipCommand revokeSponsorshipCommand, + IRemoveSponsorshipCommand removeSponsorshipCommand, + ICloudSyncSponsorshipsCommand syncSponsorshipsCommand, + IUserService userService, + ICurrentContext currentContext) { - throw new BadRequestException("Failed to parse sponsorship token."); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _validateRedemptionTokenCommand = validateRedemptionTokenCommand; + _validateBillingSyncKeyCommand = validateBillingSyncKeyCommand; + _createSponsorshipCommand = createSponsorshipCommand; + _sendSponsorshipOfferCommand = sendSponsorshipOfferCommand; + _setUpSponsorshipCommand = setUpSponsorshipCommand; + _revokeSponsorshipCommand = revokeSponsorshipCommand; + _removeSponsorshipCommand = removeSponsorshipCommand; + _syncSponsorshipsCommand = syncSponsorshipsCommand; + _userService = userService; + _currentContext = currentContext; } - if (!await _currentContext.OrganizationOwner(model.SponsoredOrganizationId)) + [Authorize("Application")] + [HttpPost("{sponsoringOrgId}/families-for-enterprise")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) { - throw new BadRequestException("Can only redeem sponsorship for an organization you own."); + var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); + + var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( + sponsoringOrg, + await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), + model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); + await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); } - await _setUpSponsorshipCommand.SetUpSponsorshipAsync( - sponsorship, - await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId)); - } - - [Authorize("Installation")] - [HttpPost("sync")] - public async Task Sync([FromBody] OrganizationSponsorshipSyncRequestModel model) - { - var sponsoringOrg = await _organizationRepository.GetByIdAsync(model.SponsoringOrganizationCloudId); - if (!await _validateBillingSyncKeyCommand.ValidateBillingSyncKeyAsync(sponsoringOrg, model.BillingSyncKey)) + [Authorize("Application")] + [HttpPost("{sponsoringOrgId}/families-for-enterprise/resend")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task ResendSponsorshipOffer(Guid sponsoringOrgId) { - throw new BadRequestException("Invalid Billing Sync Key"); + var sponsoringOrgUser = await _organizationUserRepository + .GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); + + await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync( + await _organizationRepository.GetByIdAsync(sponsoringOrgId), + sponsoringOrgUser, + await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id)); } - var (syncResponseData, offersToSend) = await _syncSponsorshipsCommand.SyncOrganization(sponsoringOrg, model.ToOrganizationSponsorshipSync().SponsorshipsBatch); - await _sendSponsorshipOfferCommand.BulkSendSponsorshipOfferAsync(sponsoringOrg.Name, offersToSend); - return new OrganizationSponsorshipSyncResponseModel(syncResponseData); - } - - [Authorize("Application")] - [HttpDelete("{sponsoringOrganizationId}")] - [HttpPost("{sponsoringOrganizationId}/delete")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task RevokeSponsorship(Guid sponsoringOrganizationId) - { - - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default); - if (_currentContext.UserId != orgUser?.UserId) + [Authorize("Application")] + [HttpPost("validate-token")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PreValidateSponsorshipToken([FromQuery] string sponsorshipToken) { - throw new BadRequestException("Can only revoke a sponsorship you granted."); + return (await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email)).valid; } - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); - - await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); - } - - [Authorize("Application")] - [HttpDelete("sponsored/{sponsoredOrgId}")] - [HttpPost("sponsored/{sponsoredOrgId}/remove")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task RemoveSponsorship(Guid sponsoredOrgId) - { - - if (!await _currentContext.OrganizationOwner(sponsoredOrgId)) + [Authorize("Application")] + [HttpPost("redeem")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task RedeemSponsorship([FromQuery] string sponsorshipToken, [FromBody] OrganizationSponsorshipRedeemRequestModel model) { - throw new BadRequestException("Only the owner of an organization can remove sponsorship."); + var (valid, sponsorship) = await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email); + + if (!valid) + { + throw new BadRequestException("Failed to parse sponsorship token."); + } + + if (!await _currentContext.OrganizationOwner(model.SponsoredOrganizationId)) + { + throw new BadRequestException("Can only redeem sponsorship for an organization you own."); + } + + await _setUpSponsorshipCommand.SetUpSponsorshipAsync( + sponsorship, + await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId)); } - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoredOrganizationIdAsync(sponsoredOrgId); - - await _removeSponsorshipCommand.RemoveSponsorshipAsync(existingOrgSponsorship); - } - - [HttpGet("{sponsoringOrgId}/sync-status")] - public async Task GetSyncStatus(Guid sponsoringOrgId) - { - var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); - - if (!await _currentContext.OrganizationOwner(sponsoringOrg.Id)) + [Authorize("Installation")] + [HttpPost("sync")] + public async Task Sync([FromBody] OrganizationSponsorshipSyncRequestModel model) { - throw new NotFoundException(); + var sponsoringOrg = await _organizationRepository.GetByIdAsync(model.SponsoringOrganizationCloudId); + if (!await _validateBillingSyncKeyCommand.ValidateBillingSyncKeyAsync(sponsoringOrg, model.BillingSyncKey)) + { + throw new BadRequestException("Invalid Billing Sync Key"); + } + + var (syncResponseData, offersToSend) = await _syncSponsorshipsCommand.SyncOrganization(sponsoringOrg, model.ToOrganizationSponsorshipSync().SponsorshipsBatch); + await _sendSponsorshipOfferCommand.BulkSendSponsorshipOfferAsync(sponsoringOrg.Name, offersToSend); + return new OrganizationSponsorshipSyncResponseModel(syncResponseData); } - var lastSyncDate = await _organizationSponsorshipRepository.GetLatestSyncDateBySponsoringOrganizationIdAsync(sponsoringOrg.Id); - return new OrganizationSponsorshipSyncStatusResponseModel(lastSyncDate); - } + [Authorize("Application")] + [HttpDelete("{sponsoringOrganizationId}")] + [HttpPost("{sponsoringOrganizationId}/delete")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task RevokeSponsorship(Guid sponsoringOrganizationId) + { - private Task CurrentUser => _userService.GetUserByIdAsync(_currentContext.UserId.Value); + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default); + if (_currentContext.UserId != orgUser?.UserId) + { + throw new BadRequestException("Can only revoke a sponsorship you granted."); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); + + await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); + } + + [Authorize("Application")] + [HttpDelete("sponsored/{sponsoredOrgId}")] + [HttpPost("sponsored/{sponsoredOrgId}/remove")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task RemoveSponsorship(Guid sponsoredOrgId) + { + + if (!await _currentContext.OrganizationOwner(sponsoredOrgId)) + { + throw new BadRequestException("Only the owner of an organization can remove sponsorship."); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoredOrganizationIdAsync(sponsoredOrgId); + + await _removeSponsorshipCommand.RemoveSponsorshipAsync(existingOrgSponsorship); + } + + [HttpGet("{sponsoringOrgId}/sync-status")] + public async Task GetSyncStatus(Guid sponsoringOrgId) + { + var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); + + if (!await _currentContext.OrganizationOwner(sponsoringOrg.Id)) + { + throw new NotFoundException(); + } + + var lastSyncDate = await _organizationSponsorshipRepository.GetLatestSyncDateBySponsoringOrganizationIdAsync(sponsoringOrg.Id); + return new OrganizationSponsorshipSyncStatusResponseModel(lastSyncDate); + } + + private Task CurrentUser => _userService.GetUserByIdAsync(_currentContext.UserId.Value); + } } diff --git a/src/Api/Controllers/OrganizationUsersController.cs b/src/Api/Controllers/OrganizationUsersController.cs index 64340e3ede..b1e5451ebc 100644 --- a/src/Api/Controllers/OrganizationUsersController.cs +++ b/src/Api/Controllers/OrganizationUsersController.cs @@ -13,459 +13,460 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organizations/{orgId}/users")] -[Authorize("Application")] -public class OrganizationUsersController : Controller +namespace Bit.Api.Controllers { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; - private readonly ICollectionRepository _collectionRepository; - private readonly IGroupRepository _groupRepository; - private readonly IUserService _userService; - private readonly IPolicyRepository _policyRepository; - private readonly ICurrentContext _currentContext; - - public OrganizationUsersController( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, - ICollectionRepository collectionRepository, - IGroupRepository groupRepository, - IUserService userService, - IPolicyRepository policyRepository, - ICurrentContext currentContext) + [Route("organizations/{orgId}/users")] + [Authorize("Application")] + public class OrganizationUsersController : Controller { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; - _collectionRepository = collectionRepository; - _groupRepository = groupRepository; - _userService = userService; - _policyRepository = policyRepository; - _currentContext = currentContext; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationService _organizationService; + private readonly ICollectionRepository _collectionRepository; + private readonly IGroupRepository _groupRepository; + private readonly IUserService _userService; + private readonly IPolicyRepository _policyRepository; + private readonly ICurrentContext _currentContext; - [HttpGet("{id}")] - public async Task Get(string orgId, string id) - { - var organizationUser = await _organizationUserRepository.GetByIdWithCollectionsAsync(new Guid(id)); - if (organizationUser == null || !await _currentContext.ManageUsers(organizationUser.Item1.OrganizationId)) + public OrganizationUsersController( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationService organizationService, + ICollectionRepository collectionRepository, + IGroupRepository groupRepository, + IUserService userService, + IPolicyRepository policyRepository, + ICurrentContext currentContext) { - throw new NotFoundException(); + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _organizationService = organizationService; + _collectionRepository = collectionRepository; + _groupRepository = groupRepository; + _userService = userService; + _policyRepository = policyRepository; + _currentContext = currentContext; } - return new OrganizationUserDetailsResponseModel(organizationUser.Item1, organizationUser.Item2); - } - - [HttpGet("")] - public async Task> Get(string orgId) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ViewAllCollections(orgGuidId) && - !await _currentContext.ViewAssignedCollections(orgGuidId) && - !await _currentContext.ManageGroups(orgGuidId) && - !await _currentContext.ManageUsers(orgGuidId)) + [HttpGet("{id}")] + public async Task Get(string orgId, string id) { - throw new NotFoundException(); + var organizationUser = await _organizationUserRepository.GetByIdWithCollectionsAsync(new Guid(id)); + if (organizationUser == null || !await _currentContext.ManageUsers(organizationUser.Item1.OrganizationId)) + { + throw new NotFoundException(); + } + + return new OrganizationUserDetailsResponseModel(organizationUser.Item1, organizationUser.Item2); } - var organizationUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(orgGuidId); - var responseTasks = organizationUsers.Select(async o => new OrganizationUserUserDetailsResponseModel(o, - await _userService.TwoFactorIsEnabledAsync(o))); - var responses = await Task.WhenAll(responseTasks); - return new ListResponseModel(responses); - } - - [HttpGet("{id}/groups")] - public async Task> GetGroups(string orgId, string id) - { - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || (!await _currentContext.ManageGroups(organizationUser.OrganizationId) && - !await _currentContext.ManageUsers(organizationUser.OrganizationId))) + [HttpGet("")] + public async Task> Get(string orgId) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ViewAllCollections(orgGuidId) && + !await _currentContext.ViewAssignedCollections(orgGuidId) && + !await _currentContext.ManageGroups(orgGuidId) && + !await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var organizationUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(orgGuidId); + var responseTasks = organizationUsers.Select(async o => new OrganizationUserUserDetailsResponseModel(o, + await _userService.TwoFactorIsEnabledAsync(o))); + var responses = await Task.WhenAll(responseTasks); + return new ListResponseModel(responses); } - var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(organizationUser.Id); - var responses = groupIds.Select(g => g.ToString()); - return responses; - } - - [HttpGet("{id}/reset-password-details")] - public async Task GetResetPasswordDetails(string orgId, string id) - { - // Make sure the calling user can reset passwords for this org - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageResetPassword(orgGuidId)) + [HttpGet("{id}/groups")] + public async Task> GetGroups(string orgId, string id) { - throw new NotFoundException(); + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || (!await _currentContext.ManageGroups(organizationUser.OrganizationId) && + !await _currentContext.ManageUsers(organizationUser.OrganizationId))) + { + throw new NotFoundException(); + } + + var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(organizationUser.Id); + var responses = groupIds.Select(g => g.ToString()); + return responses; } - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || !organizationUser.UserId.HasValue) + [HttpGet("{id}/reset-password-details")] + public async Task GetResetPasswordDetails(string orgId, string id) { - throw new NotFoundException(); + // Make sure the calling user can reset passwords for this org + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageResetPassword(orgGuidId)) + { + throw new NotFoundException(); + } + + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || !organizationUser.UserId.HasValue) + { + throw new NotFoundException(); + } + + // Retrieve data necessary for response (KDF, KDF Iterations, ResetPasswordKey) + // TODO Reset Password - Revisit this and create SPROC to reduce DB calls + var user = await _userService.GetUserByIdAsync(organizationUser.UserId.Value); + if (user == null) + { + throw new NotFoundException(); + } + + // Retrieve Encrypted Private Key from organization + var org = await _organizationRepository.GetByIdAsync(orgGuidId); + if (org == null) + { + throw new NotFoundException(); + } + + return new OrganizationUserResetPasswordDetailsResponseModel(new OrganizationUserResetPasswordDetails(organizationUser, user, org)); } - // Retrieve data necessary for response (KDF, KDF Iterations, ResetPasswordKey) - // TODO Reset Password - Revisit this and create SPROC to reduce DB calls - var user = await _userService.GetUserByIdAsync(organizationUser.UserId.Value); - if (user == null) + [HttpPost("invite")] + public async Task Invite(string orgId, [FromBody] OrganizationUserInviteRequestModel model) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.InviteUsersAsync(orgGuidId, userId.Value, + new (OrganizationUserInvite, string)[] { (new OrganizationUserInvite(model.ToData()), null) }); } - // Retrieve Encrypted Private Key from organization - var org = await _organizationRepository.GetByIdAsync(orgGuidId); - if (org == null) + [HttpPost("reinvite")] + public async Task> BulkReinvite(string orgId, [FromBody] OrganizationUserBulkRequestModel model) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.ResendInvitesAsync(orgGuidId, userId.Value, model.Ids); + return new ListResponseModel( + result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); } - return new OrganizationUserResetPasswordDetailsResponseModel(new OrganizationUserResetPasswordDetails(organizationUser, user, org)); - } - - [HttpPost("invite")] - public async Task Invite(string orgId, [FromBody] OrganizationUserInviteRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPost("{id}/reinvite")] + public async Task Reinvite(string orgId, string id) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.ResendInviteAsync(orgGuidId, userId.Value, new Guid(id)); } - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.InviteUsersAsync(orgGuidId, userId.Value, - new (OrganizationUserInvite, string)[] { (new OrganizationUserInvite(model.ToData()), null) }); - } - - [HttpPost("reinvite")] - public async Task> BulkReinvite(string orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPost("{organizationUserId}/accept")] + public async Task Accept(Guid orgId, Guid organizationUserId, [FromBody] OrganizationUserAcceptRequestModel model) { - throw new NotFoundException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var masterPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + var useMasterPasswordPolicy = masterPasswordPolicy != null && + masterPasswordPolicy.Enabled && + masterPasswordPolicy.GetDataModel().AutoEnrollEnabled; + + if (useMasterPasswordPolicy && + string.IsNullOrWhiteSpace(model.ResetPasswordKey)) + { + throw new BadRequestException(string.Empty, "Master Password reset is required, but not provided."); + } + + await _organizationService.AcceptUserAsync(organizationUserId, user, model.Token, _userService); + + if (useMasterPasswordPolicy) + { + await _organizationService.UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); + } } - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.ResendInvitesAsync(orgGuidId, userId.Value, model.Ids); - return new ListResponseModel( - result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); - } - - [HttpPost("{id}/reinvite")] - public async Task Reinvite(string orgId, string id) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPost("{id}/confirm")] + public async Task Confirm(string orgId, string id, [FromBody] OrganizationUserConfirmRequestModel model) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.ConfirmUserAsync(orgGuidId, new Guid(id), model.Key, userId.Value, + _userService); } - var userId = _userService.GetProperUserId(User); - await _organizationService.ResendInviteAsync(orgGuidId, userId.Value, new Guid(id)); - } - - [HttpPost("{organizationUserId}/accept")] - public async Task Accept(Guid orgId, Guid organizationUserId, [FromBody] OrganizationUserAcceptRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("confirm")] + public async Task> BulkConfirm(string orgId, + [FromBody] OrganizationUserBulkConfirmRequestModel model) { - throw new UnauthorizedAccessException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var results = await _organizationService.ConfirmUsersAsync(orgGuidId, model.ToDictionary(), userId.Value, + _userService); + + return new ListResponseModel(results.Select(r => + new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); } - var masterPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - var useMasterPasswordPolicy = masterPasswordPolicy != null && - masterPasswordPolicy.Enabled && - masterPasswordPolicy.GetDataModel().AutoEnrollEnabled; - - if (useMasterPasswordPolicy && - string.IsNullOrWhiteSpace(model.ResetPasswordKey)) + [HttpPost("public-keys")] + public async Task> UserPublicKeys(string orgId, [FromBody] OrganizationUserBulkRequestModel model) { - throw new BadRequestException(string.Empty, "Master Password reset is required, but not provided."); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var result = await _organizationUserRepository.GetManyPublicKeysByOrganizationUserAsync(orgGuidId, model.Ids); + var responses = result.Select(r => new OrganizationUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); + return new ListResponseModel(responses); } - await _organizationService.AcceptUserAsync(organizationUserId, user, model.Token, _userService); - - if (useMasterPasswordPolicy) + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string orgId, string id, [FromBody] OrganizationUserUpdateRequestModel model) { - await _organizationService.UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); - } - } + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } - [HttpPost("{id}/confirm")] - public async Task Confirm(string orgId, string id, [FromBody] OrganizationUserConfirmRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.SaveUserAsync(model.ToOrganizationUser(organizationUser), userId.Value, + model.Collections?.Select(c => c.ToSelectionReadOnly())); } - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.ConfirmUserAsync(orgGuidId, new Guid(id), model.Key, userId.Value, - _userService); - } - - [HttpPost("confirm")] - public async Task> BulkConfirm(string orgId, - [FromBody] OrganizationUserBulkConfirmRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPut("{id}/groups")] + [HttpPost("{id}/groups")] + public async Task PutGroups(string orgId, string id, [FromBody] OrganizationUserUpdateGroupsRequestModel model) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) + { + throw new NotFoundException(); + } + + var loggedInUserId = _userService.GetProperUserId(User); + await _organizationService.UpdateUserGroupsAsync(organizationUser, model.GroupIds.Select(g => new Guid(g)), loggedInUserId); } - var userId = _userService.GetProperUserId(User); - var results = await _organizationService.ConfirmUsersAsync(orgGuidId, model.ToDictionary(), userId.Value, - _userService); - - return new ListResponseModel(results.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); - } - - [HttpPost("public-keys")] - public async Task> UserPublicKeys(string orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPut("{userId}/reset-password-enrollment")] + public async Task PutResetPasswordEnrollment(string orgId, string userId, [FromBody] OrganizationUserResetPasswordEnrollmentRequestModel model) { - throw new NotFoundException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (model.ResetPasswordKey != null && !await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + else + { + var callingUserId = user.Id; + await _organizationService.UpdateUserResetPasswordEnrollmentAsync( + new Guid(orgId), new Guid(userId), model.ResetPasswordKey, callingUserId); + } } - var result = await _organizationUserRepository.GetManyPublicKeysByOrganizationUserAsync(orgGuidId, model.Ids); - var responses = result.Select(r => new OrganizationUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); - return new ListResponseModel(responses); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string orgId, string id, [FromBody] OrganizationUserUpdateRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPut("{id}/reset-password")] + public async Task PutResetPassword(string orgId, string id, [FromBody] OrganizationUserResetPasswordRequestModel model) { - throw new NotFoundException(); - } - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) - { - throw new NotFoundException(); - } + var orgGuidId = new Guid(orgId); - var userId = _userService.GetProperUserId(User); - await _organizationService.SaveUserAsync(model.ToOrganizationUser(organizationUser), userId.Value, - model.Collections?.Select(c => c.ToSelectionReadOnly())); - } + // Calling user must have Manage Reset Password permission + if (!await _currentContext.ManageResetPassword(orgGuidId)) + { + throw new NotFoundException(); + } - [HttpPut("{id}/groups")] - [HttpPost("{id}/groups")] - public async Task PutGroups(string orgId, string id, [FromBody] OrganizationUserUpdateGroupsRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } + // Get the users role, since provider users aren't a member of the organization we use the owner check + var orgUserType = await _currentContext.OrganizationOwner(orgGuidId) + ? OrganizationUserType.Owner + : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgGuidId)?.Type; + if (orgUserType == null) + { + throw new NotFoundException(); + } - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) - { - throw new NotFoundException(); - } + var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgGuidId, new Guid(id), model.NewMasterPasswordHash, model.Key); + if (result.Succeeded) + { + return; + } - var loggedInUserId = _userService.GetProperUserId(User); - await _organizationService.UpdateUserGroupsAsync(organizationUser, model.GroupIds.Select(g => new Guid(g)), loggedInUserId); - } + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } - [HttpPut("{userId}/reset-password-enrollment")] - public async Task PutResetPasswordEnrollment(string orgId, string userId, [FromBody] OrganizationUserResetPasswordEnrollmentRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (model.ResetPasswordKey != null && !await _userService.VerifySecretAsync(user, model.Secret)) - { await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); + throw new BadRequestException(ModelState); } - else + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string orgId, string id) { - var callingUserId = user.Id; - await _organizationService.UpdateUserResetPasswordEnrollmentAsync( - new Guid(orgId), new Guid(userId), model.ResetPasswordKey, callingUserId); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.DeleteUserAsync(orgGuidId, new Guid(id), userId.Value); } - } - [HttpPut("{id}/reset-password")] - public async Task PutResetPassword(string orgId, string id, [FromBody] OrganizationUserResetPasswordRequestModel model) - { - - var orgGuidId = new Guid(orgId); - - // Calling user must have Manage Reset Password permission - if (!await _currentContext.ManageResetPassword(orgGuidId)) + [HttpDelete("")] + [HttpPost("delete")] + public async Task> BulkDelete(string orgId, [FromBody] OrganizationUserBulkRequestModel model) { - throw new NotFoundException(); + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.DeleteUsersAsync(orgGuidId, model.Ids, userId.Value); + return new ListResponseModel(result.Select(r => + new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); } - // Get the users role, since provider users aren't a member of the organization we use the owner check - var orgUserType = await _currentContext.OrganizationOwner(orgGuidId) - ? OrganizationUserType.Owner - : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgGuidId)?.Type; - if (orgUserType == null) + [Obsolete("2022-07-22 Moved to {id}/revoke endpoint")] + [HttpPatch("{id}/deactivate")] + [HttpPut("{id}/deactivate")] + public async Task Deactivate(Guid orgId, Guid id) { - throw new NotFoundException(); + await RevokeAsync(orgId, id); } - var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgGuidId, new Guid(id), model.NewMasterPasswordHash, model.Key); - if (result.Succeeded) + [Obsolete("2022-07-22 Moved to /revoke endpoint")] + [HttpPatch("deactivate")] + [HttpPut("deactivate")] + public async Task> BulkDeactivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - return; + return await BulkRevokeAsync(orgId, model); } - foreach (var error in result.Errors) + [Obsolete("2022-07-22 Moved to {id}/restore endpoint")] + [HttpPatch("{id}/activate")] + [HttpPut("{id}/activate")] + public async Task Activate(Guid orgId, Guid id) { - ModelState.AddModelError(string.Empty, error.Description); + await RestoreAsync(orgId, id); } - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string orgId, string id) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [Obsolete("2022-07-22 Moved to /restore endpoint")] + [HttpPatch("activate")] + [HttpPut("activate")] + public async Task> BulkActivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - throw new NotFoundException(); + return await BulkRestoreAsync(orgId, model); } - var userId = _userService.GetProperUserId(User); - await _organizationService.DeleteUserAsync(orgGuidId, new Guid(id), userId.Value); - } - - [HttpDelete("")] - [HttpPost("delete")] - public async Task> BulkDelete(string orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) + [HttpPatch("{id}/revoke")] + [HttpPut("{id}/revoke")] + public async Task RevokeAsync(Guid orgId, Guid id) { - throw new NotFoundException(); + await RestoreOrRevokeUserAsync(orgId, id, _organizationService.RevokeUserAsync); } - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.DeleteUsersAsync(orgGuidId, model.Ids, userId.Value); - return new ListResponseModel(result.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); - } - - [Obsolete("2022-07-22 Moved to {id}/revoke endpoint")] - [HttpPatch("{id}/deactivate")] - [HttpPut("{id}/deactivate")] - public async Task Deactivate(Guid orgId, Guid id) - { - await RevokeAsync(orgId, id); - } - - [Obsolete("2022-07-22 Moved to /revoke endpoint")] - [HttpPatch("deactivate")] - [HttpPut("deactivate")] - public async Task> BulkDeactivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await BulkRevokeAsync(orgId, model); - } - - [Obsolete("2022-07-22 Moved to {id}/restore endpoint")] - [HttpPatch("{id}/activate")] - [HttpPut("{id}/activate")] - public async Task Activate(Guid orgId, Guid id) - { - await RestoreAsync(orgId, id); - } - - [Obsolete("2022-07-22 Moved to /restore endpoint")] - [HttpPatch("activate")] - [HttpPut("activate")] - public async Task> BulkActivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await BulkRestoreAsync(orgId, model); - } - - [HttpPatch("{id}/revoke")] - [HttpPut("{id}/revoke")] - public async Task RevokeAsync(Guid orgId, Guid id) - { - await RestoreOrRevokeUserAsync(orgId, id, _organizationService.RevokeUserAsync); - } - - [HttpPatch("revoke")] - [HttpPut("revoke")] - public async Task> BulkRevokeAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await RestoreOrRevokeUsersAsync(orgId, model, _organizationService.RevokeUsersAsync); - } - - [HttpPatch("{id}/restore")] - [HttpPut("{id}/restore")] - public async Task RestoreAsync(Guid orgId, Guid id) - { - await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _organizationService.RestoreUserAsync(orgUser, userId, _userService)); - } - - [HttpPatch("restore")] - [HttpPut("restore")] - public async Task> BulkRestoreAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _organizationService.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService)); - } - - private async Task RestoreOrRevokeUserAsync( - Guid orgId, - Guid id, - Func statusAction) - { - if (!await _currentContext.ManageUsers(orgId)) + [HttpPatch("revoke")] + [HttpPut("revoke")] + public async Task> BulkRevokeAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - throw new NotFoundException(); + return await RestoreOrRevokeUsersAsync(orgId, model, _organizationService.RevokeUsersAsync); } - var userId = _userService.GetProperUserId(User); - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != orgId) + [HttpPatch("{id}/restore")] + [HttpPut("{id}/restore")] + public async Task RestoreAsync(Guid orgId, Guid id) { - throw new NotFoundException(); + await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _organizationService.RestoreUserAsync(orgUser, userId, _userService)); } - await statusAction(orgUser, userId); - } - - private async Task> RestoreOrRevokeUsersAsync( - Guid orgId, - OrganizationUserBulkRequestModel model, - Func, Guid?, Task>>> statusAction) - { - if (!await _currentContext.ManageUsers(orgId)) + [HttpPatch("restore")] + [HttpPut("restore")] + public async Task> BulkRestoreAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - throw new NotFoundException(); + return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _organizationService.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService)); } - var userId = _userService.GetProperUserId(User); - var result = await statusAction(orgId, model.Ids, userId.Value); - return new ListResponseModel(result.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + private async Task RestoreOrRevokeUserAsync( + Guid orgId, + Guid id, + Func statusAction) + { + if (!await _currentContext.ManageUsers(orgId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != orgId) + { + throw new NotFoundException(); + } + + await statusAction(orgUser, userId); + } + + private async Task> RestoreOrRevokeUsersAsync( + Guid orgId, + OrganizationUserBulkRequestModel model, + Func, Guid?, Task>>> statusAction) + { + if (!await _currentContext.ManageUsers(orgId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await statusAction(orgId, model.Ids, userId.Value); + return new ListResponseModel(result.Select(r => + new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + } } } diff --git a/src/Api/Controllers/OrganizationsController.cs b/src/Api/Controllers/OrganizationsController.cs index f38b0dbc3c..7a5b26d9ea 100644 --- a/src/Api/Controllers/OrganizationsController.cs +++ b/src/Api/Controllers/OrganizationsController.cs @@ -18,697 +18,698 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organizations")] -[Authorize("Application")] -public class OrganizationsController : Controller +namespace Bit.Api.Controllers { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; - private readonly IPaymentService _paymentService; - private readonly ICurrentContext _currentContext; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoConfigService _ssoConfigService; - private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; - private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly GlobalSettings _globalSettings; - - public OrganizationsController( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, - IOrganizationService organizationService, - IUserService userService, - IPaymentService paymentService, - ICurrentContext currentContext, - ISsoConfigRepository ssoConfigRepository, - ISsoConfigService ssoConfigService, - IGetOrganizationApiKeyCommand getOrganizationApiKeyCommand, - IRotateOrganizationApiKeyCommand rotateOrganizationApiKeyCommand, - IOrganizationApiKeyRepository organizationApiKeyRepository, - GlobalSettings globalSettings) + [Route("organizations")] + [Authorize("Application")] + public class OrganizationsController : Controller { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _policyRepository = policyRepository; - _organizationService = organizationService; - _userService = userService; - _paymentService = paymentService; - _currentContext = currentContext; - _ssoConfigRepository = ssoConfigRepository; - _ssoConfigService = ssoConfigService; - _getOrganizationApiKeyCommand = getOrganizationApiKeyCommand; - _rotateOrganizationApiKeyCommand = rotateOrganizationApiKeyCommand; - _organizationApiKeyRepository = organizationApiKeyRepository; - _globalSettings = globalSettings; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IOrganizationService _organizationService; + private readonly IUserService _userService; + private readonly IPaymentService _paymentService; + private readonly ICurrentContext _currentContext; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoConfigService _ssoConfigService; + private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; + private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + private readonly GlobalSettings _globalSettings; - [HttpGet("{id}")] - public async Task Get(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) + public OrganizationsController( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository, + IOrganizationService organizationService, + IUserService userService, + IPaymentService paymentService, + ICurrentContext currentContext, + ISsoConfigRepository ssoConfigRepository, + ISsoConfigService ssoConfigService, + IGetOrganizationApiKeyCommand getOrganizationApiKeyCommand, + IRotateOrganizationApiKeyCommand rotateOrganizationApiKeyCommand, + IOrganizationApiKeyRepository organizationApiKeyRepository, + GlobalSettings globalSettings) { - throw new NotFoundException(); + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _policyRepository = policyRepository; + _organizationService = organizationService; + _userService = userService; + _paymentService = paymentService; + _currentContext = currentContext; + _ssoConfigRepository = ssoConfigRepository; + _ssoConfigService = ssoConfigService; + _getOrganizationApiKeyCommand = getOrganizationApiKeyCommand; + _rotateOrganizationApiKeyCommand = rotateOrganizationApiKeyCommand; + _organizationApiKeyRepository = organizationApiKeyRepository; + _globalSettings = globalSettings; } - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) + [HttpGet("{id}")] + public async Task Get(string id) { - throw new NotFoundException(); - } - - return new OrganizationResponseModel(organization); - } - - [HttpGet("{id}/billing")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBilling(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var billingInfo = await _paymentService.GetBillingAsync(organization); - return new BillingResponseModel(billingInfo); - } - - [HttpGet("{id}/subscription")] - public async Task GetSubscription(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - if (!_globalSettings.SelfHosted && organization.Gateway != null) - { - var subscriptionInfo = await _paymentService.GetSubscriptionAsync(organization); - if (subscriptionInfo == null) + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) { throw new NotFoundException(); } - return new OrganizationSubscriptionResponseModel(organization, subscriptionInfo); - } - else - { - return new OrganizationSubscriptionResponseModel(organization); - } - } - [HttpGet("{id}/license")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetLicense(string id, [FromQuery] Guid installationId) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var license = await _organizationService.GenerateLicenseAsync(orgIdGuid, installationId); - if (license == null) - { - throw new NotFoundException(); - } - - return license; - } - - [HttpGet("")] - public async Task> GetUser() - { - var userId = _userService.GetProperUserId(User).Value; - var organizations = await _organizationUserRepository.GetManyDetailsByUserAsync(userId, - OrganizationUserStatusType.Confirmed); - var responses = organizations.Select(o => new ProfileOrganizationResponseModel(o)); - return new ListResponseModel(responses); - } - - [HttpGet("{identifier}/auto-enroll-status")] - public async Task GetAutoEnrollStatus(string identifier) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var organization = await _organizationRepository.GetByIdentifierAsync(identifier); - if (organization == null) - { - throw new NotFoundException(); - } - - var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id); - if (organizationUser == null) - { - throw new NotFoundException(); - } - - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) - { - return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false); - } - - var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); - return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); - } - - [HttpPost("")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Post([FromBody] OrganizationCreateRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var organizationSignup = model.ToOrganizationSignup(user); - var result = await _organizationService.SignUpAsync(organizationSignup); - return new OrganizationResponseModel(result.Item1); - } - - [HttpPost("license")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostLicense(OrganizationCreateLicenseRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - if (license == null) - { - throw new BadRequestException("Invalid license"); - } - - var result = await _organizationService.SignUpAsync(license, user, model.Key, - model.CollectionName, model.Keys?.PublicKey, model.Keys?.EncryptedPrivateKey); - return new OrganizationResponseModel(result.Item1); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string id, [FromBody] OrganizationUpdateRequestModel model) - { - var orgIdGuid = new Guid(id); - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var updateBilling = !_globalSettings.SelfHosted && (model.BusinessName != organization.BusinessName || - model.BillingEmail != organization.BillingEmail); - - var hasRequiredPermissions = updateBilling - ? await _currentContext.ManageBilling(orgIdGuid) - : await _currentContext.OrganizationOwner(orgIdGuid); - - if (!hasRequiredPermissions) - { - throw new NotFoundException(); - } - - await _organizationService.UpdateAsync(model.ToOrganization(organization, _globalSettings), updateBilling); - return new OrganizationResponseModel(organization); - } - - [HttpPost("{id}/payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPayment(string id, [FromBody] PaymentRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.ReplacePaymentMethodAsync(orgIdGuid, model.PaymentToken, - model.PaymentMethodType.Value, new TaxInfo + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { + throw new NotFoundException(); + } + + return new OrganizationResponseModel(organization); + } + + [HttpGet("{id}/billing")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetBilling(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var billingInfo = await _paymentService.GetBillingAsync(organization); + return new BillingResponseModel(billingInfo); + } + + [HttpGet("{id}/subscription")] + public async Task GetSubscription(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + if (!_globalSettings.SelfHosted && organization.Gateway != null) + { + var subscriptionInfo = await _paymentService.GetSubscriptionAsync(organization); + if (subscriptionInfo == null) + { + throw new NotFoundException(); + } + return new OrganizationSubscriptionResponseModel(organization, subscriptionInfo); + } + else + { + return new OrganizationSubscriptionResponseModel(organization); + } + } + + [HttpGet("{id}/license")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetLicense(string id, [FromQuery] Guid installationId) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var license = await _organizationService.GenerateLicenseAsync(orgIdGuid, installationId); + if (license == null) + { + throw new NotFoundException(); + } + + return license; + } + + [HttpGet("")] + public async Task> GetUser() + { + var userId = _userService.GetProperUserId(User).Value; + var organizations = await _organizationUserRepository.GetManyDetailsByUserAsync(userId, + OrganizationUserStatusType.Confirmed); + var responses = organizations.Select(o => new ProfileOrganizationResponseModel(o)); + return new ListResponseModel(responses); + } + + [HttpGet("{identifier}/auto-enroll-status")] + public async Task GetAutoEnrollStatus(string identifier) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var organization = await _organizationRepository.GetByIdentifierAsync(identifier); + if (organization == null) + { + throw new NotFoundException(); + } + + var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id); + if (organizationUser == null) + { + throw new NotFoundException(); + } + + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) + { + return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false); + } + + var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); + return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); + } + + [HttpPost("")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Post([FromBody] OrganizationCreateRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var organizationSignup = model.ToOrganizationSignup(user); + var result = await _organizationService.SignUpAsync(organizationSignup); + return new OrganizationResponseModel(result.Item1); + } + + [HttpPost("license")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostLicense(OrganizationCreateLicenseRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + if (license == null) + { + throw new BadRequestException("Invalid license"); + } + + var result = await _organizationService.SignUpAsync(license, user, model.Key, + model.CollectionName, model.Keys?.PublicKey, model.Keys?.EncryptedPrivateKey); + return new OrganizationResponseModel(result.Item1); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string id, [FromBody] OrganizationUpdateRequestModel model) + { + var orgIdGuid = new Guid(id); + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var updateBilling = !_globalSettings.SelfHosted && (model.BusinessName != organization.BusinessName || + model.BillingEmail != organization.BillingEmail); + + var hasRequiredPermissions = updateBilling + ? await _currentContext.ManageBilling(orgIdGuid) + : await _currentContext.OrganizationOwner(orgIdGuid); + + if (!hasRequiredPermissions) + { + throw new NotFoundException(); + } + + await _organizationService.UpdateAsync(model.ToOrganization(organization, _globalSettings), updateBilling); + return new OrganizationResponseModel(organization); + } + + [HttpPost("{id}/payment")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostPayment(string id, [FromBody] PaymentRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.ReplacePaymentMethodAsync(orgIdGuid, model.PaymentToken, + model.PaymentMethodType.Value, new TaxInfo + { + BillingAddressLine1 = model.Line1, + BillingAddressLine2 = model.Line2, + BillingAddressState = model.State, + BillingAddressCity = model.City, + BillingAddressPostalCode = model.PostalCode, + BillingAddressCountry = model.Country, + TaxIdNumber = model.TaxId, + }); + } + + [HttpPost("{id}/upgrade")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostUpgrade(string id, [FromBody] OrganizationUpgradeRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var result = await _organizationService.UpgradePlanAsync(orgIdGuid, model.ToOrganizationUpgrade()); + return new PaymentResponseModel + { + Success = result.Item1, + PaymentIntentClientSecret = result.Item2 + }; + } + + [HttpPost("{id}/subscription")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostSubscription(string id, [FromBody] OrganizationSubscriptionUpdateRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.UpdateSubscription(orgIdGuid, model.SeatAdjustment, model.MaxAutoscaleSeats); + } + + [HttpPost("{id}/seat")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostSeat(string id, [FromBody] OrganizationSeatRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var result = await _organizationService.AdjustSeatsAsync(orgIdGuid, model.SeatAdjustment.Value); + return new PaymentResponseModel + { + Success = true, + PaymentIntentClientSecret = result + }; + } + + [HttpPost("{id}/storage")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostStorage(string id, [FromBody] StorageRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var result = await _organizationService.AdjustStorageAsync(orgIdGuid, model.StorageGbAdjustment.Value); + return new PaymentResponseModel + { + Success = true, + PaymentIntentClientSecret = result + }; + } + + [HttpPost("{id}/verify-bank")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostVerifyBank(string id, [FromBody] OrganizationVerifyBankRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.VerifyBankAsync(orgIdGuid, model.Amount1.Value, model.Amount2.Value); + } + + [HttpPost("{id}/cancel")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostCancel(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.CancelSubscriptionAsync(orgIdGuid); + } + + [HttpPost("{id}/reinstate")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostReinstate(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.ReinstateSubscriptionAsync(orgIdGuid); + } + + [HttpPost("{id}/leave")] + public async Task Leave(string id) + { + var orgGuidId = new Guid(id); + if (!await _currentContext.OrganizationUser(orgGuidId)) + { + throw new NotFoundException(); + } + + var user = await _userService.GetUserByPrincipalAsync(User); + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgGuidId); + if (ssoConfig?.GetData()?.KeyConnectorEnabled == true && + user.UsesKeyConnector) + { + throw new BadRequestException("Your organization's Single Sign-On settings prevent you from leaving."); + } + + + await _organizationService.DeleteUserAsync(orgGuidId, user.Id); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id, [FromBody] SecretVerificationRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + else + { + await _organizationService.DeleteAsync(organization); + } + } + + [HttpPost("{id}/license")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostLicense(string id, LicenseRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + if (license == null) + { + throw new BadRequestException("Invalid license"); + } + + await _organizationService.UpdateLicenseAsync(new Guid(id), license); + } + + [HttpPost("{id}/import")] + public async Task Import(string id, [FromBody] ImportOrganizationUsersRequestModel model) + { + if (!_globalSettings.SelfHosted && !model.LargeImport && + (model.Groups.Count() > 2000 || model.Users.Count(u => !u.Deleted) > 2000)) + { + throw new BadRequestException("You cannot import this much data at once."); + } + + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationAdmin(orgIdGuid)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.ImportAsync( + orgIdGuid, + userId.Value, + model.Groups.Select(g => g.ToImportedGroup(orgIdGuid)), + model.Users.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), + model.Users.Where(u => u.Deleted).Select(u => u.ExternalId), + model.OverwriteExisting); + } + + [HttpPost("{id}/api-key")] + public async Task ApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + if (model.Type == OrganizationApiKeyType.BillingSync || model.Type == OrganizationApiKeyType.Scim) + { + // Non-enterprise orgs should not be able to create or view an apikey of billing sync/scim key types + var plan = StaticStore.GetPlan(organization.PlanType); + if (plan.Product != ProductType.Enterprise) + { + throw new NotFoundException(); + } + } + + var organizationApiKey = await _getOrganizationApiKeyCommand + .GetOrganizationApiKeyAsync(organization.Id, model.Type); + + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (model.Type != OrganizationApiKeyType.Scim + && !await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + else + { + var response = new ApiKeyResponseModel(organizationApiKey); + return response; + } + } + + [HttpGet("{id}/api-key-information/{type?}")] + public async Task> ApiKeyInformation(Guid id, OrganizationApiKeyType? type) + { + if (!await HasApiKeyAccessAsync(id, type)) + { + throw new NotFoundException(); + } + + var apiKeys = await _organizationApiKeyRepository.GetManyByOrganizationIdTypeAsync(id, type); + + return new ListResponseModel( + apiKeys.Select(k => new OrganizationApiKeyInformation(k))); + } + + [HttpPost("{id}/rotate-api-key")] + public async Task RotateApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var organizationApiKey = await _getOrganizationApiKeyCommand + .GetOrganizationApiKeyAsync(organization.Id, model.Type); + + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (model.Type != OrganizationApiKeyType.Scim + && !await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + else + { + await _rotateOrganizationApiKeyCommand.RotateApiKeyAsync(organizationApiKey); + var response = new ApiKeyResponseModel(organizationApiKey); + return response; + } + } + + private async Task HasApiKeyAccessAsync(Guid orgId, OrganizationApiKeyType? type) + { + return type switch + { + OrganizationApiKeyType.Scim => await _currentContext.ManageScim(orgId), + _ => await _currentContext.OrganizationOwner(orgId), + }; + } + + [HttpGet("{id}/tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetTaxInfo(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var taxInfo = await _paymentService.GetTaxInfoAsync(organization); + return new TaxInfoResponseModel(taxInfo); + } + + [HttpPut("{id}/tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PutTaxInfo(string id, [FromBody] OrganizationTaxInfoUpdateRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var taxInfo = new TaxInfo + { + TaxIdNumber = model.TaxId, BillingAddressLine1 = model.Line1, BillingAddressLine2 = model.Line2, - BillingAddressState = model.State, BillingAddressCity = model.City, + BillingAddressState = model.State, BillingAddressPostalCode = model.PostalCode, BillingAddressCountry = model.Country, - TaxIdNumber = model.TaxId, - }); - } - - [HttpPost("{id}/upgrade")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostUpgrade(string id, [FromBody] OrganizationUpgradeRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); + }; + await _paymentService.SaveTaxInfoAsync(organization, taxInfo); } - var result = await _organizationService.UpgradePlanAsync(orgIdGuid, model.ToOrganizationUpgrade()); - return new PaymentResponseModel + [HttpGet("{id}/keys")] + public async Task GetKeys(string id) { - Success = result.Item1, - PaymentIntentClientSecret = result.Item2 - }; - } - - [HttpPost("{id}/subscription")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSubscription(string id, [FromBody] OrganizationSubscriptionUpdateRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.UpdateSubscription(orgIdGuid, model.SeatAdjustment, model.MaxAutoscaleSeats); - } - - [HttpPost("{id}/seat")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSeat(string id, [FromBody] OrganizationSeatRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var result = await _organizationService.AdjustSeatsAsync(orgIdGuid, model.SeatAdjustment.Value); - return new PaymentResponseModel - { - Success = true, - PaymentIntentClientSecret = result - }; - } - - [HttpPost("{id}/storage")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostStorage(string id, [FromBody] StorageRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var result = await _organizationService.AdjustStorageAsync(orgIdGuid, model.StorageGbAdjustment.Value); - return new PaymentResponseModel - { - Success = true, - PaymentIntentClientSecret = result - }; - } - - [HttpPost("{id}/verify-bank")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostVerifyBank(string id, [FromBody] OrganizationVerifyBankRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.VerifyBankAsync(orgIdGuid, model.Amount1.Value, model.Amount2.Value); - } - - [HttpPost("{id}/cancel")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostCancel(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.CancelSubscriptionAsync(orgIdGuid); - } - - [HttpPost("{id}/reinstate")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostReinstate(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.ReinstateSubscriptionAsync(orgIdGuid); - } - - [HttpPost("{id}/leave")] - public async Task Leave(string id) - { - var orgGuidId = new Guid(id); - if (!await _currentContext.OrganizationUser(orgGuidId)) - { - throw new NotFoundException(); - } - - var user = await _userService.GetUserByPrincipalAsync(User); - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgGuidId); - if (ssoConfig?.GetData()?.KeyConnectorEnabled == true && - user.UsesKeyConnector) - { - throw new BadRequestException("Your organization's Single Sign-On settings prevent you from leaving."); - } - - - await _organizationService.DeleteUserAsync(orgGuidId, user.Id); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id, [FromBody] SecretVerificationRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - else - { - await _organizationService.DeleteAsync(organization); - } - } - - [HttpPost("{id}/license")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostLicense(string id, LicenseRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - if (license == null) - { - throw new BadRequestException("Invalid license"); - } - - await _organizationService.UpdateLicenseAsync(new Guid(id), license); - } - - [HttpPost("{id}/import")] - public async Task Import(string id, [FromBody] ImportOrganizationUsersRequestModel model) - { - if (!_globalSettings.SelfHosted && !model.LargeImport && - (model.Groups.Count() > 2000 || model.Users.Count(u => !u.Deleted) > 2000)) - { - throw new BadRequestException("You cannot import this much data at once."); - } - - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationAdmin(orgIdGuid)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - await _organizationService.ImportAsync( - orgIdGuid, - userId.Value, - model.Groups.Select(g => g.ToImportedGroup(orgIdGuid)), - model.Users.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), - model.Users.Where(u => u.Deleted).Select(u => u.ExternalId), - model.OverwriteExisting); - } - - [HttpPost("{id}/api-key")] - public async Task ApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - if (model.Type == OrganizationApiKeyType.BillingSync || model.Type == OrganizationApiKeyType.Scim) - { - // Non-enterprise orgs should not be able to create or view an apikey of billing sync/scim key types - var plan = StaticStore.GetPlan(organization.PlanType); - if (plan.Product != ProductType.Enterprise) + var org = await _organizationRepository.GetByIdAsync(new Guid(id)); + if (org == null) { throw new NotFoundException(); } + + return new OrganizationKeysResponseModel(org); } - var organizationApiKey = await _getOrganizationApiKeyCommand - .GetOrganizationApiKeyAsync(organization.Id, model.Type); - - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("{id}/keys")] + public async Task PostKeys(string id, [FromBody] OrganizationKeysRequestModel model) { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var org = await _organizationService.UpdateOrganizationKeysAsync(new Guid(id), model.PublicKey, model.EncryptedPrivateKey); + return new OrganizationKeysResponseModel(org); } - if (model.Type != OrganizationApiKeyType.Scim - && !await _userService.VerifySecretAsync(user, model.Secret)) + [HttpGet("{id:guid}/sso")] + public async Task GetSso(Guid id) { - await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); - } - else - { - var response = new ApiKeyResponseModel(organizationApiKey); - return response; - } - } + if (!await _currentContext.ManageSso(id)) + { + throw new NotFoundException(); + } - [HttpGet("{id}/api-key-information/{type?}")] - public async Task> ApiKeyInformation(Guid id, OrganizationApiKeyType? type) - { - if (!await HasApiKeyAccessAsync(id, type)) - { - throw new NotFoundException(); + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + throw new NotFoundException(); + } + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); + + return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); } - var apiKeys = await _organizationApiKeyRepository.GetManyByOrganizationIdTypeAsync(id, type); - - return new ListResponseModel( - apiKeys.Select(k => new OrganizationApiKeyInformation(k))); - } - - [HttpPost("{id}/rotate-api-key")] - public async Task RotateApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) + [HttpPost("{id:guid}/sso")] + public async Task PostSso(Guid id, [FromBody] OrganizationSsoRequestModel model) { - throw new NotFoundException(); + if (!await _currentContext.ManageSso(id)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + throw new NotFoundException(); + } + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); + ssoConfig = ssoConfig == null ? model.ToSsoConfig(id) : model.ToSsoConfig(ssoConfig); + + await _ssoConfigService.SaveAsync(ssoConfig, organization); + + return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var organizationApiKey = await _getOrganizationApiKeyCommand - .GetOrganizationApiKeyAsync(organization.Id, model.Type); - - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (model.Type != OrganizationApiKeyType.Scim - && !await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); - } - else - { - await _rotateOrganizationApiKeyCommand.RotateApiKeyAsync(organizationApiKey); - var response = new ApiKeyResponseModel(organizationApiKey); - return response; - } - } - - private async Task HasApiKeyAccessAsync(Guid orgId, OrganizationApiKeyType? type) - { - return type switch - { - OrganizationApiKeyType.Scim => await _currentContext.ManageScim(orgId), - _ => await _currentContext.OrganizationOwner(orgId), - }; - } - - [HttpGet("{id}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfo(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = await _paymentService.GetTaxInfoAsync(organization); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("{id}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo(string id, [FromBody] OrganizationTaxInfoUpdateRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = new TaxInfo - { - TaxIdNumber = model.TaxId, - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressCity = model.City, - BillingAddressState = model.State, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - } - - [HttpGet("{id}/keys")] - public async Task GetKeys(string id) - { - var org = await _organizationRepository.GetByIdAsync(new Guid(id)); - if (org == null) - { - throw new NotFoundException(); - } - - return new OrganizationKeysResponseModel(org); - } - - [HttpPost("{id}/keys")] - public async Task PostKeys(string id, [FromBody] OrganizationKeysRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var org = await _organizationService.UpdateOrganizationKeysAsync(new Guid(id), model.PublicKey, model.EncryptedPrivateKey); - return new OrganizationKeysResponseModel(org); - } - - [HttpGet("{id:guid}/sso")] - public async Task GetSso(Guid id) - { - if (!await _currentContext.ManageSso(id)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); - - return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); - } - - [HttpPost("{id:guid}/sso")] - public async Task PostSso(Guid id, [FromBody] OrganizationSsoRequestModel model) - { - if (!await _currentContext.ManageSso(id)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); - ssoConfig = ssoConfig == null ? model.ToSsoConfig(id) : model.ToSsoConfig(ssoConfig); - - await _ssoConfigService.SaveAsync(ssoConfig, organization); - - return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); } } diff --git a/src/Api/Controllers/PlansController.cs b/src/Api/Controllers/PlansController.cs index d738e60cfb..5f5d44c337 100644 --- a/src/Api/Controllers/PlansController.cs +++ b/src/Api/Controllers/PlansController.cs @@ -4,32 +4,33 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("plans")] -[Authorize("Web")] -public class PlansController : Controller +namespace Bit.Api.Controllers { - private readonly ITaxRateRepository _taxRateRepository; - public PlansController(ITaxRateRepository taxRateRepository) + [Route("plans")] + [Authorize("Web")] + public class PlansController : Controller { - _taxRateRepository = taxRateRepository; - } + private readonly ITaxRateRepository _taxRateRepository; + public PlansController(ITaxRateRepository taxRateRepository) + { + _taxRateRepository = taxRateRepository; + } - [HttpGet("")] - [AllowAnonymous] - public ListResponseModel Get() - { - var data = StaticStore.Plans; - var responses = data.Select(plan => new PlanResponseModel(plan)); - return new ListResponseModel(responses); - } + [HttpGet("")] + [AllowAnonymous] + public ListResponseModel Get() + { + var data = StaticStore.Plans; + var responses = data.Select(plan => new PlanResponseModel(plan)); + return new ListResponseModel(responses); + } - [HttpGet("sales-tax-rates")] - public async Task> GetTaxRates() - { - var data = await _taxRateRepository.GetAllActiveAsync(); - var responses = data.Select(x => new TaxRateResponseModel(x)); - return new ListResponseModel(responses); + [HttpGet("sales-tax-rates")] + public async Task> GetTaxRates() + { + var data = await _taxRateRepository.GetAllActiveAsync(); + var responses = data.Select(x => new TaxRateResponseModel(x)); + return new ListResponseModel(responses); + } } } diff --git a/src/Api/Controllers/PoliciesController.cs b/src/Api/Controllers/PoliciesController.cs index 175e1d6a8d..756b8a9d35 100644 --- a/src/Api/Controllers/PoliciesController.cs +++ b/src/Api/Controllers/PoliciesController.cs @@ -11,144 +11,145 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("organizations/{orgId}/policies")] -[Authorize("Application")] -public class PoliciesController : Controller +namespace Bit.Api.Controllers { - private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - private readonly IDataProtector _organizationServiceDataProtector; - - public PoliciesController( - IPolicyRepository policyRepository, - IPolicyService policyService, - IOrganizationService organizationService, - IOrganizationUserRepository organizationUserRepository, - IUserService userService, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IDataProtectionProvider dataProtectionProvider) + [Route("organizations/{orgId}/policies")] + [Authorize("Application")] + public class PoliciesController : Controller { - _policyRepository = policyRepository; - _policyService = policyService; - _organizationService = organizationService; - _organizationUserRepository = organizationUserRepository; - _userService = userService; - _currentContext = currentContext; - _globalSettings = globalSettings; - _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( - "OrganizationServiceDataProtector"); - } + private readonly IPolicyRepository _policyRepository; + private readonly IPolicyService _policyService; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + private readonly IDataProtector _organizationServiceDataProtector; - [HttpGet("{type}")] - public async Task Get(string orgId, int type) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) + public PoliciesController( + IPolicyRepository policyRepository, + IPolicyService policyService, + IOrganizationService organizationService, + IOrganizationUserRepository organizationUserRepository, + IUserService userService, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IDataProtectionProvider dataProtectionProvider) { - throw new NotFoundException(); - } - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(orgIdGuid, (PolicyType)type); - if (policy == null) - { - throw new NotFoundException(); + _policyRepository = policyRepository; + _policyService = policyService; + _organizationService = organizationService; + _organizationUserRepository = organizationUserRepository; + _userService = userService; + _currentContext = currentContext; + _globalSettings = globalSettings; + _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( + "OrganizationServiceDataProtector"); } - return new PolicyResponseModel(policy); - } - - [HttpGet("")] - public async Task> Get(string orgId) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) + [HttpGet("{type}")] + public async Task Get(string orgId, int type) { - throw new NotFoundException(); + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + var policy = await _policyRepository.GetByOrganizationIdTypeAsync(orgIdGuid, (PolicyType)type); + if (policy == null) + { + throw new NotFoundException(); + } + + return new PolicyResponseModel(policy); } - var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = policies.Select(p => new PolicyResponseModel(p)); - return new ListResponseModel(responses); - } - - [AllowAnonymous] - [HttpGet("token")] - public async Task> GetByToken(string orgId, [FromQuery] string email, - [FromQuery] string token, [FromQuery] string organizationUserId) - { - var orgUserId = new Guid(organizationUserId); - var tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, - email, orgUserId, _globalSettings); - if (!tokenValid) + [HttpGet("")] + public async Task> Get(string orgId) { - throw new NotFoundException(); + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + + var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = policies.Select(p => new PolicyResponseModel(p)); + return new ListResponseModel(responses); } - var orgIdGuid = new Guid(orgId); - var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId); - if (orgUser == null || orgUser.OrganizationId != orgIdGuid) + [AllowAnonymous] + [HttpGet("token")] + public async Task> GetByToken(string orgId, [FromQuery] string email, + [FromQuery] string token, [FromQuery] string organizationUserId) { - throw new NotFoundException(); + var orgUserId = new Guid(organizationUserId); + var tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, + email, orgUserId, _globalSettings); + if (!tokenValid) + { + throw new NotFoundException(); + } + + var orgIdGuid = new Guid(orgId); + var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId); + if (orgUser == null || orgUser.OrganizationId != orgIdGuid) + { + throw new NotFoundException(); + } + + var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); + return new ListResponseModel(responses); } - var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); - return new ListResponseModel(responses); - } + [AllowAnonymous] + [HttpGet("invited-user")] + public async Task> GetByInvitedUser(string orgId, [FromQuery] string userId) + { + var user = await _userService.GetUserByIdAsync(new Guid(userId)); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + var orgIdGuid = new Guid(orgId); + var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(user.Id); + var orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgIdGuid); + if (orgUser == null) + { + throw new NotFoundException(); + } + if (orgUser.Status != OrganizationUserStatusType.Invited) + { + throw new UnauthorizedAccessException(); + } - [AllowAnonymous] - [HttpGet("invited-user")] - public async Task> GetByInvitedUser(string orgId, [FromQuery] string userId) - { - var user = await _userService.GetUserByIdAsync(new Guid(userId)); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - var orgIdGuid = new Guid(orgId); - var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgIdGuid); - if (orgUser == null) - { - throw new NotFoundException(); - } - if (orgUser.Status != OrganizationUserStatusType.Invited) - { - throw new UnauthorizedAccessException(); + var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); + return new ListResponseModel(responses); } - var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); - return new ListResponseModel(responses); - } + [HttpPut("{type}")] + public async Task Put(string orgId, int type, [FromBody] PolicyRequestModel model) + { + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + var policy = await _policyRepository.GetByOrganizationIdTypeAsync(new Guid(orgId), (PolicyType)type); + if (policy == null) + { + policy = model.ToPolicy(orgIdGuid); + } + else + { + policy = model.ToPolicy(policy); + } - [HttpPut("{type}")] - public async Task Put(string orgId, int type, [FromBody] PolicyRequestModel model) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); + var userId = _userService.GetProperUserId(User); + await _policyService.SaveAsync(policy, _userService, _organizationService, userId); + return new PolicyResponseModel(policy); } - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(new Guid(orgId), (PolicyType)type); - if (policy == null) - { - policy = model.ToPolicy(orgIdGuid); - } - else - { - policy = model.ToPolicy(policy); - } - - var userId = _userService.GetProperUserId(User); - await _policyService.SaveAsync(policy, _userService, _organizationService, userId); - return new PolicyResponseModel(policy); } } diff --git a/src/Api/Controllers/ProviderOrganizationsController.cs b/src/Api/Controllers/ProviderOrganizationsController.cs index 222d11302e..f4772fbe22 100644 --- a/src/Api/Controllers/ProviderOrganizationsController.cs +++ b/src/Api/Controllers/ProviderOrganizationsController.cs @@ -9,86 +9,87 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("providers/{providerId:guid}/organizations")] -[Authorize("Application")] -public class ProviderOrganizationsController : Controller +namespace Bit.Api.Controllers { - - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly IProviderService _providerService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - - public ProviderOrganizationsController( - IProviderOrganizationRepository providerOrganizationRepository, - IProviderService providerService, - IUserService userService, - ICurrentContext currentContext) + [Route("providers/{providerId:guid}/organizations")] + [Authorize("Application")] + public class ProviderOrganizationsController : Controller { - _providerOrganizationRepository = providerOrganizationRepository; - _providerService = providerService; - _userService = userService; - _currentContext = currentContext; - } - [HttpGet("")] - public async Task> Get(Guid providerId) - { - if (!_currentContext.AccessProviderOrganizations(providerId)) + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IProviderService _providerService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + + public ProviderOrganizationsController( + IProviderOrganizationRepository providerOrganizationRepository, + IProviderService providerService, + IUserService userService, + ICurrentContext currentContext) { - throw new NotFoundException(); + _providerOrganizationRepository = providerOrganizationRepository; + _providerService = providerService; + _userService = userService; + _currentContext = currentContext; } - var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); - var responses = providerOrganizations.Select(o => new ProviderOrganizationOrganizationDetailsResponseModel(o)); - return new ListResponseModel(responses); - } - - [HttpPost("add")] - public async Task Add(Guid providerId, [FromBody] ProviderOrganizationAddRequestModel model) - { - if (!_currentContext.ManageProviderOrganizations(providerId)) + [HttpGet("")] + public async Task> Get(Guid providerId) { - throw new NotFoundException(); + if (!_currentContext.AccessProviderOrganizations(providerId)) + { + throw new NotFoundException(); + } + + var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); + var responses = providerOrganizations.Select(o => new ProviderOrganizationOrganizationDetailsResponseModel(o)); + return new ListResponseModel(responses); } - var userId = _userService.GetProperUserId(User).Value; - - await _providerService.AddOrganization(providerId, model.OrganizationId, userId, model.Key); - } - - [HttpPost("")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Post(Guid providerId, [FromBody] ProviderOrganizationCreateRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("add")] + public async Task Add(Guid providerId, [FromBody] ProviderOrganizationAddRequestModel model) { - throw new UnauthorizedAccessException(); + if (!_currentContext.ManageProviderOrganizations(providerId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + + await _providerService.AddOrganization(providerId, model.OrganizationId, userId, model.Key); } - if (!_currentContext.ManageProviderOrganizations(providerId)) + [HttpPost("")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Post(Guid providerId, [FromBody] ProviderOrganizationCreateRequestModel model) { - throw new NotFoundException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!_currentContext.ManageProviderOrganizations(providerId)) + { + throw new NotFoundException(); + } + + var organizationSignup = model.OrganizationCreateRequest.ToOrganizationSignup(user); + var result = await _providerService.CreateOrganizationAsync(providerId, organizationSignup, model.ClientOwnerEmail, user); + return new ProviderOrganizationResponseModel(result); } - var organizationSignup = model.OrganizationCreateRequest.ToOrganizationSignup(user); - var result = await _providerService.CreateOrganizationAsync(providerId, organizationSignup, model.ClientOwnerEmail, user); - return new ProviderOrganizationResponseModel(result); - } - - [HttpDelete("{id:guid}")] - [HttpPost("{id:guid}/delete")] - public async Task Delete(Guid providerId, Guid id) - { - if (!_currentContext.ManageProviderOrganizations(providerId)) + [HttpDelete("{id:guid}")] + [HttpPost("{id:guid}/delete")] + public async Task Delete(Guid providerId, Guid id) { - throw new NotFoundException(); - } + if (!_currentContext.ManageProviderOrganizations(providerId)) + { + throw new NotFoundException(); + } - var userId = _userService.GetProperUserId(User); - await _providerService.RemoveOrganizationAsync(providerId, id, userId.Value); + var userId = _userService.GetProperUserId(User); + await _providerService.RemoveOrganizationAsync(providerId, id, userId.Value); + } } } diff --git a/src/Api/Controllers/ProviderUsersController.cs b/src/Api/Controllers/ProviderUsersController.cs index f88394c0b6..ad9dec639f 100644 --- a/src/Api/Controllers/ProviderUsersController.cs +++ b/src/Api/Controllers/ProviderUsersController.cs @@ -9,191 +9,192 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("providers/{providerId:guid}/users")] -[Authorize("Application")] -public class ProviderUsersController : Controller +namespace Bit.Api.Controllers { - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderService _providerService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - - public ProviderUsersController( - IProviderUserRepository providerUserRepository, - IProviderService providerService, - IUserService userService, - ICurrentContext currentContext) + [Route("providers/{providerId:guid}/users")] + [Authorize("Application")] + public class ProviderUsersController : Controller { - _providerUserRepository = providerUserRepository; - _providerService = providerService; - _userService = userService; - _currentContext = currentContext; - } + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderService _providerService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; - [HttpGet("{id:guid}")] - public async Task Get(Guid providerId, Guid id) - { - var providerUser = await _providerUserRepository.GetByIdAsync(id); - if (providerUser == null || !_currentContext.ProviderManageUsers(providerUser.ProviderId)) + public ProviderUsersController( + IProviderUserRepository providerUserRepository, + IProviderService providerService, + IUserService userService, + ICurrentContext currentContext) { - throw new NotFoundException(); + _providerUserRepository = providerUserRepository; + _providerService = providerService; + _userService = userService; + _currentContext = currentContext; } - return new ProviderUserResponseModel(providerUser); - } - - [HttpGet("")] - public async Task> Get(Guid providerId) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpGet("{id:guid}")] + public async Task Get(Guid providerId, Guid id) { - throw new NotFoundException(); + var providerUser = await _providerUserRepository.GetByIdAsync(id); + if (providerUser == null || !_currentContext.ProviderManageUsers(providerUser.ProviderId)) + { + throw new NotFoundException(); + } + + return new ProviderUserResponseModel(providerUser); } - var providerUsers = await _providerUserRepository.GetManyDetailsByProviderAsync(providerId); - var responses = providerUsers.Select(o => new ProviderUserUserDetailsResponseModel(o)); - return new ListResponseModel(responses); - } - - [HttpPost("invite")] - public async Task Invite(Guid providerId, [FromBody] ProviderUserInviteRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpGet("")] + public async Task> Get(Guid providerId) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var providerUsers = await _providerUserRepository.GetManyDetailsByProviderAsync(providerId); + var responses = providerUsers.Select(o => new ProviderUserUserDetailsResponseModel(o)); + return new ListResponseModel(responses); } - var invite = ProviderUserInviteFactory.CreateIntialInvite(model.Emails, model.Type.Value, - _userService.GetProperUserId(User).Value, providerId); - await _providerService.InviteUserAsync(invite); - } - - [HttpPost("reinvite")] - public async Task> BulkReinvite(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpPost("invite")] + public async Task Invite(Guid providerId, [FromBody] ProviderUserInviteRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var invite = ProviderUserInviteFactory.CreateIntialInvite(model.Emails, model.Type.Value, + _userService.GetProperUserId(User).Value, providerId); + await _providerService.InviteUserAsync(invite); } - var invite = ProviderUserInviteFactory.CreateReinvite(model.Ids, _userService.GetProperUserId(User).Value, providerId); - var result = await _providerService.ResendInvitesAsync(invite); - return new ListResponseModel( - result.Select(t => new ProviderUserBulkResponseModel(t.Item1.Id, t.Item2))); - } - - [HttpPost("{id:guid}/reinvite")] - public async Task Reinvite(Guid providerId, Guid id) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpPost("reinvite")] + public async Task> BulkReinvite(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var invite = ProviderUserInviteFactory.CreateReinvite(model.Ids, _userService.GetProperUserId(User).Value, providerId); + var result = await _providerService.ResendInvitesAsync(invite); + return new ListResponseModel( + result.Select(t => new ProviderUserBulkResponseModel(t.Item1.Id, t.Item2))); } - var invite = ProviderUserInviteFactory.CreateReinvite(new[] { id }, - _userService.GetProperUserId(User).Value, providerId); - await _providerService.ResendInvitesAsync(invite); - } - - [HttpPost("{id:guid}/accept")] - public async Task Accept(Guid providerId, Guid id, [FromBody] ProviderUserAcceptRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpPost("{id:guid}/reinvite")] + public async Task Reinvite(Guid providerId, Guid id) { - throw new UnauthorizedAccessException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var invite = ProviderUserInviteFactory.CreateReinvite(new[] { id }, + _userService.GetProperUserId(User).Value, providerId); + await _providerService.ResendInvitesAsync(invite); } - await _providerService.AcceptUserAsync(id, user, model.Token); - } - - [HttpPost("{id:guid}/confirm")] - public async Task Confirm(Guid providerId, Guid id, [FromBody] ProviderUserConfirmRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpPost("{id:guid}/accept")] + public async Task Accept(Guid providerId, Guid id, [FromBody] ProviderUserAcceptRequestModel model) { - throw new NotFoundException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _providerService.AcceptUserAsync(id, user, model.Token); } - var userId = _userService.GetProperUserId(User); - await _providerService.ConfirmUsersAsync(providerId, new Dictionary { [id] = model.Key }, userId.Value); - } - - [HttpPost("confirm")] - public async Task> BulkConfirm(Guid providerId, - [FromBody] ProviderUserBulkConfirmRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpPost("{id:guid}/confirm")] + public async Task Confirm(Guid providerId, Guid id, [FromBody] ProviderUserConfirmRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _providerService.ConfirmUsersAsync(providerId, new Dictionary { [id] = model.Key }, userId.Value); } - var userId = _userService.GetProperUserId(User); - var results = await _providerService.ConfirmUsersAsync(providerId, model.ToDictionary(), userId.Value); - - return new ListResponseModel(results.Select(r => - new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); - } - - [HttpPost("public-keys")] - public async Task> UserPublicKeys(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpPost("confirm")] + public async Task> BulkConfirm(Guid providerId, + [FromBody] ProviderUserBulkConfirmRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var results = await _providerService.ConfirmUsersAsync(providerId, model.ToDictionary(), userId.Value); + + return new ListResponseModel(results.Select(r => + new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); } - var result = await _providerUserRepository.GetManyPublicKeysByProviderUserAsync(providerId, model.Ids); - var responses = result.Select(r => new ProviderUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); - return new ListResponseModel(responses); - } - - [HttpPut("{id:guid}")] - [HttpPost("{id:guid}")] - public async Task Put(Guid providerId, Guid id, [FromBody] ProviderUserUpdateRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpPost("public-keys")] + public async Task> UserPublicKeys(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var result = await _providerUserRepository.GetManyPublicKeysByProviderUserAsync(providerId, model.Ids); + var responses = result.Select(r => new ProviderUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); + return new ListResponseModel(responses); } - var providerUser = await _providerUserRepository.GetByIdAsync(id); - if (providerUser == null || providerUser.ProviderId != providerId) + [HttpPut("{id:guid}")] + [HttpPost("{id:guid}")] + public async Task Put(Guid providerId, Guid id, [FromBody] ProviderUserUpdateRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var providerUser = await _providerUserRepository.GetByIdAsync(id); + if (providerUser == null || providerUser.ProviderId != providerId) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _providerService.SaveUserAsync(model.ToProviderUser(providerUser), userId.Value); } - var userId = _userService.GetProperUserId(User); - await _providerService.SaveUserAsync(model.ToProviderUser(providerUser), userId.Value); - } - - [HttpDelete("{id:guid}")] - [HttpPost("{id:guid}/delete")] - public async Task Delete(Guid providerId, Guid id) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpDelete("{id:guid}")] + [HttpPost("{id:guid}/delete")] + public async Task Delete(Guid providerId, Guid id) { - throw new NotFoundException(); + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _providerService.DeleteUsersAsync(providerId, new[] { id }, userId.Value); } - var userId = _userService.GetProperUserId(User); - await _providerService.DeleteUsersAsync(providerId, new[] { id }, userId.Value); - } - - [HttpDelete("")] - [HttpPost("delete")] - public async Task> BulkDelete(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) + [HttpDelete("")] + [HttpPost("delete")] + public async Task> BulkDelete(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) { - throw new NotFoundException(); - } + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); + } - var userId = _userService.GetProperUserId(User); - var result = await _providerService.DeleteUsersAsync(providerId, model.Ids, userId.Value); - return new ListResponseModel(result.Select(r => - new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); + var userId = _userService.GetProperUserId(User); + var result = await _providerService.DeleteUsersAsync(providerId, model.Ids, userId.Value); + return new ListResponseModel(result.Select(r => + new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); + } } } diff --git a/src/Api/Controllers/ProvidersController.cs b/src/Api/Controllers/ProvidersController.cs index 5daf9ce491..5969c0c6f0 100644 --- a/src/Api/Controllers/ProvidersController.cs +++ b/src/Api/Controllers/ProvidersController.cs @@ -8,83 +8,84 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("providers")] -[Authorize("Application")] -public class ProvidersController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - private readonly IProviderRepository _providerRepository; - private readonly IProviderService _providerService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - - public ProvidersController(IUserService userService, IProviderRepository providerRepository, - IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) + [Route("providers")] + [Authorize("Application")] + public class ProvidersController : Controller { - _userService = userService; - _providerRepository = providerRepository; - _providerService = providerService; - _currentContext = currentContext; - _globalSettings = globalSettings; - } + private readonly IUserService _userService; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; - [HttpGet("{id:guid}")] - public async Task Get(Guid id) - { - if (!_currentContext.ProviderUser(id)) + public ProvidersController(IUserService userService, IProviderRepository providerRepository, + IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) { - throw new NotFoundException(); + _userService = userService; + _providerRepository = providerRepository; + _providerService = providerService; + _currentContext = currentContext; + _globalSettings = globalSettings; } - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) + [HttpGet("{id:guid}")] + public async Task Get(Guid id) { - throw new NotFoundException(); + if (!_currentContext.ProviderUser(id)) + { + throw new NotFoundException(); + } + + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + throw new NotFoundException(); + } + + return new ProviderResponseModel(provider); } - return new ProviderResponseModel(provider); - } - - [HttpPut("{id:guid}")] - [HttpPost("{id:guid}")] - public async Task Put(Guid id, [FromBody] ProviderUpdateRequestModel model) - { - if (!_currentContext.ProviderProviderAdmin(id)) + [HttpPut("{id:guid}")] + [HttpPost("{id:guid}")] + public async Task Put(Guid id, [FromBody] ProviderUpdateRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderProviderAdmin(id)) + { + throw new NotFoundException(); + } + + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + throw new NotFoundException(); + } + + await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings)); + return new ProviderResponseModel(provider); } - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) + [HttpPost("{id:guid}/setup")] + public async Task Setup(Guid id, [FromBody] ProviderSetupRequestModel model) { - throw new NotFoundException(); + if (!_currentContext.ProviderProviderAdmin(id)) + { + throw new NotFoundException(); + } + + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + + var response = + await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); + + return new ProviderResponseModel(response); } - - await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings)); - return new ProviderResponseModel(provider); - } - - [HttpPost("{id:guid}/setup")] - public async Task Setup(Guid id, [FromBody] ProviderSetupRequestModel model) - { - if (!_currentContext.ProviderProviderAdmin(id)) - { - throw new NotFoundException(); - } - - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - - var response = - await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); - - return new ProviderResponseModel(response); } } diff --git a/src/Api/Controllers/PushController.cs b/src/Api/Controllers/PushController.cs index 7312cb7b85..afeaf92f7a 100644 --- a/src/Api/Controllers/PushController.cs +++ b/src/Api/Controllers/PushController.cs @@ -7,108 +7,109 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("push")] -[Authorize("Push")] -[SelfHosted(NotSelfHostedOnly = true)] -public class PushController : Controller +namespace Bit.Api.Controllers { - private readonly IPushRegistrationService _pushRegistrationService; - private readonly IPushNotificationService _pushNotificationService; - private readonly IWebHostEnvironment _environment; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - - public PushController( - IPushRegistrationService pushRegistrationService, - IPushNotificationService pushNotificationService, - IWebHostEnvironment environment, - ICurrentContext currentContext, - GlobalSettings globalSettings) + [Route("push")] + [Authorize("Push")] + [SelfHosted(NotSelfHostedOnly = true)] + public class PushController : Controller { - _currentContext = currentContext; - _environment = environment; - _pushRegistrationService = pushRegistrationService; - _pushNotificationService = pushNotificationService; - _globalSettings = globalSettings; - } + private readonly IPushRegistrationService _pushRegistrationService; + private readonly IPushNotificationService _pushNotificationService; + private readonly IWebHostEnvironment _environment; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; - [HttpPost("register")] - public async Task PostRegister([FromBody] PushRegistrationRequestModel model) - { - CheckUsage(); - await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken, Prefix(model.DeviceId), - Prefix(model.UserId), Prefix(model.Identifier), model.Type); - } - - [HttpDelete("{id}")] - public async Task Delete(string id) - { - CheckUsage(); - await _pushRegistrationService.DeleteRegistrationAsync(Prefix(id)); - } - - [HttpPut("add-organization")] - public async Task PutAddOrganization([FromBody] PushUpdateRequestModel model) - { - CheckUsage(); - await _pushRegistrationService.AddUserRegistrationOrganizationAsync( - model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); - } - - [HttpPut("delete-organization")] - public async Task PutDeleteOrganization([FromBody] PushUpdateRequestModel model) - { - CheckUsage(); - await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync( - model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); - } - - [HttpPost("send")] - public async Task PostSend([FromBody] PushSendRequestModel model) - { - CheckUsage(); - - if (!string.IsNullOrWhiteSpace(model.UserId)) + public PushController( + IPushRegistrationService pushRegistrationService, + IPushNotificationService pushNotificationService, + IWebHostEnvironment environment, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId), - model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); - } - else if (!string.IsNullOrWhiteSpace(model.OrganizationId)) - { - await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId), - model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); - } - } - - private string Prefix(string value) - { - if (string.IsNullOrWhiteSpace(value)) - { - return null; + _currentContext = currentContext; + _environment = environment; + _pushRegistrationService = pushRegistrationService; + _pushNotificationService = pushNotificationService; + _globalSettings = globalSettings; } - return $"{_currentContext.InstallationId.Value}_{value}"; - } - - private void CheckUsage() - { - if (CanUse()) + [HttpPost("register")] + public async Task PostRegister([FromBody] PushRegistrationRequestModel model) { - return; + CheckUsage(); + await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken, Prefix(model.DeviceId), + Prefix(model.UserId), Prefix(model.Identifier), model.Type); } - throw new BadRequestException("Not correctly configured for push relays."); - } - - private bool CanUse() - { - if (_environment.IsDevelopment()) + [HttpDelete("{id}")] + public async Task Delete(string id) { - return true; + CheckUsage(); + await _pushRegistrationService.DeleteRegistrationAsync(Prefix(id)); } - return _currentContext.InstallationId.HasValue && !_globalSettings.SelfHosted; + [HttpPut("add-organization")] + public async Task PutAddOrganization([FromBody] PushUpdateRequestModel model) + { + CheckUsage(); + await _pushRegistrationService.AddUserRegistrationOrganizationAsync( + model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); + } + + [HttpPut("delete-organization")] + public async Task PutDeleteOrganization([FromBody] PushUpdateRequestModel model) + { + CheckUsage(); + await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync( + model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); + } + + [HttpPost("send")] + public async Task PostSend([FromBody] PushSendRequestModel model) + { + CheckUsage(); + + if (!string.IsNullOrWhiteSpace(model.UserId)) + { + await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId), + model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); + } + else if (!string.IsNullOrWhiteSpace(model.OrganizationId)) + { + await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId), + model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); + } + } + + private string Prefix(string value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return null; + } + + return $"{_currentContext.InstallationId.Value}_{value}"; + } + + private void CheckUsage() + { + if (CanUse()) + { + return; + } + + throw new BadRequestException("Not correctly configured for push relays."); + } + + private bool CanUse() + { + if (_environment.IsDevelopment()) + { + return true; + } + + return _currentContext.InstallationId.HasValue && !_globalSettings.SelfHosted; + } } } diff --git a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs index ffb5c7bb98..b741929836 100644 --- a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs +++ b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs @@ -7,60 +7,61 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers.SelfHosted; - -[Route("organization/sponsorship/self-hosted")] -[Authorize("Application")] -[SelfHosted(SelfHostedOnly = true)] -public class SelfHostedOrganizationSponsorshipsController : Controller +namespace Bit.Api.Controllers.SelfHosted { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly ICreateSponsorshipCommand _offerSponsorshipCommand; - private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; - private readonly ICurrentContext _currentContext; - - public SelfHostedOrganizationSponsorshipsController( - ICreateSponsorshipCommand offerSponsorshipCommand, - IRevokeSponsorshipCommand revokeSponsorshipCommand, - IOrganizationRepository organizationRepository, - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationUserRepository organizationUserRepository, - ICurrentContext currentContext - ) + [Route("organization/sponsorship/self-hosted")] + [Authorize("Application")] + [SelfHosted(SelfHostedOnly = true)] + public class SelfHostedOrganizationSponsorshipsController : Controller { - _offerSponsorshipCommand = offerSponsorshipCommand; - _revokeSponsorshipCommand = revokeSponsorshipCommand; - _organizationRepository = organizationRepository; - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationUserRepository = organizationUserRepository; - _currentContext = currentContext; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly ICreateSponsorshipCommand _offerSponsorshipCommand; + private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; + private readonly ICurrentContext _currentContext; - [HttpPost("{sponsoringOrgId}/families-for-enterprise")] - public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) - { - await _offerSponsorshipCommand.CreateSponsorshipAsync( - await _organizationRepository.GetByIdAsync(sponsoringOrgId), - await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), - model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); - } - - [HttpDelete("{sponsoringOrgId}")] - [HttpPost("{sponsoringOrgId}/delete")] - public async Task RevokeSponsorship(Guid sponsoringOrgId) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); - - if (orgUser == null) + public SelfHostedOrganizationSponsorshipsController( + ICreateSponsorshipCommand offerSponsorshipCommand, + IRevokeSponsorshipCommand revokeSponsorshipCommand, + IOrganizationRepository organizationRepository, + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationUserRepository organizationUserRepository, + ICurrentContext currentContext + ) { - throw new BadRequestException("Unknown Organization User"); + _offerSponsorshipCommand = offerSponsorshipCommand; + _revokeSponsorshipCommand = revokeSponsorshipCommand; + _organizationRepository = organizationRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationUserRepository = organizationUserRepository; + _currentContext = currentContext; } - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); + [HttpPost("{sponsoringOrgId}/families-for-enterprise")] + public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) + { + await _offerSponsorshipCommand.CreateSponsorshipAsync( + await _organizationRepository.GetByIdAsync(sponsoringOrgId), + await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), + model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); + } - await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); + [HttpDelete("{sponsoringOrgId}")] + [HttpPost("{sponsoringOrgId}/delete")] + public async Task RevokeSponsorship(Guid sponsoringOrgId) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); + + if (orgUser == null) + { + throw new BadRequestException("Unknown Organization User"); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); + + await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); + } } } diff --git a/src/Api/Controllers/SendsController.cs b/src/Api/Controllers/SendsController.cs index 5f1d7527ac..405f5c6593 100644 --- a/src/Api/Controllers/SendsController.cs +++ b/src/Api/Controllers/SendsController.cs @@ -16,322 +16,323 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("sends")] -[Authorize("Application")] -public class SendsController : Controller +namespace Bit.Api.Controllers { - private readonly ISendRepository _sendRepository; - private readonly IUserService _userService; - private readonly ISendService _sendService; - private readonly ISendFileStorageService _sendFileStorageService; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - - public SendsController( - ISendRepository sendRepository, - IUserService userService, - ISendService sendService, - ISendFileStorageService sendFileStorageService, - ILogger logger, - GlobalSettings globalSettings, - ICurrentContext currentContext) + [Route("sends")] + [Authorize("Application")] + public class SendsController : Controller { - _sendRepository = sendRepository; - _userService = userService; - _sendService = sendService; - _sendFileStorageService = sendFileStorageService; - _logger = logger; - _globalSettings = globalSettings; - _currentContext = currentContext; - } + private readonly ISendRepository _sendRepository; + private readonly IUserService _userService; + private readonly ISendService _sendService; + private readonly ISendFileStorageService _sendFileStorageService; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly ICurrentContext _currentContext; - [AllowAnonymous] - [HttpPost("access/{id}")] - public async Task Access(string id, [FromBody] SendAccessRequestModel model) - { - // Uncomment whenever we want to require the `send-id` header - //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || - // _currentContext.HttpContext.Request.Headers["Send-Id"] != id) - //{ - // throw new BadRequestException("Invalid Send-Id header."); - //} - - var guid = new Guid(CoreHelpers.Base64UrlDecode(id)); - var (send, passwordRequired, passwordInvalid) = - await _sendService.AccessAsync(guid, model.Password); - if (passwordRequired) + public SendsController( + ISendRepository sendRepository, + IUserService userService, + ISendService sendService, + ISendFileStorageService sendFileStorageService, + ILogger logger, + GlobalSettings globalSettings, + ICurrentContext currentContext) { - return new UnauthorizedResult(); - } - if (passwordInvalid) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid password."); - } - if (send == null) - { - throw new NotFoundException(); + _sendRepository = sendRepository; + _userService = userService; + _sendService = sendService; + _sendFileStorageService = sendFileStorageService; + _logger = logger; + _globalSettings = globalSettings; + _currentContext = currentContext; } - var sendResponse = new SendAccessResponseModel(send, _globalSettings); - if (send.UserId.HasValue && !send.HideEmail.GetValueOrDefault()) + [AllowAnonymous] + [HttpPost("access/{id}")] + public async Task Access(string id, [FromBody] SendAccessRequestModel model) { - var creator = await _userService.GetUserByIdAsync(send.UserId.Value); - sendResponse.CreatorIdentifier = creator.Email; - } - return new ObjectResult(sendResponse); - } + // Uncomment whenever we want to require the `send-id` header + //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || + // _currentContext.HttpContext.Request.Headers["Send-Id"] != id) + //{ + // throw new BadRequestException("Invalid Send-Id header."); + //} - [AllowAnonymous] - [HttpPost("{encodedSendId}/access/file/{fileId}")] - public async Task GetSendFileDownloadData(string encodedSendId, - string fileId, [FromBody] SendAccessRequestModel model) - { - // Uncomment whenever we want to require the `send-id` header - //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || - // _currentContext.HttpContext.Request.Headers["Send-Id"] != encodedSendId) - //{ - // throw new BadRequestException("Invalid Send-Id header."); - //} + var guid = new Guid(CoreHelpers.Base64UrlDecode(id)); + var (send, passwordRequired, passwordInvalid) = + await _sendService.AccessAsync(guid, model.Password); + if (passwordRequired) + { + return new UnauthorizedResult(); + } + if (passwordInvalid) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid password."); + } + if (send == null) + { + throw new NotFoundException(); + } - var sendId = new Guid(CoreHelpers.Base64UrlDecode(encodedSendId)); - var send = await _sendRepository.GetByIdAsync(sendId); - - if (send == null) - { - throw new BadRequestException("Could not locate send"); + var sendResponse = new SendAccessResponseModel(send, _globalSettings); + if (send.UserId.HasValue && !send.HideEmail.GetValueOrDefault()) + { + var creator = await _userService.GetUserByIdAsync(send.UserId.Value); + sendResponse.CreatorIdentifier = creator.Email; + } + return new ObjectResult(sendResponse); } - var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId, - model.Password); + [AllowAnonymous] + [HttpPost("{encodedSendId}/access/file/{fileId}")] + public async Task GetSendFileDownloadData(string encodedSendId, + string fileId, [FromBody] SendAccessRequestModel model) + { + // Uncomment whenever we want to require the `send-id` header + //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || + // _currentContext.HttpContext.Request.Headers["Send-Id"] != encodedSendId) + //{ + // throw new BadRequestException("Invalid Send-Id header."); + //} - if (passwordRequired) - { - return new UnauthorizedResult(); - } - if (passwordInvalid) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid password."); - } - if (send == null) - { - throw new NotFoundException(); + var sendId = new Guid(CoreHelpers.Base64UrlDecode(encodedSendId)); + var send = await _sendRepository.GetByIdAsync(sendId); + + if (send == null) + { + throw new BadRequestException("Could not locate send"); + } + + var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId, + model.Password); + + if (passwordRequired) + { + return new UnauthorizedResult(); + } + if (passwordInvalid) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid password."); + } + if (send == null) + { + throw new NotFoundException(); + } + + return new ObjectResult(new SendFileDownloadDataResponseModel() + { + Id = fileId, + Url = url, + }); } - return new ObjectResult(new SendFileDownloadDataResponseModel() + [HttpGet("{id}")] + public async Task Get(string id) { - Id = fileId, - Url = url, - }); - } + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) + { + throw new NotFoundException(); + } - [HttpGet("{id}")] - public async Task Get(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); + return new SendResponseModel(send, _globalSettings); } - return new SendResponseModel(send, _globalSettings); - } - - [HttpGet("")] - public async Task> Get() - { - var userId = _userService.GetProperUserId(User).Value; - var sends = await _sendRepository.GetManyByUserIdAsync(userId); - var responses = sends.Select(s => new SendResponseModel(s, _globalSettings)); - return new ListResponseModel(responses); - } - - [HttpPost("")] - public async Task Post([FromBody] SendRequestModel model) - { - model.ValidateCreation(); - var userId = _userService.GetProperUserId(User).Value; - var send = model.ToSend(userId, _sendService); - await _sendService.SaveSendAsync(send); - return new SendResponseModel(send, _globalSettings); - } - - [HttpPost("file")] - [Obsolete("Deprecated File Send API", false)] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostFile() - { - if (!Request?.ContentType.Contains("multipart/") ?? true) + [HttpGet("")] + public async Task> Get() { - throw new BadRequestException("Invalid content."); + var userId = _userService.GetProperUserId(User).Value; + var sends = await _sendRepository.GetManyByUserIdAsync(userId); + var responses = sends.Select(s => new SendResponseModel(s, _globalSettings)); + return new ListResponseModel(responses); } - Send send = null; - await Request.GetSendFileAsync(async (stream, fileName, model) => + [HttpPost("")] + public async Task Post([FromBody] SendRequestModel model) { model.ValidateCreation(); var userId = _userService.GetProperUserId(User).Value; - var (madeSend, madeData) = model.ToSend(userId, fileName, _sendService); - send = madeSend; - await _sendService.SaveFileSendAsync(send, madeData, model.FileLength.GetValueOrDefault(0)); - await _sendService.UploadFileToExistingSendAsync(stream, send); - }); - - return new SendResponseModel(send, _globalSettings); - } - - - [HttpPost("file/v2")] - public async Task PostFile([FromBody] SendRequestModel model) - { - if (model.Type != SendType.File) - { - throw new BadRequestException("Invalid content."); + var send = model.ToSend(userId, _sendService); + await _sendService.SaveSendAsync(send); + return new SendResponseModel(send, _globalSettings); } - if (!model.FileLength.HasValue) - { - throw new BadRequestException("Invalid content. File size hint is required."); - } - - if (model.FileLength.Value > SendService.MAX_FILE_SIZE) - { - throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}."); - } - - var userId = _userService.GetProperUserId(User).Value; - var (send, data) = model.ToSend(userId, model.File.FileName, _sendService); - var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value); - return new SendFileUploadDataResponseModel - { - Url = uploadUrl, - FileUploadType = _sendFileStorageService.FileUploadType, - SendResponse = new SendResponseModel(send, _globalSettings) - }; - } - - [HttpGet("{id}/file/{fileId}")] - public async Task RenewFileUpload(string id, string fileId) - { - var userId = _userService.GetProperUserId(User).Value; - var sendId = new Guid(id); - var send = await _sendRepository.GetByIdAsync(sendId); - var fileData = JsonSerializer.Deserialize(send?.Data); - - if (send == null || send.Type != SendType.File || (send.UserId.HasValue && send.UserId.Value != userId) || - !send.UserId.HasValue || fileData.Id != fileId || fileData.Validated) - { - // Not found if Send isn't found, user doesn't have access, request is faulty, - // or we've already validated the file. This last is to emulate create-only blob permissions for Azure - throw new NotFoundException(); - } - - return new SendFileUploadDataResponseModel - { - Url = await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId), - FileUploadType = _sendFileStorageService.FileUploadType, - SendResponse = new SendResponseModel(send, _globalSettings), - }; - } - - [HttpPost("{id}/file/{fileId}")] - [SelfHosted(SelfHostedOnly = true)] - [RequestSizeLimit(Constants.FileSize501mb)] - [DisableFormValueModelBinding] - public async Task PostFileForExistingSend(string id, string fileId) - { - if (!Request?.ContentType.Contains("multipart/") ?? true) - { - throw new BadRequestException("Invalid content."); - } - - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - await Request.GetFileAsync(async (stream) => - { - await _sendService.UploadFileToExistingSendAsync(stream, send); - }); - } - - [AllowAnonymous] - [HttpPost("file/validate/azure")] - public async Task AzureValidateFile() - { - return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> + [HttpPost("file")] + [Obsolete("Deprecated File Send API", false)] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostFile() { + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); + } + + Send send = null; + await Request.GetSendFileAsync(async (stream, fileName, model) => + { + model.ValidateCreation(); + var userId = _userService.GetProperUserId(User).Value; + var (madeSend, madeData) = model.ToSend(userId, fileName, _sendService); + send = madeSend; + await _sendService.SaveFileSendAsync(send, madeData, model.FileLength.GetValueOrDefault(0)); + await _sendService.UploadFileToExistingSendAsync(stream, send); + }); + + return new SendResponseModel(send, _globalSettings); + } + + + [HttpPost("file/v2")] + public async Task PostFile([FromBody] SendRequestModel model) + { + if (model.Type != SendType.File) + { + throw new BadRequestException("Invalid content."); + } + + if (!model.FileLength.HasValue) + { + throw new BadRequestException("Invalid content. File size hint is required."); + } + + if (model.FileLength.Value > SendService.MAX_FILE_SIZE) + { + throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}."); + } + + var userId = _userService.GetProperUserId(User).Value; + var (send, data) = model.ToSend(userId, model.File.FileName, _sendService); + var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value); + return new SendFileUploadDataResponseModel + { + Url = uploadUrl, + FileUploadType = _sendFileStorageService.FileUploadType, + SendResponse = new SendResponseModel(send, _globalSettings) + }; + } + + [HttpGet("{id}/file/{fileId}")] + public async Task RenewFileUpload(string id, string fileId) + { + var userId = _userService.GetProperUserId(User).Value; + var sendId = new Guid(id); + var send = await _sendRepository.GetByIdAsync(sendId); + var fileData = JsonSerializer.Deserialize(send?.Data); + + if (send == null || send.Type != SendType.File || (send.UserId.HasValue && send.UserId.Value != userId) || + !send.UserId.HasValue || fileData.Id != fileId || fileData.Validated) + { + // Not found if Send isn't found, user doesn't have access, request is faulty, + // or we've already validated the file. This last is to emulate create-only blob permissions for Azure + throw new NotFoundException(); + } + + return new SendFileUploadDataResponseModel + { + Url = await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId), + FileUploadType = _sendFileStorageService.FileUploadType, + SendResponse = new SendResponseModel(send, _globalSettings), + }; + } + + [HttpPost("{id}/file/{fileId}")] + [SelfHosted(SelfHostedOnly = true)] + [RequestSizeLimit(Constants.FileSize501mb)] + [DisableFormValueModelBinding] + public async Task PostFileForExistingSend(string id, string fileId) + { + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); + } + + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + await Request.GetFileAsync(async (stream) => + { + await _sendService.UploadFileToExistingSendAsync(stream, send); + }); + } + + [AllowAnonymous] + [HttpPost("file/validate/azure")] + public async Task AzureValidateFile() + { + return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> { - "Microsoft.Storage.BlobCreated", async (eventGridEvent) => { - try + "Microsoft.Storage.BlobCreated", async (eventGridEvent) => { - var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1]; - var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName); - var send = await _sendRepository.GetByIdAsync(new Guid(sendId)); - if (send == null) + try { - if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService) + var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1]; + var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName); + var send = await _sendRepository.GetByIdAsync(new Guid(sendId)); + if (send == null) { - await azureSendFileStorageService.DeleteBlobAsync(blobName); + if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService) + { + await azureSendFileStorageService.DeleteBlobAsync(blobName); + } + return; } + await _sendService.ValidateSendFile(send); + } + catch (Exception e) + { + _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); return; } - await _sendService.ValidateSendFile(send); - } - catch (Exception e) - { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); - return; } } + }); + } + + [HttpPut("{id}")] + public async Task Put(string id, [FromBody] SendRequestModel model) + { + model.ValidateEdit(); + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) + { + throw new NotFoundException(); } - }); - } - [HttpPut("{id}")] - public async Task Put(string id, [FromBody] SendRequestModel model) - { - model.ValidateEdit(); - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); + await _sendService.SaveSendAsync(model.ToSend(send, _sendService)); + return new SendResponseModel(send, _globalSettings); } - await _sendService.SaveSendAsync(model.ToSend(send, _sendService)); - return new SendResponseModel(send, _globalSettings); - } - - [HttpPut("{id}/remove-password")] - public async Task PutRemovePassword(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) + [HttpPut("{id}/remove-password")] + public async Task PutRemovePassword(string id) { - throw new NotFoundException(); + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) + { + throw new NotFoundException(); + } + + send.Password = null; + await _sendService.SaveSendAsync(send); + return new SendResponseModel(send, _globalSettings); } - send.Password = null; - await _sendService.SaveSendAsync(send); - return new SendResponseModel(send, _globalSettings); - } - - [HttpDelete("{id}")] - public async Task Delete(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) + [HttpDelete("{id}")] + public async Task Delete(string id) { - throw new NotFoundException(); - } + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) + { + throw new NotFoundException(); + } - await _sendService.DeleteSendAsync(send); + await _sendService.DeleteSendAsync(send); + } } } diff --git a/src/Api/Controllers/SettingsController.cs b/src/Api/Controllers/SettingsController.cs index 8489b137e8..2db70b0179 100644 --- a/src/Api/Controllers/SettingsController.cs +++ b/src/Api/Controllers/SettingsController.cs @@ -4,46 +4,47 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("settings")] -[Authorize("Application")] -public class SettingsController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - - public SettingsController( - IUserService userService) + [Route("settings")] + [Authorize("Application")] + public class SettingsController : Controller { - _userService = userService; - } + private readonly IUserService _userService; - [HttpGet("domains")] - public async Task GetDomains(bool excluded = true) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + public SettingsController( + IUserService userService) { - throw new UnauthorizedAccessException(); + _userService = userService; } - var response = new DomainsResponseModel(user, excluded); - return response; - } - - [HttpPut("domains")] - [HttpPost("domains")] - public async Task PutDomains([FromBody] UpdateDomainsRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + [HttpGet("domains")] + public async Task GetDomains(bool excluded = true) { - throw new UnauthorizedAccessException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var response = new DomainsResponseModel(user, excluded); + return response; } - await _userService.SaveUserAsync(model.ToUser(user), true); + [HttpPut("domains")] + [HttpPost("domains")] + public async Task PutDomains([FromBody] UpdateDomainsRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } - var response = new DomainsResponseModel(user); - return response; + await _userService.SaveUserAsync(model.ToUser(user), true); + + var response = new DomainsResponseModel(user); + return response; + } } } diff --git a/src/Api/Controllers/SyncController.cs b/src/Api/Controllers/SyncController.cs index 49ccfeacf8..c855543863 100644 --- a/src/Api/Controllers/SyncController.cs +++ b/src/Api/Controllers/SyncController.cs @@ -10,84 +10,85 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("sync")] -[Authorize("Application")] -public class SyncController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - private readonly IFolderRepository _folderRepository; - private readonly ICipherRepository _cipherRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly ICollectionCipherRepository _collectionCipherRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ISendRepository _sendRepository; - private readonly GlobalSettings _globalSettings; - - public SyncController( - IUserService userService, - IFolderRepository folderRepository, - ICipherRepository cipherRepository, - ICollectionRepository collectionRepository, - ICollectionCipherRepository collectionCipherRepository, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IPolicyRepository policyRepository, - ISendRepository sendRepository, - GlobalSettings globalSettings) + [Route("sync")] + [Authorize("Application")] + public class SyncController : Controller { - _userService = userService; - _folderRepository = folderRepository; - _cipherRepository = cipherRepository; - _collectionRepository = collectionRepository; - _collectionCipherRepository = collectionCipherRepository; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _policyRepository = policyRepository; - _sendRepository = sendRepository; - _globalSettings = globalSettings; - } + private readonly IUserService _userService; + private readonly IFolderRepository _folderRepository; + private readonly ICipherRepository _cipherRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IPolicyRepository _policyRepository; + private readonly ISendRepository _sendRepository; + private readonly GlobalSettings _globalSettings; - [HttpGet("")] - public async Task Get([FromQuery] bool excludeDomains = false) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + public SyncController( + IUserService userService, + IFolderRepository folderRepository, + ICipherRepository cipherRepository, + ICollectionRepository collectionRepository, + ICollectionCipherRepository collectionCipherRepository, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IPolicyRepository policyRepository, + ISendRepository sendRepository, + GlobalSettings globalSettings) { - throw new BadRequestException("User not found."); + _userService = userService; + _folderRepository = folderRepository; + _cipherRepository = cipherRepository; + _collectionRepository = collectionRepository; + _collectionCipherRepository = collectionCipherRepository; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _policyRepository = policyRepository; + _sendRepository = sendRepository; + _globalSettings = globalSettings; } - var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, - ProviderUserStatusType.Confirmed); - var providerUserOrganizationDetails = - await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, + [HttpGet("")] + public async Task Get([FromQuery] bool excludeDomains = false) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new BadRequestException("User not found."); + } + + var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed); - var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); - var folders = await _folderRepository.GetManyByUserIdAsync(user.Id); - var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, hasEnabledOrgs); - var sends = await _sendRepository.GetManyByUserIdAsync(user.Id); + var providerUserOrganizationDetails = + await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, + ProviderUserStatusType.Confirmed); + var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); + var folders = await _folderRepository.GetManyByUserIdAsync(user.Id); + var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, hasEnabledOrgs); + var sends = await _sendRepository.GetManyByUserIdAsync(user.Id); - IEnumerable collections = null; - IDictionary> collectionCiphersGroupDict = null; - IEnumerable policies = null; - if (hasEnabledOrgs) - { - collections = await _collectionRepository.GetManyByUserIdAsync(user.Id); - var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(user.Id); - collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); - policies = await _policyRepository.GetManyByUserIdAsync(user.Id); + IEnumerable collections = null; + IDictionary> collectionCiphersGroupDict = null; + IEnumerable policies = null; + if (hasEnabledOrgs) + { + collections = await _collectionRepository.GetManyByUserIdAsync(user.Id); + var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(user.Id); + collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); + policies = await _policyRepository.GetManyByUserIdAsync(user.Id); + } + + var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user); + var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user); + var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationUserDetails, + providerUserDetails, providerUserOrganizationDetails, folders, collections, ciphers, + collectionCiphersGroupDict, excludeDomains, policies, sends); + return response; } - - var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user); - var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user); - var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationUserDetails, - providerUserDetails, providerUserOrganizationDetails, folders, collections, ciphers, - collectionCiphersGroupDict, excludeDomains, policies, sends); - return response; } } diff --git a/src/Api/Controllers/TwoFactorController.cs b/src/Api/Controllers/TwoFactorController.cs index 6ed2b87961..b8e71f5dfa 100644 --- a/src/Api/Controllers/TwoFactorController.cs +++ b/src/Api/Controllers/TwoFactorController.cs @@ -16,442 +16,443 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("two-factor")] -[Authorize("Web")] -public class TwoFactorController : Controller +namespace Bit.Api.Controllers { - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly GlobalSettings _globalSettings; - private readonly UserManager _userManager; - private readonly ICurrentContext _currentContext; - - public TwoFactorController( - IUserService userService, - IOrganizationRepository organizationRepository, - IOrganizationService organizationService, - GlobalSettings globalSettings, - UserManager userManager, - ICurrentContext currentContext) + [Route("two-factor")] + [Authorize("Web")] + public class TwoFactorController : Controller { - _userService = userService; - _organizationRepository = organizationRepository; - _organizationService = organizationService; - _globalSettings = globalSettings; - _userManager = userManager; - _currentContext = currentContext; - } + private readonly IUserService _userService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly GlobalSettings _globalSettings; + private readonly UserManager _userManager; + private readonly ICurrentContext _currentContext; - [HttpGet("")] - public async Task> Get() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + public TwoFactorController( + IUserService userService, + IOrganizationRepository organizationRepository, + IOrganizationService organizationService, + GlobalSettings globalSettings, + UserManager userManager, + ICurrentContext currentContext) { - throw new UnauthorizedAccessException(); + _userService = userService; + _organizationRepository = organizationRepository; + _organizationService = organizationService; + _globalSettings = globalSettings; + _userManager = userManager; + _currentContext = currentContext; } - var providers = user.GetTwoFactorProviders()?.Select( - p => new TwoFactorProviderResponseModel(p.Key, p.Value)); - return new ListResponseModel(providers); - } - - [HttpGet("~/organizations/{id}/two-factor")] - public async Task> GetOrganization(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationAdmin(orgIdGuid)) + [HttpGet("")] + public async Task> Get() { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var providers = organization.GetTwoFactorProviders()?.Select( - p => new TwoFactorProviderResponseModel(p.Key, p.Value)); - return new ListResponseModel(providers); - } - - [HttpPost("get-authenticator")] - public async Task GetAuthenticator([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, false); - var response = new TwoFactorAuthenticatorResponseModel(user); - return response; - } - - [HttpPut("authenticator")] - [HttpPost("authenticator")] - public async Task PutAuthenticator( - [FromBody] UpdateTwoFactorAuthenticatorRequestModel model) - { - var user = await CheckAsync(model, false); - model.ToUser(user); - - if (!await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator), model.Token)) - { - await Task.Delay(2000); - throw new BadRequestException("Token", "Invalid token."); - } - - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Authenticator); - var response = new TwoFactorAuthenticatorResponseModel(user); - return response; - } - - [HttpPost("get-yubikey")] - public async Task GetYubiKey([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var response = new TwoFactorYubiKeyResponseModel(user); - return response; - } - - [HttpPut("yubikey")] - [HttpPost("yubikey")] - public async Task PutYubiKey([FromBody] UpdateTwoFactorYubicoOtpRequestModel model) - { - var user = await CheckAsync(model, true); - model.ToUser(user); - - await ValidateYubiKeyAsync(user, nameof(model.Key1), model.Key1); - await ValidateYubiKeyAsync(user, nameof(model.Key2), model.Key2); - await ValidateYubiKeyAsync(user, nameof(model.Key3), model.Key3); - await ValidateYubiKeyAsync(user, nameof(model.Key4), model.Key4); - await ValidateYubiKeyAsync(user, nameof(model.Key5), model.Key5); - - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.YubiKey); - var response = new TwoFactorYubiKeyResponseModel(user); - return response; - } - - [HttpPost("get-duo")] - public async Task GetDuo([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var response = new TwoFactorDuoResponseModel(user); - return response; - } - - [HttpPut("duo")] - [HttpPost("duo")] - public async Task PutDuo([FromBody] UpdateTwoFactorDuoRequestModel model) - { - var user = await CheckAsync(model, true); - try - { - var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); - duoApi.JSONApiCall("GET", "/auth/v2/check"); - } - catch (DuoException) - { - throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); - } - - model.ToUser(user); - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Duo); - var response = new TwoFactorDuoResponseModel(user); - return response; - } - - [HttpPost("~/organizations/{id}/two-factor/get-duo")] - public async Task GetOrganizationDuo(string id, - [FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, false); - - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var response = new TwoFactorDuoResponseModel(organization); - return response; - } - - [HttpPut("~/organizations/{id}/two-factor/duo")] - [HttpPost("~/organizations/{id}/two-factor/duo")] - public async Task PutOrganizationDuo(string id, - [FromBody] UpdateTwoFactorDuoRequestModel model) - { - var user = await CheckAsync(model, false); - - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - try - { - var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); - duoApi.JSONApiCall("GET", "/auth/v2/check"); - } - catch (DuoException) - { - throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); - } - - model.ToOrganization(organization); - await _organizationService.UpdateTwoFactorProviderAsync(organization, - TwoFactorProviderType.OrganizationDuo); - var response = new TwoFactorDuoResponseModel(organization); - return response; - } - - [HttpPost("get-webauthn")] - public async Task GetWebAuthn([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var response = new TwoFactorWebAuthnResponseModel(user); - return response; - } - - [HttpPost("get-webauthn-challenge")] - public async Task GetWebAuthnChallenge([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var reg = await _userService.StartWebAuthnRegistrationAsync(user); - return reg; - } - - [HttpPut("webauthn")] - [HttpPost("webauthn")] - public async Task PutWebAuthn([FromBody] TwoFactorWebAuthnRequestModel model) - { - var user = await CheckAsync(model, true); - - var success = await _userService.CompleteWebAuthRegistrationAsync( - user, model.Id.Value, model.Name, model.DeviceResponse); - if (!success) - { - throw new BadRequestException("Unable to complete WebAuthn registration."); - } - var response = new TwoFactorWebAuthnResponseModel(user); - return response; - } - - [HttpDelete("webauthn")] - public async Task DeleteWebAuthn([FromBody] TwoFactorWebAuthnDeleteRequestModel model) - { - var user = await CheckAsync(model, true); - await _userService.DeleteWebAuthnKeyAsync(user, model.Id.Value); - var response = new TwoFactorWebAuthnResponseModel(user); - return response; - } - - [HttpPost("get-email")] - public async Task GetEmail([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, false); - var response = new TwoFactorEmailResponseModel(user); - return response; - } - - [HttpPost("send-email")] - public async Task SendEmail([FromBody] TwoFactorEmailRequestModel model) - { - var user = await CheckAsync(model, false); - model.ToUser(user); - await _userService.SendTwoFactorEmailAsync(user); - } - - [AllowAnonymous] - [HttpPost("send-email-login")] - public async Task SendEmailLogin([FromBody] TwoFactorEmailRequestModel model) - { - var user = await _userManager.FindByEmailAsync(model.Email.ToLowerInvariant()); - if (user != null) - { - if (await _userService.VerifySecretAsync(user, model.Secret)) + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - var isBecauseNewDeviceLogin = false; - if (user.GetTwoFactorProvider(TwoFactorProviderType.Email) is null - && - await _userService.Needs2FABecauseNewDeviceAsync(user, model.DeviceIdentifier, null)) - { - model.ToUser(user); - isBecauseNewDeviceLogin = true; - } + throw new UnauthorizedAccessException(); + } - await _userService.SendTwoFactorEmailAsync(user, isBecauseNewDeviceLogin); - return; + var providers = user.GetTwoFactorProviders()?.Select( + p => new TwoFactorProviderResponseModel(p.Key, p.Value)); + return new ListResponseModel(providers); + } + + [HttpGet("~/organizations/{id}/two-factor")] + public async Task> GetOrganization(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationAdmin(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var providers = organization.GetTwoFactorProviders()?.Select( + p => new TwoFactorProviderResponseModel(p.Key, p.Value)); + return new ListResponseModel(providers); + } + + [HttpPost("get-authenticator")] + public async Task GetAuthenticator([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + var response = new TwoFactorAuthenticatorResponseModel(user); + return response; + } + + [HttpPut("authenticator")] + [HttpPost("authenticator")] + public async Task PutAuthenticator( + [FromBody] UpdateTwoFactorAuthenticatorRequestModel model) + { + var user = await CheckAsync(model, false); + model.ToUser(user); + + if (!await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator), model.Token)) + { + await Task.Delay(2000); + throw new BadRequestException("Token", "Invalid token."); + } + + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Authenticator); + var response = new TwoFactorAuthenticatorResponseModel(user); + return response; + } + + [HttpPost("get-yubikey")] + public async Task GetYubiKey([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var response = new TwoFactorYubiKeyResponseModel(user); + return response; + } + + [HttpPut("yubikey")] + [HttpPost("yubikey")] + public async Task PutYubiKey([FromBody] UpdateTwoFactorYubicoOtpRequestModel model) + { + var user = await CheckAsync(model, true); + model.ToUser(user); + + await ValidateYubiKeyAsync(user, nameof(model.Key1), model.Key1); + await ValidateYubiKeyAsync(user, nameof(model.Key2), model.Key2); + await ValidateYubiKeyAsync(user, nameof(model.Key3), model.Key3); + await ValidateYubiKeyAsync(user, nameof(model.Key4), model.Key4); + await ValidateYubiKeyAsync(user, nameof(model.Key5), model.Key5); + + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.YubiKey); + var response = new TwoFactorYubiKeyResponseModel(user); + return response; + } + + [HttpPost("get-duo")] + public async Task GetDuo([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var response = new TwoFactorDuoResponseModel(user); + return response; + } + + [HttpPut("duo")] + [HttpPost("duo")] + public async Task PutDuo([FromBody] UpdateTwoFactorDuoRequestModel model) + { + var user = await CheckAsync(model, true); + try + { + var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); + duoApi.JSONApiCall("GET", "/auth/v2/check"); + } + catch (DuoException) + { + throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); + } + + model.ToUser(user); + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Duo); + var response = new TwoFactorDuoResponseModel(user); + return response; + } + + [HttpPost("~/organizations/{id}/two-factor/get-duo")] + public async Task GetOrganizationDuo(string id, + [FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var response = new TwoFactorDuoResponseModel(organization); + return response; + } + + [HttpPut("~/organizations/{id}/two-factor/duo")] + [HttpPost("~/organizations/{id}/two-factor/duo")] + public async Task PutOrganizationDuo(string id, + [FromBody] UpdateTwoFactorDuoRequestModel model) + { + var user = await CheckAsync(model, false); + + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + try + { + var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); + duoApi.JSONApiCall("GET", "/auth/v2/check"); + } + catch (DuoException) + { + throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); + } + + model.ToOrganization(organization); + await _organizationService.UpdateTwoFactorProviderAsync(organization, + TwoFactorProviderType.OrganizationDuo); + var response = new TwoFactorDuoResponseModel(organization); + return response; + } + + [HttpPost("get-webauthn")] + public async Task GetWebAuthn([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var response = new TwoFactorWebAuthnResponseModel(user); + return response; + } + + [HttpPost("get-webauthn-challenge")] + public async Task GetWebAuthnChallenge([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var reg = await _userService.StartWebAuthnRegistrationAsync(user); + return reg; + } + + [HttpPut("webauthn")] + [HttpPost("webauthn")] + public async Task PutWebAuthn([FromBody] TwoFactorWebAuthnRequestModel model) + { + var user = await CheckAsync(model, true); + + var success = await _userService.CompleteWebAuthRegistrationAsync( + user, model.Id.Value, model.Name, model.DeviceResponse); + if (!success) + { + throw new BadRequestException("Unable to complete WebAuthn registration."); + } + var response = new TwoFactorWebAuthnResponseModel(user); + return response; + } + + [HttpDelete("webauthn")] + public async Task DeleteWebAuthn([FromBody] TwoFactorWebAuthnDeleteRequestModel model) + { + var user = await CheckAsync(model, true); + await _userService.DeleteWebAuthnKeyAsync(user, model.Id.Value); + var response = new TwoFactorWebAuthnResponseModel(user); + return response; + } + + [HttpPost("get-email")] + public async Task GetEmail([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + var response = new TwoFactorEmailResponseModel(user); + return response; + } + + [HttpPost("send-email")] + public async Task SendEmail([FromBody] TwoFactorEmailRequestModel model) + { + var user = await CheckAsync(model, false); + model.ToUser(user); + await _userService.SendTwoFactorEmailAsync(user); + } + + [AllowAnonymous] + [HttpPost("send-email-login")] + public async Task SendEmailLogin([FromBody] TwoFactorEmailRequestModel model) + { + var user = await _userManager.FindByEmailAsync(model.Email.ToLowerInvariant()); + if (user != null) + { + if (await _userService.VerifySecretAsync(user, model.Secret)) + { + var isBecauseNewDeviceLogin = false; + if (user.GetTwoFactorProvider(TwoFactorProviderType.Email) is null + && + await _userService.Needs2FABecauseNewDeviceAsync(user, model.DeviceIdentifier, null)) + { + model.ToUser(user); + isBecauseNewDeviceLogin = true; + } + + await _userService.SendTwoFactorEmailAsync(user, isBecauseNewDeviceLogin); + return; + } + } + + await Task.Delay(2000); + throw new BadRequestException("Cannot send two-factor email."); + } + + [HttpPut("email")] + [HttpPost("email")] + public async Task PutEmail([FromBody] UpdateTwoFactorEmailRequestModel model) + { + var user = await CheckAsync(model, false); + model.ToUser(user); + + if (!await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.Email), model.Token)) + { + await Task.Delay(2000); + throw new BadRequestException("Token", "Invalid token."); + } + + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); + var response = new TwoFactorEmailResponseModel(user); + return response; + } + + [HttpPut("disable")] + [HttpPost("disable")] + public async Task PutDisable([FromBody] TwoFactorProviderRequestModel model) + { + var user = await CheckAsync(model, false); + await _userService.DisableTwoFactorProviderAsync(user, model.Type.Value, _organizationService); + var response = new TwoFactorProviderResponseModel(model.Type.Value, user); + return response; + } + + [HttpPut("~/organizations/{id}/two-factor/disable")] + [HttpPost("~/organizations/{id}/two-factor/disable")] + public async Task PutOrganizationDisable(string id, + [FromBody] TwoFactorProviderRequestModel model) + { + var user = await CheckAsync(model, false); + + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + await _organizationService.DisableTwoFactorProviderAsync(organization, model.Type.Value); + var response = new TwoFactorProviderResponseModel(model.Type.Value, organization); + return response; + } + + [HttpPost("get-recover")] + public async Task GetRecover([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + var response = new TwoFactorRecoverResponseModel(user); + return response; + } + + [HttpPost("recover")] + [AllowAnonymous] + public async Task PostRecover([FromBody] TwoFactorRecoveryRequestModel model) + { + if (!await _userService.RecoverTwoFactorAsync(model.Email, model.MasterPasswordHash, model.RecoveryCode, + _organizationService)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "Invalid information. Try again."); } } - await Task.Delay(2000); - throw new BadRequestException("Cannot send two-factor email."); - } - - [HttpPut("email")] - [HttpPost("email")] - public async Task PutEmail([FromBody] UpdateTwoFactorEmailRequestModel model) - { - var user = await CheckAsync(model, false); - model.ToUser(user); - - if (!await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.Email), model.Token)) + [HttpGet("get-device-verification-settings")] + public async Task GetDeviceVerificationSettings() { - await Task.Delay(2000); - throw new BadRequestException("Token", "Invalid token."); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (User.Claims.HasSsoIdP()) + { + return new DeviceVerificationResponseModel(false, false); + } + + var canUserEditDeviceVerificationSettings = _userService.CanEditDeviceVerificationSettings(user); + return new DeviceVerificationResponseModel(canUserEditDeviceVerificationSettings, canUserEditDeviceVerificationSettings && user.UnknownDeviceVerificationEnabled); } - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); - var response = new TwoFactorEmailResponseModel(user); - return response; - } - - [HttpPut("disable")] - [HttpPost("disable")] - public async Task PutDisable([FromBody] TwoFactorProviderRequestModel model) - { - var user = await CheckAsync(model, false); - await _userService.DisableTwoFactorProviderAsync(user, model.Type.Value, _organizationService); - var response = new TwoFactorProviderResponseModel(model.Type.Value, user); - return response; - } - - [HttpPut("~/organizations/{id}/two-factor/disable")] - [HttpPost("~/organizations/{id}/two-factor/disable")] - public async Task PutOrganizationDisable(string id, - [FromBody] TwoFactorProviderRequestModel model) - { - var user = await CheckAsync(model, false); - - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManagePolicies(orgIdGuid)) + [HttpPut("device-verification-settings")] + public async Task PutDeviceVerificationSettings([FromBody] DeviceVerificationRequestModel model) { - throw new NotFoundException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + if (!_userService.CanEditDeviceVerificationSettings(user) + || User.Claims.HasSsoIdP()) + { + throw new InvalidOperationException("Can't update device verification settings"); + } + + model.ToUser(user); + await _userService.SaveUserAsync(user); + return new DeviceVerificationResponseModel(true, user.UnknownDeviceVerificationEnabled); } - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) + private async Task CheckAsync(SecretVerificationRequestModel model, bool premium) { - throw new NotFoundException(); + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + + if (premium && !(await _userService.CanAccessPremium(user))) + { + throw new BadRequestException("Premium status is required."); + } + + return user; } - await _organizationService.DisableTwoFactorProviderAsync(organization, model.Type.Value); - var response = new TwoFactorProviderResponseModel(model.Type.Value, organization); - return response; - } - - [HttpPost("get-recover")] - public async Task GetRecover([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, false); - var response = new TwoFactorRecoverResponseModel(user); - return response; - } - - [HttpPost("recover")] - [AllowAnonymous] - public async Task PostRecover([FromBody] TwoFactorRecoveryRequestModel model) - { - if (!await _userService.RecoverTwoFactorAsync(model.Email, model.MasterPasswordHash, model.RecoveryCode, - _organizationService)) + private async Task ValidateYubiKeyAsync(User user, string name, string value) { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "Invalid information. Try again."); - } - } + if (string.IsNullOrWhiteSpace(value) || value.Length == 12) + { + return; + } - [HttpGet("get-device-verification-settings")] - public async Task GetDeviceVerificationSettings() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (User.Claims.HasSsoIdP()) - { - return new DeviceVerificationResponseModel(false, false); - } - - var canUserEditDeviceVerificationSettings = _userService.CanEditDeviceVerificationSettings(user); - return new DeviceVerificationResponseModel(canUserEditDeviceVerificationSettings, canUserEditDeviceVerificationSettings && user.UnknownDeviceVerificationEnabled); - } - - [HttpPut("device-verification-settings")] - public async Task PutDeviceVerificationSettings([FromBody] DeviceVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - if (!_userService.CanEditDeviceVerificationSettings(user) - || User.Claims.HasSsoIdP()) - { - throw new InvalidOperationException("Can't update device verification settings"); - } - - model.ToUser(user); - await _userService.SaveUserAsync(user); - return new DeviceVerificationResponseModel(true, user.UnknownDeviceVerificationEnabled); - } - - private async Task CheckAsync(SecretVerificationRequestModel model, bool premium) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - - if (premium && !(await _userService.CanAccessPremium(user))) - { - throw new BadRequestException("Premium status is required."); - } - - return user; - } - - private async Task ValidateYubiKeyAsync(User user, string name, string value) - { - if (string.IsNullOrWhiteSpace(value) || value.Length == 12) - { - return; - } - - if (!await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey), value)) - { - await Task.Delay(2000); - throw new BadRequestException(name, $"{name} is invalid."); - } - else - { - await Task.Delay(500); + if (!await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey), value)) + { + await Task.Delay(2000); + throw new BadRequestException(name, $"{name} is invalid."); + } + else + { + await Task.Delay(500); + } } } } diff --git a/src/Api/Controllers/UsersController.cs b/src/Api/Controllers/UsersController.cs index 4dfd047d37..eeb50301eb 100644 --- a/src/Api/Controllers/UsersController.cs +++ b/src/Api/Controllers/UsersController.cs @@ -4,30 +4,31 @@ using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; - -[Route("users")] -[Authorize("Application")] -public class UsersController : Controller +namespace Bit.Api.Controllers { - private readonly IUserRepository _userRepository; - - public UsersController( - IUserRepository userRepository) + [Route("users")] + [Authorize("Application")] + public class UsersController : Controller { - _userRepository = userRepository; - } + private readonly IUserRepository _userRepository; - [HttpGet("{id}/public-key")] - public async Task Get(string id) - { - var guidId = new Guid(id); - var key = await _userRepository.GetPublicKeyAsync(guidId); - if (key == null) + public UsersController( + IUserRepository userRepository) { - throw new NotFoundException(); + _userRepository = userRepository; } - return new UserKeyResponseModel(guidId, key); + [HttpGet("{id}/public-key")] + public async Task Get(string id) + { + var guidId = new Guid(id); + var key = await _userRepository.GetPublicKeyAsync(guidId); + if (key == null) + { + throw new NotFoundException(); + } + + return new UserKeyResponseModel(guidId, key); + } } } diff --git a/src/Api/Jobs/AliveJob.cs b/src/Api/Jobs/AliveJob.cs index 71136ef7c1..354b8206e0 100644 --- a/src/Api/Jobs/AliveJob.cs +++ b/src/Api/Jobs/AliveJob.cs @@ -2,16 +2,17 @@ using Bit.Core.Jobs; using Quartz; -namespace Bit.Api.Jobs; - -public class AliveJob : BaseJob +namespace Bit.Api.Jobs { - public AliveJob(ILogger logger) - : base(logger) { } - - protected override Task ExecuteJobAsync(IJobExecutionContext context) + public class AliveJob : BaseJob { - _logger.LogInformation(Constants.BypassFiltersEventId, null, "It's alive!"); - return Task.FromResult(0); + public AliveJob(ILogger logger) + : base(logger) { } + + protected override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, "It's alive!"); + return Task.FromResult(0); + } } } diff --git a/src/Api/Jobs/EmergencyAccessNotificationJob.cs b/src/Api/Jobs/EmergencyAccessNotificationJob.cs index 6520de3522..4851ef38c0 100644 --- a/src/Api/Jobs/EmergencyAccessNotificationJob.cs +++ b/src/Api/Jobs/EmergencyAccessNotificationJob.cs @@ -2,22 +2,23 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs; - -public class EmergencyAccessNotificationJob : BaseJob +namespace Bit.Api.Jobs { - private readonly IServiceScopeFactory _serviceScopeFactory; - - public EmergencyAccessNotificationJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) - : base(logger) + public class EmergencyAccessNotificationJob : BaseJob { - _serviceScopeFactory = serviceScopeFactory; - } + private readonly IServiceScopeFactory _serviceScopeFactory; - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - using var scope = _serviceScopeFactory.CreateScope(); - var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; - await emergencyAccessService.SendNotificationsAsync(); + public EmergencyAccessNotificationJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) + : base(logger) + { + _serviceScopeFactory = serviceScopeFactory; + } + + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + using var scope = _serviceScopeFactory.CreateScope(); + var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; + await emergencyAccessService.SendNotificationsAsync(); + } } } diff --git a/src/Api/Jobs/EmergencyAccessTimeoutJob.cs b/src/Api/Jobs/EmergencyAccessTimeoutJob.cs index 642f4173c3..7e7e85c6d5 100644 --- a/src/Api/Jobs/EmergencyAccessTimeoutJob.cs +++ b/src/Api/Jobs/EmergencyAccessTimeoutJob.cs @@ -2,22 +2,23 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs; - -public class EmergencyAccessTimeoutJob : BaseJob +namespace Bit.Api.Jobs { - private readonly IServiceScopeFactory _serviceScopeFactory; - - public EmergencyAccessTimeoutJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) - : base(logger) + public class EmergencyAccessTimeoutJob : BaseJob { - _serviceScopeFactory = serviceScopeFactory; - } + private readonly IServiceScopeFactory _serviceScopeFactory; - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - using var scope = _serviceScopeFactory.CreateScope(); - var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; - await emergencyAccessService.HandleTimedOutRequestsAsync(); + public EmergencyAccessTimeoutJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) + : base(logger) + { + _serviceScopeFactory = serviceScopeFactory; + } + + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + using var scope = _serviceScopeFactory.CreateScope(); + var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; + await emergencyAccessService.HandleTimedOutRequestsAsync(); + } } } diff --git a/src/Api/Jobs/JobsHostedService.cs b/src/Api/Jobs/JobsHostedService.cs index 241a012428..99adbb0e2c 100644 --- a/src/Api/Jobs/JobsHostedService.cs +++ b/src/Api/Jobs/JobsHostedService.cs @@ -2,81 +2,82 @@ using Bit.Core.Settings; using Quartz; -namespace Bit.Api.Jobs; - -public class JobsHostedService : BaseJobsHostedService +namespace Bit.Api.Jobs { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } - - public override async Task StartAsync(CancellationToken cancellationToken) + public class JobsHostedService : BaseJobsHostedService { - var everyTopOfTheHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheHourTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var emergencyAccessNotificationTrigger = TriggerBuilder.Create() - .WithIdentity("EmergencyAccessNotificationTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var emergencyAccessTimeoutTrigger = TriggerBuilder.Create() - .WithIdentity("EmergencyAccessTimeoutTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var everyTopOfTheSixthHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheSixthHourTrigger") - .StartNow() - .WithCronSchedule("0 0 */6 * * ?") - .Build(); - var everyTwelfthHourAndThirtyMinutesTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTwelfthHourAndThirtyMinutesTrigger") - .StartNow() - .WithCronSchedule("0 30 */12 * * ?") - .Build(); - var randomDailySponsorshipSyncTrigger = TriggerBuilder.Create() - .WithIdentity("RandomDailySponsorshipSyncTrigger") - .StartAt(DateBuilder.FutureDate(new Random().Next(24), IntervalUnit.Hour)) - .WithSimpleSchedule(x => x - .WithIntervalInHours(24) - .RepeatForever()) - .Build(); + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } - var jobs = new List> + public override async Task StartAsync(CancellationToken cancellationToken) { - new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger), - new Tuple(typeof(EmergencyAccessNotificationJob), emergencyAccessNotificationTrigger), - new Tuple(typeof(EmergencyAccessTimeoutJob), emergencyAccessTimeoutTrigger), - new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), - new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger) - }; + var everyTopOfTheHourTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheHourTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var emergencyAccessNotificationTrigger = TriggerBuilder.Create() + .WithIdentity("EmergencyAccessNotificationTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var emergencyAccessTimeoutTrigger = TriggerBuilder.Create() + .WithIdentity("EmergencyAccessTimeoutTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var everyTopOfTheSixthHourTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheSixthHourTrigger") + .StartNow() + .WithCronSchedule("0 0 */6 * * ?") + .Build(); + var everyTwelfthHourAndThirtyMinutesTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTwelfthHourAndThirtyMinutesTrigger") + .StartNow() + .WithCronSchedule("0 30 */12 * * ?") + .Build(); + var randomDailySponsorshipSyncTrigger = TriggerBuilder.Create() + .WithIdentity("RandomDailySponsorshipSyncTrigger") + .StartAt(DateBuilder.FutureDate(new Random().Next(24), IntervalUnit.Hour)) + .WithSimpleSchedule(x => x + .WithIntervalInHours(24) + .RepeatForever()) + .Build(); - if (_globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication) - { - jobs.Add(new Tuple(typeof(SelfHostedSponsorshipSyncJob), randomDailySponsorshipSyncTrigger)); + var jobs = new List> + { + new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger), + new Tuple(typeof(EmergencyAccessNotificationJob), emergencyAccessNotificationTrigger), + new Tuple(typeof(EmergencyAccessTimeoutJob), emergencyAccessTimeoutTrigger), + new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), + new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger) + }; + + if (_globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication) + { + jobs.Add(new Tuple(typeof(SelfHostedSponsorshipSyncJob), randomDailySponsorshipSyncTrigger)); + } + + Jobs = jobs; + + await base.StartAsync(cancellationToken); } - Jobs = jobs; - - await base.StartAsync(cancellationToken); - } - - public static void AddJobsServices(IServiceCollection services, bool selfHosted) - { - if (selfHosted) + public static void AddJobsServices(IServiceCollection services, bool selfHosted) { - services.AddTransient(); + if (selfHosted) + { + services.AddTransient(); + } + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); } } diff --git a/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs b/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs index d217598242..7ffb32d3b2 100644 --- a/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs +++ b/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs @@ -7,58 +7,59 @@ using Bit.Core.Services; using Bit.Core.Settings; using Quartz; -namespace Bit.Api.Jobs; - -public class SelfHostedSponsorshipSyncJob : BaseJob +namespace Bit.Api.Jobs { - private readonly IServiceProvider _serviceProvider; - private IOrganizationRepository _organizationRepository; - private IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ILicensingService _licensingService; - private GlobalSettings _globalSettings; - - public SelfHostedSponsorshipSyncJob( - IServiceProvider serviceProvider, - IOrganizationRepository organizationRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - ILicensingService licensingService, - ILogger logger, - GlobalSettings globalSettings) - : base(logger) + public class SelfHostedSponsorshipSyncJob : BaseJob { - _serviceProvider = serviceProvider; - _organizationRepository = organizationRepository; - _organizationConnectionRepository = organizationConnectionRepository; - _licensingService = licensingService; - _globalSettings = globalSettings; - } + private readonly IServiceProvider _serviceProvider; + private IOrganizationRepository _organizationRepository; + private IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ILicensingService _licensingService; + private GlobalSettings _globalSettings; - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - if (!_globalSettings.EnableCloudCommunication) + public SelfHostedSponsorshipSyncJob( + IServiceProvider serviceProvider, + IOrganizationRepository organizationRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + ILicensingService licensingService, + ILogger logger, + GlobalSettings globalSettings) + : base(logger) { - _logger.LogInformation("Skipping Organization sync with cloud - Cloud communication is disabled in global settings"); - return; + _serviceProvider = serviceProvider; + _organizationRepository = organizationRepository; + _organizationConnectionRepository = organizationConnectionRepository; + _licensingService = licensingService; + _globalSettings = globalSettings; } - var organizations = await _organizationRepository.GetManyByEnabledAsync(); - - using (var scope = _serviceProvider.CreateScope()) + protected override async Task ExecuteJobAsync(IJobExecutionContext context) { - var syncCommand = scope.ServiceProvider.GetRequiredService(); - foreach (var org in organizations) + if (!_globalSettings.EnableCloudCommunication) { - var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(org.Id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); - if (connection != null) + _logger.LogInformation("Skipping Organization sync with cloud - Cloud communication is disabled in global settings"); + return; + } + + var organizations = await _organizationRepository.GetManyByEnabledAsync(); + + using (var scope = _serviceProvider.CreateScope()) + { + var syncCommand = scope.ServiceProvider.GetRequiredService(); + foreach (var org in organizations) { - try + var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(org.Id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); + if (connection != null) { - var config = connection.GetConfig(); - await syncCommand.SyncOrganization(org.Id, config.CloudOrganizationId, connection); - } - catch (Exception ex) - { - _logger.LogError(ex, $"Sponsorship sync for organization {org.Name} Failed"); + try + { + var config = connection.GetConfig(); + await syncCommand.SyncOrganization(org.Id, config.CloudOrganizationId, connection); + } + catch (Exception ex) + { + _logger.LogError(ex, $"Sponsorship sync for organization {org.Name} Failed"); + } } } } diff --git a/src/Api/Jobs/ValidateOrganizationsJob.cs b/src/Api/Jobs/ValidateOrganizationsJob.cs index 8c4225a015..d3ec2dad5a 100644 --- a/src/Api/Jobs/ValidateOrganizationsJob.cs +++ b/src/Api/Jobs/ValidateOrganizationsJob.cs @@ -2,22 +2,23 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs; - -public class ValidateOrganizationsJob : BaseJob +namespace Bit.Api.Jobs { - private readonly ILicensingService _licensingService; - - public ValidateOrganizationsJob( - ILicensingService licensingService, - ILogger logger) - : base(logger) + public class ValidateOrganizationsJob : BaseJob { - _licensingService = licensingService; - } + private readonly ILicensingService _licensingService; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - await _licensingService.ValidateOrganizationsAsync(); + public ValidateOrganizationsJob( + ILicensingService licensingService, + ILogger logger) + : base(logger) + { + _licensingService = licensingService; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + await _licensingService.ValidateOrganizationsAsync(); + } } } diff --git a/src/Api/Jobs/ValidateUsersJob.cs b/src/Api/Jobs/ValidateUsersJob.cs index be531b47de..1261624276 100644 --- a/src/Api/Jobs/ValidateUsersJob.cs +++ b/src/Api/Jobs/ValidateUsersJob.cs @@ -2,22 +2,23 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs; - -public class ValidateUsersJob : BaseJob +namespace Bit.Api.Jobs { - private readonly ILicensingService _licensingService; - - public ValidateUsersJob( - ILicensingService licensingService, - ILogger logger) - : base(logger) + public class ValidateUsersJob : BaseJob { - _licensingService = licensingService; - } + private readonly ILicensingService _licensingService; - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - await _licensingService.ValidateUsersAsync(); + public ValidateUsersJob( + ILicensingService licensingService, + ILogger logger) + : base(logger) + { + _licensingService = licensingService; + } + + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + await _licensingService.ValidateUsersAsync(); + } } } diff --git a/src/Api/Models/CipherAttachmentModel.cs b/src/Api/Models/CipherAttachmentModel.cs index c1ae197187..a510806588 100644 --- a/src/Api/Models/CipherAttachmentModel.cs +++ b/src/Api/Models/CipherAttachmentModel.cs @@ -1,20 +1,21 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class CipherAttachmentModel +namespace Bit.Api.Models { - public CipherAttachmentModel() { } - - public CipherAttachmentModel(CipherAttachment.MetaData data) + public class CipherAttachmentModel { - FileName = data.FileName; - Key = data.Key; - } + public CipherAttachmentModel() { } - [EncryptedStringLength(1000)] - public string FileName { get; set; } - [EncryptedStringLength(1000)] - public string Key { get; set; } + public CipherAttachmentModel(CipherAttachment.MetaData data) + { + FileName = data.FileName; + Key = data.Key; + } + + [EncryptedStringLength(1000)] + public string FileName { get; set; } + [EncryptedStringLength(1000)] + public string Key { get; set; } + } } diff --git a/src/Api/Models/CipherCardModel.cs b/src/Api/Models/CipherCardModel.cs index 07ea4d1e69..d95123e324 100644 --- a/src/Api/Models/CipherCardModel.cs +++ b/src/Api/Models/CipherCardModel.cs @@ -2,38 +2,39 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class CipherCardModel +namespace Bit.Api.Models { - public CipherCardModel() { } - - public CipherCardModel(CipherCardData data) + public class CipherCardModel { - CardholderName = data.CardholderName; - Brand = data.Brand; - Number = data.Number; - ExpMonth = data.ExpMonth; - ExpYear = data.ExpYear; - Code = data.Code; - } + public CipherCardModel() { } - [EncryptedString] - [EncryptedStringLength(1000)] - public string CardholderName { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Brand { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Number { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string ExpMonth { get; set; } - [EncryptedString] - [StringLength(1000)] - public string ExpYear { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Code { get; set; } + public CipherCardModel(CipherCardData data) + { + CardholderName = data.CardholderName; + Brand = data.Brand; + Number = data.Number; + ExpMonth = data.ExpMonth; + ExpYear = data.ExpYear; + Code = data.Code; + } + + [EncryptedString] + [EncryptedStringLength(1000)] + public string CardholderName { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Brand { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Number { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string ExpMonth { get; set; } + [EncryptedString] + [StringLength(1000)] + public string ExpYear { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Code { get; set; } + } } diff --git a/src/Api/Models/CipherFieldModel.cs b/src/Api/Models/CipherFieldModel.cs index 675dcfce07..5ade6e883d 100644 --- a/src/Api/Models/CipherFieldModel.cs +++ b/src/Api/Models/CipherFieldModel.cs @@ -2,35 +2,36 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class CipherFieldModel +namespace Bit.Api.Models { - public CipherFieldModel() { } - - public CipherFieldModel(CipherFieldData data) + public class CipherFieldModel { - Type = data.Type; - Name = data.Name; - Value = data.Value; - LinkedId = data.LinkedId ?? null; - } + public CipherFieldModel() { } - public FieldType Type { get; set; } - [EncryptedStringLength(1000)] - public string Name { get; set; } - [EncryptedStringLength(5000)] - public string Value { get; set; } - public int? LinkedId { get; set; } - - public CipherFieldData ToCipherFieldData() - { - return new CipherFieldData + public CipherFieldModel(CipherFieldData data) { - Type = Type, - Name = Name, - Value = Value, - LinkedId = LinkedId ?? null, - }; + Type = data.Type; + Name = data.Name; + Value = data.Value; + LinkedId = data.LinkedId ?? null; + } + + public FieldType Type { get; set; } + [EncryptedStringLength(1000)] + public string Name { get; set; } + [EncryptedStringLength(5000)] + public string Value { get; set; } + public int? LinkedId { get; set; } + + public CipherFieldData ToCipherFieldData() + { + return new CipherFieldData + { + Type = Type, + Name = Name, + Value = Value, + LinkedId = LinkedId ?? null, + }; + } } } diff --git a/src/Api/Models/CipherIdentityModel.cs b/src/Api/Models/CipherIdentityModel.cs index 7c1fed164f..ce50166192 100644 --- a/src/Api/Models/CipherIdentityModel.cs +++ b/src/Api/Models/CipherIdentityModel.cs @@ -2,86 +2,87 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class CipherIdentityModel +namespace Bit.Api.Models { - public CipherIdentityModel() { } - - public CipherIdentityModel(CipherIdentityData data) + public class CipherIdentityModel { - Title = data.Title; - FirstName = data.FirstName; - MiddleName = data.MiddleName; - LastName = data.LastName; - Address1 = data.Address1; - Address2 = data.Address2; - Address3 = data.Address3; - City = data.City; - State = data.State; - PostalCode = data.PostalCode; - Country = data.Country; - Company = data.Company; - Email = data.Email; - Phone = data.Phone; - SSN = data.SSN; - Username = data.Username; - PassportNumber = data.PassportNumber; - LicenseNumber = data.LicenseNumber; - } + public CipherIdentityModel() { } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Title { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string FirstName { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string MiddleName { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string LastName { get; set; } - [EncryptedString] - [StringLength(1000)] - public string Address1 { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Address2 { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Address3 { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string City { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string State { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string PostalCode { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Country { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Company { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Email { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Phone { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string SSN { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Username { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string PassportNumber { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string LicenseNumber { get; set; } + public CipherIdentityModel(CipherIdentityData data) + { + Title = data.Title; + FirstName = data.FirstName; + MiddleName = data.MiddleName; + LastName = data.LastName; + Address1 = data.Address1; + Address2 = data.Address2; + Address3 = data.Address3; + City = data.City; + State = data.State; + PostalCode = data.PostalCode; + Country = data.Country; + Company = data.Company; + Email = data.Email; + Phone = data.Phone; + SSN = data.SSN; + Username = data.Username; + PassportNumber = data.PassportNumber; + LicenseNumber = data.LicenseNumber; + } + + [EncryptedString] + [EncryptedStringLength(1000)] + public string Title { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string FirstName { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string MiddleName { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string LastName { get; set; } + [EncryptedString] + [StringLength(1000)] + public string Address1 { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Address2 { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Address3 { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string City { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string State { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string PostalCode { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Country { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Company { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Email { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Phone { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string SSN { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Username { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string PassportNumber { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string LicenseNumber { get; set; } + } } diff --git a/src/Api/Models/CipherLoginModel.cs b/src/Api/Models/CipherLoginModel.cs index 134ca09cb5..156da6ba73 100644 --- a/src/Api/Models/CipherLoginModel.cs +++ b/src/Api/Models/CipherLoginModel.cs @@ -2,83 +2,84 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class CipherLoginModel +namespace Bit.Api.Models { - public CipherLoginModel() { } - - public CipherLoginModel(CipherLoginData data) + public class CipherLoginModel { - Uris = data.Uris?.Select(u => new CipherLoginUriModel(u))?.ToList(); - if (!Uris?.Any() ?? true) - { - Uri = data.Uri; - } + public CipherLoginModel() { } - Username = data.Username; - Password = data.Password; - PasswordRevisionDate = data.PasswordRevisionDate; - Totp = data.Totp; - AutofillOnPageLoad = data.AutofillOnPageLoad; - } - - [EncryptedString] - [EncryptedStringLength(10000)] - public string Uri - { - get => Uris?.FirstOrDefault()?.Uri; - set + public CipherLoginModel(CipherLoginData data) { - if (string.IsNullOrWhiteSpace(value)) + Uris = data.Uris?.Select(u => new CipherLoginUriModel(u))?.ToList(); + if (!Uris?.Any() ?? true) { - return; + Uri = data.Uri; } - if (Uris == null) - { - Uris = new List(); - } - - Uris.Add(new CipherLoginUriModel(value)); - } - } - public List Uris { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Username { get; set; } - [EncryptedString] - [EncryptedStringLength(5000)] - public string Password { get; set; } - public DateTime? PasswordRevisionDate { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Totp { get; set; } - public bool? AutofillOnPageLoad { get; set; } - - public class CipherLoginUriModel - { - public CipherLoginUriModel() { } - - public CipherLoginUriModel(string uri) - { - Uri = uri; - } - - public CipherLoginUriModel(CipherLoginData.CipherLoginUriData uri) - { - Uri = uri.Uri; - Match = uri.Match; + Username = data.Username; + Password = data.Password; + PasswordRevisionDate = data.PasswordRevisionDate; + Totp = data.Totp; + AutofillOnPageLoad = data.AutofillOnPageLoad; } [EncryptedString] [EncryptedStringLength(10000)] - public string Uri { get; set; } - public UriMatchType? Match { get; set; } = null; - - public CipherLoginData.CipherLoginUriData ToCipherLoginUriData() + public string Uri { - return new CipherLoginData.CipherLoginUriData { Uri = Uri, Match = Match, }; + get => Uris?.FirstOrDefault()?.Uri; + set + { + if (string.IsNullOrWhiteSpace(value)) + { + return; + } + + if (Uris == null) + { + Uris = new List(); + } + + Uris.Add(new CipherLoginUriModel(value)); + } + } + public List Uris { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Username { get; set; } + [EncryptedString] + [EncryptedStringLength(5000)] + public string Password { get; set; } + public DateTime? PasswordRevisionDate { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Totp { get; set; } + public bool? AutofillOnPageLoad { get; set; } + + public class CipherLoginUriModel + { + public CipherLoginUriModel() { } + + public CipherLoginUriModel(string uri) + { + Uri = uri; + } + + public CipherLoginUriModel(CipherLoginData.CipherLoginUriData uri) + { + Uri = uri.Uri; + Match = uri.Match; + } + + [EncryptedString] + [EncryptedStringLength(10000)] + public string Uri { get; set; } + public UriMatchType? Match { get; set; } = null; + + public CipherLoginData.CipherLoginUriData ToCipherLoginUriData() + { + return new CipherLoginData.CipherLoginUriData { Uri = Uri, Match = Match, }; + } } } } diff --git a/src/Api/Models/CipherPasswordHistoryModel.cs b/src/Api/Models/CipherPasswordHistoryModel.cs index 329c2cf272..bd9eb296f8 100644 --- a/src/Api/Models/CipherPasswordHistoryModel.cs +++ b/src/Api/Models/CipherPasswordHistoryModel.cs @@ -2,27 +2,28 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class CipherPasswordHistoryModel +namespace Bit.Api.Models { - public CipherPasswordHistoryModel() { } - - public CipherPasswordHistoryModel(CipherPasswordHistoryData data) + public class CipherPasswordHistoryModel { - Password = data.Password; - LastUsedDate = data.LastUsedDate; - } + public CipherPasswordHistoryModel() { } - [EncryptedString] - [EncryptedStringLength(5000)] - [Required] - public string Password { get; set; } - [Required] - public DateTime? LastUsedDate { get; set; } + public CipherPasswordHistoryModel(CipherPasswordHistoryData data) + { + Password = data.Password; + LastUsedDate = data.LastUsedDate; + } - public CipherPasswordHistoryData ToCipherPasswordHistoryData() - { - return new CipherPasswordHistoryData { Password = Password, LastUsedDate = LastUsedDate.Value, }; + [EncryptedString] + [EncryptedStringLength(5000)] + [Required] + public string Password { get; set; } + [Required] + public DateTime? LastUsedDate { get; set; } + + public CipherPasswordHistoryData ToCipherPasswordHistoryData() + { + return new CipherPasswordHistoryData { Password = Password, LastUsedDate = LastUsedDate.Value, }; + } } } diff --git a/src/Api/Models/CipherSecureNoteModel.cs b/src/Api/Models/CipherSecureNoteModel.cs index 5ab35d1e84..6ea63d299c 100644 --- a/src/Api/Models/CipherSecureNoteModel.cs +++ b/src/Api/Models/CipherSecureNoteModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models; - -public class CipherSecureNoteModel +namespace Bit.Api.Models { - public CipherSecureNoteModel() { } - - public CipherSecureNoteModel(CipherSecureNoteData data) + public class CipherSecureNoteModel { - Type = data.Type; - } + public CipherSecureNoteModel() { } - public SecureNoteType Type { get; set; } + public CipherSecureNoteModel(CipherSecureNoteData data) + { + Type = data.Type; + } + + public SecureNoteType Type { get; set; } + } } diff --git a/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs b/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs index 014f67a04b..54a0a204f2 100644 --- a/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs +++ b/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs @@ -1,18 +1,19 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public; - -public abstract class AssociationWithPermissionsBaseModel +namespace Bit.Api.Models.Public { - /// - /// The associated object's unique identifier. - /// - /// bfbc8338-e329-4dc0-b0c9-317c2ebf1a09 - [Required] - public Guid? Id { get; set; } - /// - /// When true, the read only permission will not allow the user or group to make changes to items. - /// - [Required] - public bool? ReadOnly { get; set; } + public abstract class AssociationWithPermissionsBaseModel + { + /// + /// The associated object's unique identifier. + /// + /// bfbc8338-e329-4dc0-b0c9-317c2ebf1a09 + [Required] + public Guid? Id { get; set; } + /// + /// When true, the read only permission will not allow the user or group to make changes to items. + /// + [Required] + public bool? ReadOnly { get; set; } + } } diff --git a/src/Api/Models/Public/CollectionBaseModel.cs b/src/Api/Models/Public/CollectionBaseModel.cs index 0dd4b6ce85..5c36ef9b4e 100644 --- a/src/Api/Models/Public/CollectionBaseModel.cs +++ b/src/Api/Models/Public/CollectionBaseModel.cs @@ -1,13 +1,14 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public; - -public abstract class CollectionBaseModel +namespace Bit.Api.Models.Public { - /// - /// External identifier for reference or linking this collection to another system. - /// - /// external_id_123456 - [StringLength(300)] - public string ExternalId { get; set; } + public abstract class CollectionBaseModel + { + /// + /// External identifier for reference or linking this collection to another system. + /// + /// external_id_123456 + [StringLength(300)] + public string ExternalId { get; set; } + } } diff --git a/src/Api/Models/Public/GroupBaseModel.cs b/src/Api/Models/Public/GroupBaseModel.cs index 2b09e2952b..28b5ebe088 100644 --- a/src/Api/Models/Public/GroupBaseModel.cs +++ b/src/Api/Models/Public/GroupBaseModel.cs @@ -1,26 +1,27 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public; - -public abstract class GroupBaseModel +namespace Bit.Api.Models.Public { - /// - /// The name of the group. - /// - /// Development Team - [Required] - [StringLength(100)] - public string Name { get; set; } - /// - /// Determines if this group can access all collections within the organization, or only the associated - /// collections. If set to true, this option overrides any collection assignments. - /// - [Required] - public bool? AccessAll { get; set; } - /// - /// External identifier for reference or linking this group to another system, such as a user directory. - /// - /// external_id_123456 - [StringLength(300)] - public string ExternalId { get; set; } + public abstract class GroupBaseModel + { + /// + /// The name of the group. + /// + /// Development Team + [Required] + [StringLength(100)] + public string Name { get; set; } + /// + /// Determines if this group can access all collections within the organization, or only the associated + /// collections. If set to true, this option overrides any collection assignments. + /// + [Required] + public bool? AccessAll { get; set; } + /// + /// External identifier for reference or linking this group to another system, such as a user directory. + /// + /// external_id_123456 + [StringLength(300)] + public string ExternalId { get; set; } + } } diff --git a/src/Api/Models/Public/MemberBaseModel.cs b/src/Api/Models/Public/MemberBaseModel.cs index af57d80645..47621cf180 100644 --- a/src/Api/Models/Public/MemberBaseModel.cs +++ b/src/Api/Models/Public/MemberBaseModel.cs @@ -3,58 +3,59 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Api.Models.Public; - -public abstract class MemberBaseModel +namespace Bit.Api.Models.Public { - public MemberBaseModel() { } - - public MemberBaseModel(OrganizationUser user) + public abstract class MemberBaseModel { - if (user == null) + public MemberBaseModel() { } + + public MemberBaseModel(OrganizationUser user) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Type = user.Type; + AccessAll = user.AccessAll; + ExternalId = user.ExternalId; + ResetPasswordEnrolled = user.ResetPasswordKey != null; } - Type = user.Type; - AccessAll = user.AccessAll; - ExternalId = user.ExternalId; - ResetPasswordEnrolled = user.ResetPasswordKey != null; - } - - public MemberBaseModel(OrganizationUserUserDetails user) - { - if (user == null) + public MemberBaseModel(OrganizationUserUserDetails user) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Type = user.Type; + AccessAll = user.AccessAll; + ExternalId = user.ExternalId; + ResetPasswordEnrolled = user.ResetPasswordKey != null; } - Type = user.Type; - AccessAll = user.AccessAll; - ExternalId = user.ExternalId; - ResetPasswordEnrolled = user.ResetPasswordKey != null; + /// + /// The member's type (or role) within the organization. + /// + [Required] + public OrganizationUserType? Type { get; set; } + /// + /// Determines if this member can access all collections within the organization, or only the associated + /// collections. If set to true, this option overrides any collection assignments. + /// + [Required] + public bool? AccessAll { get; set; } + /// + /// External identifier for reference or linking this member to another system, such as a user directory. + /// + /// external_id_123456 + [StringLength(300)] + public string ExternalId { get; set; } + /// + /// Returns true if the member has enrolled in Password Reset assistance within the organization + /// + [Required] + public bool ResetPasswordEnrolled { get; set; } } - - /// - /// The member's type (or role) within the organization. - /// - [Required] - public OrganizationUserType? Type { get; set; } - /// - /// Determines if this member can access all collections within the organization, or only the associated - /// collections. If set to true, this option overrides any collection assignments. - /// - [Required] - public bool? AccessAll { get; set; } - /// - /// External identifier for reference or linking this member to another system, such as a user directory. - /// - /// external_id_123456 - [StringLength(300)] - public string ExternalId { get; set; } - /// - /// Returns true if the member has enrolled in Password Reset assistance within the organization - /// - [Required] - public bool ResetPasswordEnrolled { get; set; } } diff --git a/src/Api/Models/Public/PolicyBaseModel.cs b/src/Api/Models/Public/PolicyBaseModel.cs index 2ad8e76005..1814c9f4a3 100644 --- a/src/Api/Models/Public/PolicyBaseModel.cs +++ b/src/Api/Models/Public/PolicyBaseModel.cs @@ -1,16 +1,17 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public; - -public abstract class PolicyBaseModel +namespace Bit.Api.Models.Public { - /// - /// Determines if this policy is enabled and enforced. - /// - [Required] - public bool? Enabled { get; set; } - /// - /// Data for the policy. - /// - public Dictionary Data { get; set; } + public abstract class PolicyBaseModel + { + /// + /// Determines if this policy is enabled and enforced. + /// + [Required] + public bool? Enabled { get; set; } + /// + /// Data for the policy. + /// + public Dictionary Data { get; set; } + } } diff --git a/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs b/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs index b93b16e599..9a87760b92 100644 --- a/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs +++ b/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs @@ -1,15 +1,16 @@ using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Request; - -public class AssociationWithPermissionsRequestModel : AssociationWithPermissionsBaseModel +namespace Bit.Api.Models.Public.Request { - public SelectionReadOnly ToSelectionReadOnly() + public class AssociationWithPermissionsRequestModel : AssociationWithPermissionsBaseModel { - return new SelectionReadOnly + public SelectionReadOnly ToSelectionReadOnly() { - Id = Id.Value, - ReadOnly = ReadOnly.Value - }; + return new SelectionReadOnly + { + Id = Id.Value, + ReadOnly = ReadOnly.Value + }; + } } } diff --git a/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs b/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs index f38d1fec7d..36b77137dd 100644 --- a/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs @@ -1,17 +1,18 @@ using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request; - -public class CollectionUpdateRequestModel : CollectionBaseModel +namespace Bit.Api.Models.Public.Request { - /// - /// The associated groups that this collection is assigned to. - /// - public IEnumerable Groups { get; set; } - - public Collection ToCollection(Collection existingCollection) + public class CollectionUpdateRequestModel : CollectionBaseModel { - existingCollection.ExternalId = ExternalId; - return existingCollection; + /// + /// The associated groups that this collection is assigned to. + /// + public IEnumerable Groups { get; set; } + + public Collection ToCollection(Collection existingCollection) + { + existingCollection.ExternalId = ExternalId; + return existingCollection; + } } } diff --git a/src/Api/Models/Public/Request/EventFilterRequestModel.cs b/src/Api/Models/Public/Request/EventFilterRequestModel.cs index 852076eebc..74a1700a75 100644 --- a/src/Api/Models/Public/Request/EventFilterRequestModel.cs +++ b/src/Api/Models/Public/Request/EventFilterRequestModel.cs @@ -1,49 +1,50 @@ using Bit.Core.Exceptions; -namespace Bit.Api.Models.Public.Request; - -public class EventFilterRequestModel +namespace Bit.Api.Models.Public.Request { - /// - /// The start date. Must be less than the end date. - /// - public DateTime? Start { get; set; } - /// - /// The end date. Must be greater than the start date. - /// - public DateTime? End { get; set; } - /// - /// The unique identifier of the user that performed the event. - /// - public Guid? ActingUserId { get; set; } - /// - /// The unique identifier of the related item that the event describes. - /// - public Guid? ItemId { get; set; } - /// - /// A cursor for use in pagination. - /// - public string ContinuationToken { get; set; } - - public Tuple ToDateRange() + public class EventFilterRequestModel { - if (!End.HasValue || !Start.HasValue) - { - End = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); - Start = DateTime.UtcNow.Date.AddDays(-30); - } - else if (Start.Value > End.Value) - { - var newEnd = Start; - Start = End; - End = newEnd; - } + /// + /// The start date. Must be less than the end date. + /// + public DateTime? Start { get; set; } + /// + /// The end date. Must be greater than the start date. + /// + public DateTime? End { get; set; } + /// + /// The unique identifier of the user that performed the event. + /// + public Guid? ActingUserId { get; set; } + /// + /// The unique identifier of the related item that the event describes. + /// + public Guid? ItemId { get; set; } + /// + /// A cursor for use in pagination. + /// + public string ContinuationToken { get; set; } - if ((End.Value - Start.Value) > TimeSpan.FromDays(367)) + public Tuple ToDateRange() { - throw new BadRequestException("Date range must be < 367 days."); - } + if (!End.HasValue || !Start.HasValue) + { + End = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); + Start = DateTime.UtcNow.Date.AddDays(-30); + } + else if (Start.Value > End.Value) + { + var newEnd = Start; + Start = End; + End = newEnd; + } - return new Tuple(Start.Value, End.Value); + if ((End.Value - Start.Value) > TimeSpan.FromDays(367)) + { + throw new BadRequestException("Date range must be < 367 days."); + } + + return new Tuple(Start.Value, End.Value); + } } } diff --git a/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs b/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs index 9b8193b07e..12e7d4489d 100644 --- a/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs @@ -1,27 +1,28 @@ using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request; - -public class GroupCreateUpdateRequestModel : GroupBaseModel +namespace Bit.Api.Models.Public.Request { - /// - /// The associated collections that this group can access. - /// - public IEnumerable Collections { get; set; } - - public Group ToGroup(Guid orgId) + public class GroupCreateUpdateRequestModel : GroupBaseModel { - return ToGroup(new Group + /// + /// The associated collections that this group can access. + /// + public IEnumerable Collections { get; set; } + + public Group ToGroup(Guid orgId) { - OrganizationId = orgId - }); - } + return ToGroup(new Group + { + OrganizationId = orgId + }); + } - public Group ToGroup(Group existingGroup) - { - existingGroup.Name = Name; - existingGroup.AccessAll = AccessAll.Value; - existingGroup.ExternalId = ExternalId; - return existingGroup; + public Group ToGroup(Group existingGroup) + { + existingGroup.Name = Name; + existingGroup.AccessAll = AccessAll.Value; + existingGroup.ExternalId = ExternalId; + return existingGroup; + } } } diff --git a/src/Api/Models/Public/Request/MemberCreateRequestModel.cs b/src/Api/Models/Public/Request/MemberCreateRequestModel.cs index 447434e470..1845fee22f 100644 --- a/src/Api/Models/Public/Request/MemberCreateRequestModel.cs +++ b/src/Api/Models/Public/Request/MemberCreateRequestModel.cs @@ -2,21 +2,22 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Public.Request; - -public class MemberCreateRequestModel : MemberUpdateRequestModel +namespace Bit.Api.Models.Public.Request { - /// - /// The member's email address. - /// - /// jsmith@example.com - [Required] - [StringLength(256)] - [StrictEmailAddress] - public string Email { get; set; } - - public override OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + public class MemberCreateRequestModel : MemberUpdateRequestModel { - throw new NotImplementedException(); + /// + /// The member's email address. + /// + /// jsmith@example.com + [Required] + [StringLength(256)] + [StrictEmailAddress] + public string Email { get; set; } + + public override OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + { + throw new NotImplementedException(); + } } } diff --git a/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs b/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs index 6b5881186c..44a07f5262 100644 --- a/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs @@ -1,19 +1,20 @@ using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request; - -public class MemberUpdateRequestModel : MemberBaseModel +namespace Bit.Api.Models.Public.Request { - /// - /// The associated collections that this member can access. - /// - public IEnumerable Collections { get; set; } - - public virtual OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + public class MemberUpdateRequestModel : MemberBaseModel { - existingUser.Type = Type.Value; - existingUser.AccessAll = AccessAll.Value; - existingUser.ExternalId = ExternalId; - return existingUser; + /// + /// The associated collections that this member can access. + /// + public IEnumerable Collections { get; set; } + + public virtual OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + { + existingUser.Type = Type.Value; + existingUser.AccessAll = AccessAll.Value; + existingUser.ExternalId = ExternalId; + return existingUser; + } } } diff --git a/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs b/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs index 70bf649a25..2b2177b48d 100644 --- a/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs +++ b/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs @@ -4,107 +4,108 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Public.Request; - -public class OrganizationImportRequestModel +namespace Bit.Api.Models.Public.Request { - /// - /// Groups to import. - /// - public OrganizationImportGroupRequestModel[] Groups { get; set; } - /// - /// Members to import. - /// - public OrganizationImportMemberRequestModel[] Members { get; set; } - /// - /// Determines if the data in this request should overwrite or append to the existing organization data. - /// - [Required] - public bool? OverwriteExisting { get; set; } - /// - /// Indicates an import of over 2000 users and/or groups is expected - /// - public bool LargeImport { get; set; } = false; - - public class OrganizationImportGroupRequestModel + public class OrganizationImportRequestModel { /// - /// The name of the group. + /// Groups to import. /// - /// Development Team - [Required] - [StringLength(100)] - public string Name { get; set; } + public OrganizationImportGroupRequestModel[] Groups { get; set; } /// - /// External identifier for reference or linking this group to another system, such as a user directory. + /// Members to import. /// - /// external_id_123456 - [Required] - [StringLength(300)] - [JsonConverter(typeof(PermissiveStringConverter))] - public string ExternalId { get; set; } + public OrganizationImportMemberRequestModel[] Members { get; set; } /// - /// The associated external ids for members in this group. + /// Determines if the data in this request should overwrite or append to the existing organization data. /// - [JsonConverter(typeof(PermissiveStringEnumerableConverter))] - public IEnumerable MemberExternalIds { get; set; } + [Required] + public bool? OverwriteExisting { get; set; } + /// + /// Indicates an import of over 2000 users and/or groups is expected + /// + public bool LargeImport { get; set; } = false; - public ImportedGroup ToImportedGroup(Guid organizationId) + public class OrganizationImportGroupRequestModel { - var importedGroup = new ImportedGroup + /// + /// The name of the group. + /// + /// Development Team + [Required] + [StringLength(100)] + public string Name { get; set; } + /// + /// External identifier for reference or linking this group to another system, such as a user directory. + /// + /// external_id_123456 + [Required] + [StringLength(300)] + [JsonConverter(typeof(PermissiveStringConverter))] + public string ExternalId { get; set; } + /// + /// The associated external ids for members in this group. + /// + [JsonConverter(typeof(PermissiveStringEnumerableConverter))] + public IEnumerable MemberExternalIds { get; set; } + + public ImportedGroup ToImportedGroup(Guid organizationId) { - Group = new Group + var importedGroup = new ImportedGroup { - OrganizationId = organizationId, - Name = Name, + Group = new Group + { + OrganizationId = organizationId, + Name = Name, + ExternalId = ExternalId + }, + ExternalUserIds = new HashSet(MemberExternalIds) + }; + + return importedGroup; + } + } + + public class OrganizationImportMemberRequestModel : IValidatableObject + { + /// + /// The member's email address. Required for non-deleted users. + /// + /// jsmith@example.com + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + /// + /// External identifier for reference or linking this member to another system, such as a user directory. + /// + /// external_id_123456 + [Required] + [StringLength(300)] + [JsonConverter(typeof(PermissiveStringConverter))] + public string ExternalId { get; set; } + /// + /// Determines if this member should be removed from the organization during import. + /// + public bool Deleted { get; set; } + + public ImportedOrganizationUser ToImportedOrganizationUser() + { + var importedUser = new ImportedOrganizationUser + { + Email = Email.ToLowerInvariant(), ExternalId = ExternalId - }, - ExternalUserIds = new HashSet(MemberExternalIds) - }; + }; - return importedGroup; - } - } + return importedUser; + } - public class OrganizationImportMemberRequestModel : IValidatableObject - { - /// - /// The member's email address. Required for non-deleted users. - /// - /// jsmith@example.com - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - /// - /// External identifier for reference or linking this member to another system, such as a user directory. - /// - /// external_id_123456 - [Required] - [StringLength(300)] - [JsonConverter(typeof(PermissiveStringConverter))] - public string ExternalId { get; set; } - /// - /// Determines if this member should be removed from the organization during import. - /// - public bool Deleted { get; set; } - - public ImportedOrganizationUser ToImportedOrganizationUser() - { - var importedUser = new ImportedOrganizationUser + public IEnumerable Validate(ValidationContext validationContext) { - Email = Email.ToLowerInvariant(), - ExternalId = ExternalId - }; - - return importedUser; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (string.IsNullOrWhiteSpace(Email) && !Deleted) - { - yield return new ValidationResult("Email is required for enabled members.", - new string[] { nameof(Email) }); + if (string.IsNullOrWhiteSpace(Email) && !Deleted) + { + yield return new ValidationResult("Email is required for enabled members.", + new string[] { nameof(Email) }); + } } } } diff --git a/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs b/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs index 251b9358d1..c563ca9d65 100644 --- a/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs @@ -1,22 +1,23 @@ using System.Text.Json; using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request; - -public class PolicyUpdateRequestModel : PolicyBaseModel +namespace Bit.Api.Models.Public.Request { - public Policy ToPolicy(Guid orgId) + public class PolicyUpdateRequestModel : PolicyBaseModel { - return ToPolicy(new Policy + public Policy ToPolicy(Guid orgId) { - OrganizationId = orgId - }); - } + return ToPolicy(new Policy + { + OrganizationId = orgId + }); + } - public virtual Policy ToPolicy(Policy existingPolicy) - { - existingPolicy.Enabled = Enabled.GetValueOrDefault(); - existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; - return existingPolicy; + public virtual Policy ToPolicy(Policy existingPolicy) + { + existingPolicy.Enabled = Enabled.GetValueOrDefault(); + existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; + return existingPolicy; + } } } diff --git a/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs b/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs index 7a818e5bbc..a691777aac 100644 --- a/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs +++ b/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs @@ -1,9 +1,10 @@ -namespace Bit.Api.Models.Public.Request; - -public class UpdateGroupIdsRequestModel +namespace Bit.Api.Models.Public.Request { - /// - /// The associated group ids that this object can access. - /// - public IEnumerable GroupIds { get; set; } + public class UpdateGroupIdsRequestModel + { + /// + /// The associated group ids that this object can access. + /// + public IEnumerable GroupIds { get; set; } + } } diff --git a/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs b/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs index 87a2418318..03ea89eacf 100644 --- a/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs +++ b/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs @@ -1,9 +1,10 @@ -namespace Bit.Api.Models.Public.Request; - -public class UpdateMemberIdsRequestModel +namespace Bit.Api.Models.Public.Request { - /// - /// The associated member ids that have access to this object. - /// - public IEnumerable MemberIds { get; set; } + public class UpdateMemberIdsRequestModel + { + /// + /// The associated member ids that have access to this object. + /// + public IEnumerable MemberIds { get; set; } + } } diff --git a/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs b/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs index 04863d9b46..823b359046 100644 --- a/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs +++ b/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response; - -public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel +namespace Bit.Api.Models.Public.Response { - public AssociationWithPermissionsResponseModel(SelectionReadOnly selection) + public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel { - if (selection == null) + public AssociationWithPermissionsResponseModel(SelectionReadOnly selection) { - throw new ArgumentNullException(nameof(selection)); + if (selection == null) + { + throw new ArgumentNullException(nameof(selection)); + } + Id = selection.Id; + ReadOnly = selection.ReadOnly; } - Id = selection.Id; - ReadOnly = selection.ReadOnly; } } diff --git a/src/Api/Models/Public/Response/CollectionResponseModel.cs b/src/Api/Models/Public/Response/CollectionResponseModel.cs index 93e484801d..8e318e585f 100644 --- a/src/Api/Models/Public/Response/CollectionResponseModel.cs +++ b/src/Api/Models/Public/Response/CollectionResponseModel.cs @@ -2,39 +2,40 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response; - -/// -/// A collection. -/// -public class CollectionResponseModel : CollectionBaseModel, IResponseModel +namespace Bit.Api.Models.Public.Response { - public CollectionResponseModel(Collection collection, IEnumerable groups) + /// + /// A collection. + /// + public class CollectionResponseModel : CollectionBaseModel, IResponseModel { - if (collection == null) + public CollectionResponseModel(Collection collection, IEnumerable groups) { - throw new ArgumentNullException(nameof(collection)); + if (collection == null) + { + throw new ArgumentNullException(nameof(collection)); + } + + Id = collection.Id; + ExternalId = collection.ExternalId; + Groups = groups?.Select(c => new AssociationWithPermissionsResponseModel(c)); } - Id = collection.Id; - ExternalId = collection.ExternalId; - Groups = groups?.Select(c => new AssociationWithPermissionsResponseModel(c)); + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// collection + [Required] + public string Object => "collection"; + /// + /// The collection's unique identifier. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The associated groups that this collection is assigned to. + /// + public IEnumerable Groups { get; set; } } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// collection - [Required] - public string Object => "collection"; - /// - /// The collection's unique identifier. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The associated groups that this collection is assigned to. - /// - public IEnumerable Groups { get; set; } } diff --git a/src/Api/Models/Public/Response/ErrorResponseModel.cs b/src/Api/Models/Public/Response/ErrorResponseModel.cs index 4a4887a0e7..dd2f2ba0ee 100644 --- a/src/Api/Models/Public/Response/ErrorResponseModel.cs +++ b/src/Api/Models/Public/Response/ErrorResponseModel.cs @@ -1,76 +1,77 @@ using System.ComponentModel.DataAnnotations; using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Api.Models.Public.Response; - -public class ErrorResponseModel : IResponseModel +namespace Bit.Api.Models.Public.Response { - public ErrorResponseModel(string message) + public class ErrorResponseModel : IResponseModel { - Message = message; - } - - public ErrorResponseModel(ModelStateDictionary modelState) - { - Message = "The request's model state is invalid."; - Errors = new Dictionary>(); - - var keys = modelState.Keys.ToList(); - var values = modelState.Values.ToList(); - - for (var i = 0; i < values.Count; i++) + public ErrorResponseModel(string message) { - var value = values[i]; - if (keys.Count <= i) - { - // Keys not available for some reason. - break; - } - - var key = keys[i]; - if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) - { - continue; - } - - var errors = value.Errors.Select(e => e.ErrorMessage); - Errors.Add(key, errors); + Message = message; } + + public ErrorResponseModel(ModelStateDictionary modelState) + { + Message = "The request's model state is invalid."; + Errors = new Dictionary>(); + + var keys = modelState.Keys.ToList(); + var values = modelState.Values.ToList(); + + for (var i = 0; i < values.Count; i++) + { + var value = values[i]; + if (keys.Count <= i) + { + // Keys not available for some reason. + break; + } + + var key = keys[i]; + if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) + { + continue; + } + + var errors = value.Errors.Select(e => e.ErrorMessage); + Errors.Add(key, errors); + } + } + + public ErrorResponseModel(Dictionary> errors) + : this("Errors have occurred.", errors) + { } + + public ErrorResponseModel(string errorKey, string errorValue) + : this(errorKey, new string[] { errorValue }) + { } + + public ErrorResponseModel(string errorKey, IEnumerable errorValues) + : this(new Dictionary> { { errorKey, errorValues } }) + { } + + public ErrorResponseModel(string message, Dictionary> errors) + { + Message = message; + Errors = errors; + } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// error + [Required] + public string Object => "error"; + /// + /// A human-readable message providing details about the error. + /// + /// The request model is invalid. + [Required] + public string Message { get; set; } + /// + /// If multiple errors occurred, they are listed in dictionary. Errors related to a specific + /// request parameter will include a dictionary key describing that parameter. + /// + public Dictionary> Errors { get; set; } } - - public ErrorResponseModel(Dictionary> errors) - : this("Errors have occurred.", errors) - { } - - public ErrorResponseModel(string errorKey, string errorValue) - : this(errorKey, new string[] { errorValue }) - { } - - public ErrorResponseModel(string errorKey, IEnumerable errorValues) - : this(new Dictionary> { { errorKey, errorValues } }) - { } - - public ErrorResponseModel(string message, Dictionary> errors) - { - Message = message; - Errors = errors; - } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// error - [Required] - public string Object => "error"; - /// - /// A human-readable message providing details about the error. - /// - /// The request model is invalid. - [Required] - public string Message { get; set; } - /// - /// If multiple errors occurred, they are listed in dictionary. Errors related to a specific - /// request parameter will include a dictionary key describing that parameter. - /// - public Dictionary> Errors { get; set; } } diff --git a/src/Api/Models/Public/Response/EventResponseModel.cs b/src/Api/Models/Public/Response/EventResponseModel.cs index bc8b77e491..4a5f9f6527 100644 --- a/src/Api/Models/Public/Response/EventResponseModel.cs +++ b/src/Api/Models/Public/Response/EventResponseModel.cs @@ -2,91 +2,92 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response; - -/// -/// An event log. -/// -public class EventResponseModel : IResponseModel +namespace Bit.Api.Models.Public.Response { - public EventResponseModel(IEvent ev) + /// + /// An event log. + /// + public class EventResponseModel : IResponseModel { - if (ev == null) + public EventResponseModel(IEvent ev) { - throw new ArgumentNullException(nameof(ev)); + if (ev == null) + { + throw new ArgumentNullException(nameof(ev)); + } + + Type = ev.Type; + ItemId = ev.CipherId; + CollectionId = ev.CollectionId; + GroupId = ev.GroupId; + PolicyId = ev.PolicyId; + MemberId = ev.OrganizationUserId; + ActingUserId = ev.ActingUserId; + Date = ev.Date; + Device = ev.DeviceType; + IpAddress = ev.IpAddress; + InstallationId = ev.InstallationId; } - Type = ev.Type; - ItemId = ev.CipherId; - CollectionId = ev.CollectionId; - GroupId = ev.GroupId; - PolicyId = ev.PolicyId; - MemberId = ev.OrganizationUserId; - ActingUserId = ev.ActingUserId; - Date = ev.Date; - Device = ev.DeviceType; - IpAddress = ev.IpAddress; - InstallationId = ev.InstallationId; + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// event + [Required] + public string Object => "event"; + /// + /// The type of event. + /// + [Required] + public EventType Type { get; set; } + /// + /// The unique identifier of the related item that the event describes. + /// + /// 3767a302-8208-4dc6-b842-030428a1cfad + public Guid? ItemId { get; set; } + /// + /// The unique identifier of the related collection that the event describes. + /// + /// bce212a4-25f3-4888-8a0a-4c5736d851e0 + public Guid? CollectionId { get; set; } + /// + /// The unique identifier of the related group that the event describes. + /// + /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 + public Guid? GroupId { get; set; } + /// + /// The unique identifier of the related policy that the event describes. + /// + /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 + public Guid? PolicyId { get; set; } + /// + /// The unique identifier of the related member that the event describes. + /// + /// e68b8629-85eb-4929-92c0-b84464976ba4 + public Guid? MemberId { get; set; } + /// + /// The unique identifier of the user that performed the event. + /// + /// a2549f79-a71f-4eb9-9234-eb7247333f94 + public Guid? ActingUserId { get; set; } + /// + /// The Unique identifier of the Installation that performed the event. + /// + /// + public Guid? InstallationId { get; set; } + /// + /// The date/timestamp when the event occurred. + /// + [Required] + public DateTime Date { get; set; } + /// + /// The type of device used by the acting user when the event occurred. + /// + public DeviceType? Device { get; set; } + /// + /// The IP address of the acting user. + /// + /// 172.16.254.1 + public string IpAddress { get; set; } } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// event - [Required] - public string Object => "event"; - /// - /// The type of event. - /// - [Required] - public EventType Type { get; set; } - /// - /// The unique identifier of the related item that the event describes. - /// - /// 3767a302-8208-4dc6-b842-030428a1cfad - public Guid? ItemId { get; set; } - /// - /// The unique identifier of the related collection that the event describes. - /// - /// bce212a4-25f3-4888-8a0a-4c5736d851e0 - public Guid? CollectionId { get; set; } - /// - /// The unique identifier of the related group that the event describes. - /// - /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 - public Guid? GroupId { get; set; } - /// - /// The unique identifier of the related policy that the event describes. - /// - /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 - public Guid? PolicyId { get; set; } - /// - /// The unique identifier of the related member that the event describes. - /// - /// e68b8629-85eb-4929-92c0-b84464976ba4 - public Guid? MemberId { get; set; } - /// - /// The unique identifier of the user that performed the event. - /// - /// a2549f79-a71f-4eb9-9234-eb7247333f94 - public Guid? ActingUserId { get; set; } - /// - /// The Unique identifier of the Installation that performed the event. - /// - /// - public Guid? InstallationId { get; set; } - /// - /// The date/timestamp when the event occurred. - /// - [Required] - public DateTime Date { get; set; } - /// - /// The type of device used by the acting user when the event occurred. - /// - public DeviceType? Device { get; set; } - /// - /// The IP address of the acting user. - /// - /// 172.16.254.1 - public string IpAddress { get; set; } } diff --git a/src/Api/Models/Public/Response/GroupResponseModel.cs b/src/Api/Models/Public/Response/GroupResponseModel.cs index c2e8df4bee..4c6a76c8fc 100644 --- a/src/Api/Models/Public/Response/GroupResponseModel.cs +++ b/src/Api/Models/Public/Response/GroupResponseModel.cs @@ -2,41 +2,42 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response; - -/// -/// A user group. -/// -public class GroupResponseModel : GroupBaseModel, IResponseModel +namespace Bit.Api.Models.Public.Response { - public GroupResponseModel(Group group, IEnumerable collections) + /// + /// A user group. + /// + public class GroupResponseModel : GroupBaseModel, IResponseModel { - if (group == null) + public GroupResponseModel(Group group, IEnumerable collections) { - throw new ArgumentNullException(nameof(group)); + if (group == null) + { + throw new ArgumentNullException(nameof(group)); + } + + Id = group.Id; + Name = group.Name; + AccessAll = group.AccessAll; + ExternalId = group.ExternalId; + Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); } - Id = group.Id; - Name = group.Name; - AccessAll = group.AccessAll; - ExternalId = group.ExternalId; - Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// group + [Required] + public string Object => "group"; + /// + /// The group's unique identifier. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The associated collections that this group can access. + /// + public IEnumerable Collections { get; set; } } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// group - [Required] - public string Object => "group"; - /// - /// The group's unique identifier. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The associated collections that this group can access. - /// - public IEnumerable Collections { get; set; } } diff --git a/src/Api/Models/Public/Response/IResponseModel.cs b/src/Api/Models/Public/Response/IResponseModel.cs index 1032f52767..3e33330733 100644 --- a/src/Api/Models/Public/Response/IResponseModel.cs +++ b/src/Api/Models/Public/Response/IResponseModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Api.Models.Public.Response; - -public interface IResponseModel +namespace Bit.Api.Models.Public.Response { - string Object { get; } + public interface IResponseModel + { + string Object { get; } + } } diff --git a/src/Api/Models/Public/Response/ListResponseModel.cs b/src/Api/Models/Public/Response/ListResponseModel.cs index 0865be3e8e..78328c3e1e 100644 --- a/src/Api/Models/Public/Response/ListResponseModel.cs +++ b/src/Api/Models/Public/Response/ListResponseModel.cs @@ -1,28 +1,29 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public.Response; - -public class ListResponseModel : IResponseModel where T : IResponseModel +namespace Bit.Api.Models.Public.Response { - public ListResponseModel(IEnumerable data, string continuationToken = null) + public class ListResponseModel : IResponseModel where T : IResponseModel { - Data = data; - ContinuationToken = continuationToken; - } + public ListResponseModel(IEnumerable data, string continuationToken = null) + { + Data = data; + ContinuationToken = continuationToken; + } - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// list - [Required] - public string Object => "list"; - /// - /// An array containing the actual response elements, paginated by any request parameters. - /// - [Required] - public IEnumerable Data { get; set; } - /// - /// A cursor for use in pagination. - /// - public string ContinuationToken { get; set; } + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// list + [Required] + public string Object => "list"; + /// + /// An array containing the actual response elements, paginated by any request parameters. + /// + [Required] + public IEnumerable Data { get; set; } + /// + /// A cursor for use in pagination. + /// + public string ContinuationToken { get; set; } + } } diff --git a/src/Api/Models/Public/Response/MemberResponseModel.cs b/src/Api/Models/Public/Response/MemberResponseModel.cs index ccb8a8c953..ceac9fca25 100644 --- a/src/Api/Models/Public/Response/MemberResponseModel.cs +++ b/src/Api/Models/Public/Response/MemberResponseModel.cs @@ -4,90 +4,91 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Api.Models.Public.Response; - -/// -/// An organization member. -/// -public class MemberResponseModel : MemberBaseModel, IResponseModel +namespace Bit.Api.Models.Public.Response { - public MemberResponseModel(OrganizationUser user, IEnumerable collections) - : base(user) + /// + /// An organization member. + /// + public class MemberResponseModel : MemberBaseModel, IResponseModel { - if (user == null) + public MemberResponseModel(OrganizationUser user, IEnumerable collections) + : base(user) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Id = user.Id; + UserId = user.UserId; + Email = user.Email; + Status = user.Status; + Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); } - Id = user.Id; - UserId = user.UserId; - Email = user.Email; - Status = user.Status; - Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); - } - - public MemberResponseModel(OrganizationUserUserDetails user, bool twoFactorEnabled, - IEnumerable collections) - : base(user) - { - if (user == null) + public MemberResponseModel(OrganizationUserUserDetails user, bool twoFactorEnabled, + IEnumerable collections) + : base(user) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Id = user.Id; + UserId = user.UserId; + Name = user.Name; + Email = user.Email; + TwoFactorEnabled = twoFactorEnabled; + Status = user.Status; + Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); } - Id = user.Id; - UserId = user.UserId; - Name = user.Name; - Email = user.Email; - TwoFactorEnabled = twoFactorEnabled; - Status = user.Status; - Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// member + [Required] + public string Object => "member"; + /// + /// The member's unique identifier within the organization. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The member's unique identifier across Bitwarden. + /// + /// 48b47ee1-493e-4c67-aef7-014996c40eca + [Required] + public Guid? UserId { get; set; } + /// + /// The member's name, set from their user account profile. + /// + /// John Smith + public string Name { get; set; } + /// + /// The member's email address. + /// + /// jsmith@example.com + [Required] + public string Email { get; set; } + /// + /// Returns true if the member has a two-step login method enabled on their user account. + /// + [Required] + public bool TwoFactorEnabled { get; set; } + /// + /// The member's status within the organization. All created members start with a status of "Invited". + /// Once a member accept's their invitation to join the organization, their status changes to "Accepted". + /// Accepted members are then "Confirmed" by an organization administrator. Once a member is "Confirmed", + /// their status can no longer change. + /// + [Required] + public OrganizationUserStatusType Status { get; set; } + /// + /// The associated collections that this member can access. + /// + public IEnumerable Collections { get; set; } } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// member - [Required] - public string Object => "member"; - /// - /// The member's unique identifier within the organization. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The member's unique identifier across Bitwarden. - /// - /// 48b47ee1-493e-4c67-aef7-014996c40eca - [Required] - public Guid? UserId { get; set; } - /// - /// The member's name, set from their user account profile. - /// - /// John Smith - public string Name { get; set; } - /// - /// The member's email address. - /// - /// jsmith@example.com - [Required] - public string Email { get; set; } - /// - /// Returns true if the member has a two-step login method enabled on their user account. - /// - [Required] - public bool TwoFactorEnabled { get; set; } - /// - /// The member's status within the organization. All created members start with a status of "Invited". - /// Once a member accept's their invitation to join the organization, their status changes to "Accepted". - /// Accepted members are then "Confirmed" by an organization administrator. Once a member is "Confirmed", - /// their status can no longer change. - /// - [Required] - public OrganizationUserStatusType Status { get; set; } - /// - /// The associated collections that this member can access. - /// - public IEnumerable Collections { get; set; } } diff --git a/src/Api/Models/Public/Response/PolicyResponseModel.cs b/src/Api/Models/Public/Response/PolicyResponseModel.cs index b30c283229..9806c96d03 100644 --- a/src/Api/Models/Public/Response/PolicyResponseModel.cs +++ b/src/Api/Models/Public/Response/PolicyResponseModel.cs @@ -3,44 +3,45 @@ using System.Text.Json; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Public.Response; - -/// -/// A policy. -/// -public class PolicyResponseModel : PolicyBaseModel, IResponseModel +namespace Bit.Api.Models.Public.Response { - public PolicyResponseModel(Policy policy) + /// + /// A policy. + /// + public class PolicyResponseModel : PolicyBaseModel, IResponseModel { - if (policy == null) + public PolicyResponseModel(Policy policy) { - throw new ArgumentNullException(nameof(policy)); + if (policy == null) + { + throw new ArgumentNullException(nameof(policy)); + } + + Id = policy.Id; + Type = policy.Type; + Enabled = policy.Enabled; + if (!string.IsNullOrWhiteSpace(policy.Data)) + { + Data = JsonSerializer.Deserialize>(policy.Data); + } } - Id = policy.Id; - Type = policy.Type; - Enabled = policy.Enabled; - if (!string.IsNullOrWhiteSpace(policy.Data)) - { - Data = JsonSerializer.Deserialize>(policy.Data); - } + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// policy + [Required] + public string Object => "policy"; + /// + /// The policy's unique identifier. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The type of policy. + /// + [Required] + public PolicyType? Type { get; set; } } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// policy - [Required] - public string Object => "policy"; - /// - /// The policy's unique identifier. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The type of policy. - /// - [Required] - public PolicyType? Type { get; set; } } diff --git a/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs b/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs index 541df9a810..635d878a5c 100644 --- a/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs +++ b/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs @@ -1,11 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class DeleteRecoverRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } + public class DeleteRecoverRequestModel + { + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/EmailRequestModel.cs b/src/Api/Models/Request/Accounts/EmailRequestModel.cs index 54e8bfbcc0..7eabe3e2e7 100644 --- a/src/Api/Models/Request/Accounts/EmailRequestModel.cs +++ b/src/Api/Models/Request/Accounts/EmailRequestModel.cs @@ -1,19 +1,20 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Accounts; - -public class EmailRequestModel : SecretVerificationRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string NewEmail { get; set; } - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [Required] - public string Token { get; set; } - [Required] - public string Key { get; set; } + public class EmailRequestModel : SecretVerificationRequestModel + { + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string NewEmail { get; set; } + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [Required] + public string Token { get; set; } + [Required] + public string Key { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs b/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs index c4c4f7814a..298b5918de 100644 --- a/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs +++ b/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs @@ -1,12 +1,13 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Accounts; - -public class EmailTokenRequestModel : SecretVerificationRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string NewEmail { get; set; } + public class EmailTokenRequestModel : SecretVerificationRequestModel + { + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string NewEmail { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs b/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs index 2a675fa48a..321fef6586 100644 --- a/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs +++ b/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Api.Models.Request.Accounts; - -public class ImportCiphersRequestModel +namespace Bit.Api.Models.Request.Accounts { - public FolderRequestModel[] Folders { get; set; } - public CipherRequestModel[] Ciphers { get; set; } - public KeyValuePair[] FolderRelationships { get; set; } + public class ImportCiphersRequestModel + { + public FolderRequestModel[] Folders { get; set; } + public CipherRequestModel[] Ciphers { get; set; } + public KeyValuePair[] FolderRelationships { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/KdfRequestModel.cs b/src/Api/Models/Request/Accounts/KdfRequestModel.cs index ac920c7dbf..eea6ad201b 100644 --- a/src/Api/Models/Request/Accounts/KdfRequestModel.cs +++ b/src/Api/Models/Request/Accounts/KdfRequestModel.cs @@ -1,29 +1,30 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Api.Models.Request.Accounts; - -public class KdfRequestModel : PasswordRequestModel, IValidatableObject +namespace Bit.Api.Models.Request.Accounts { - [Required] - public KdfType? Kdf { get; set; } - [Required] - public int? KdfIterations { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class KdfRequestModel : PasswordRequestModel, IValidatableObject { - if (Kdf.HasValue && KdfIterations.HasValue) + [Required] + public KdfType? Kdf { get; set; } + [Required] + public int? KdfIterations { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - switch (Kdf.Value) + if (Kdf.HasValue && KdfIterations.HasValue) { - case KdfType.PBKDF2_SHA256: - if (KdfIterations.Value < 5000 || KdfIterations.Value > 2_000_000) - { - yield return new ValidationResult("KDF iterations must be between 5000 and 2000000."); - } - break; - default: - break; + switch (Kdf.Value) + { + case KdfType.PBKDF2_SHA256: + if (KdfIterations.Value < 5000 || KdfIterations.Value > 2_000_000) + { + yield return new ValidationResult("KDF iterations must be between 5000 and 2000000."); + } + break; + default: + break; + } } } } diff --git a/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs b/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs index c7e8408187..331cd7045d 100644 --- a/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs +++ b/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs @@ -1,8 +1,9 @@ using Bit.Core.Enums; -namespace Bit.Api.Models.Request.Accounts; - -public class OrganizationApiKeyRequestModel : SecretVerificationRequestModel +namespace Bit.Api.Models.Request.Accounts { - public OrganizationApiKeyType Type { get; set; } + public class OrganizationApiKeyRequestModel : SecretVerificationRequestModel + { + public OrganizationApiKeyType Type { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs b/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs index 340a89be26..148ced2b2e 100644 --- a/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs @@ -1,11 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class PasswordHintRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } + public class PasswordHintRequestModel + { + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/PasswordRequestModel.cs b/src/Api/Models/Request/Accounts/PasswordRequestModel.cs index d7c22da4b5..0df96f5272 100644 --- a/src/Api/Models/Request/Accounts/PasswordRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PasswordRequestModel.cs @@ -1,14 +1,15 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class PasswordRequestModel : SecretVerificationRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [StringLength(50)] - public string MasterPasswordHint { get; set; } - [Required] - public string Key { get; set; } + public class PasswordRequestModel : SecretVerificationRequestModel + { + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [StringLength(50)] + public string MasterPasswordHint { get; set; } + [Required] + public string Key { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs index 26d199381f..3fd95b1ce4 100644 --- a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs @@ -2,40 +2,41 @@ using Bit.Core.Settings; using Enums = Bit.Core.Enums; -namespace Bit.Api.Models.Request.Accounts; - -public class PremiumRequestModel : IValidatableObject +namespace Bit.Api.Models.Request.Accounts { - [Required] - public Enums.PaymentMethodType? PaymentMethodType { get; set; } - public string PaymentToken { get; set; } - [Range(0, 99)] - public short? AdditionalStorageGb { get; set; } - public IFormFile License { get; set; } - public string Country { get; set; } - public string PostalCode { get; set; } - - public bool Validate(GlobalSettings globalSettings) + public class PremiumRequestModel : IValidatableObject { - if (!(License == null && !globalSettings.SelfHosted) || - (License != null && globalSettings.SelfHosted)) - { - return false; - } - return globalSettings.SelfHosted || !string.IsNullOrWhiteSpace(Country); - } + [Required] + public Enums.PaymentMethodType? PaymentMethodType { get; set; } + public string PaymentToken { get; set; } + [Range(0, 99)] + public short? AdditionalStorageGb { get; set; } + public IFormFile License { get; set; } + public string Country { get; set; } + public string PostalCode { get; set; } - public IEnumerable Validate(ValidationContext validationContext) - { - var creditType = PaymentMethodType.HasValue && PaymentMethodType.Value == Enums.PaymentMethodType.Credit; - if (string.IsNullOrWhiteSpace(PaymentToken) && !creditType && License == null) + public bool Validate(GlobalSettings globalSettings) { - yield return new ValidationResult("Payment token or license is required."); + if (!(License == null && !globalSettings.SelfHosted) || + (License != null && globalSettings.SelfHosted)) + { + return false; + } + return globalSettings.SelfHosted || !string.IsNullOrWhiteSpace(Country); } - if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Zip / postal code is required.", - new string[] { nameof(PostalCode) }); + var creditType = PaymentMethodType.HasValue && PaymentMethodType.Value == Enums.PaymentMethodType.Credit; + if (string.IsNullOrWhiteSpace(PaymentToken) && !creditType && License == null) + { + yield return new ValidationResult("Payment token or license is required."); + } + if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + { + yield return new ValidationResult("Zip / postal code is required.", + new string[] { nameof(PostalCode) }); + } } } } diff --git a/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs b/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs index 329b3a0c34..06a6148d44 100644 --- a/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs +++ b/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs @@ -1,12 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class RegenerateTwoFactorRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - public string MasterPasswordHash { get; set; } - [Required] - [StringLength(50)] - public string Token { get; set; } + public class RegenerateTwoFactorRequestModel + { + [Required] + public string MasterPasswordHash { get; set; } + [Required] + [StringLength(50)] + public string Token { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs b/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs index f35ea96777..e1042d5a39 100644 --- a/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs +++ b/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs @@ -1,19 +1,20 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class SecretVerificationRequestModel : IValidatableObject +namespace Bit.Api.Models.Request.Accounts { - [StringLength(300)] - public string MasterPasswordHash { get; set; } - public string OTP { get; set; } - public string Secret => !string.IsNullOrEmpty(MasterPasswordHash) ? MasterPasswordHash : OTP; - - public virtual IEnumerable Validate(ValidationContext validationContext) + public class SecretVerificationRequestModel : IValidatableObject { - if (string.IsNullOrEmpty(Secret)) + [StringLength(300)] + public string MasterPasswordHash { get; set; } + public string OTP { get; set; } + public string Secret => !string.IsNullOrEmpty(MasterPasswordHash) ? MasterPasswordHash : OTP; + + public virtual IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("MasterPasswordHash or OTP must be supplied."); + if (string.IsNullOrEmpty(Secret)) + { + yield return new ValidationResult("MasterPasswordHash or OTP must be supplied."); + } } } } diff --git a/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs b/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs index a4906b1b5d..39c17bc4b7 100644 --- a/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs +++ b/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs @@ -3,27 +3,28 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api.Request.Accounts; -namespace Bit.Api.Models.Request.Accounts; - -public class SetKeyConnectorKeyRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - public string Key { get; set; } - [Required] - public KeysRequestModel Keys { get; set; } - [Required] - public KdfType Kdf { get; set; } - [Required] - public int KdfIterations { get; set; } - [Required] - public string OrgIdentifier { get; set; } - - public User ToUser(User existingUser) + public class SetKeyConnectorKeyRequestModel { - existingUser.Kdf = Kdf; - existingUser.KdfIterations = KdfIterations; - existingUser.Key = Key; - Keys.ToUser(existingUser); - return existingUser; + [Required] + public string Key { get; set; } + [Required] + public KeysRequestModel Keys { get; set; } + [Required] + public KdfType Kdf { get; set; } + [Required] + public int KdfIterations { get; set; } + [Required] + public string OrgIdentifier { get; set; } + + public User ToUser(User existingUser) + { + existingUser.Kdf = Kdf; + existingUser.KdfIterations = KdfIterations; + existingUser.Key = Key; + Keys.ToUser(existingUser); + return existingUser; + } } } diff --git a/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs b/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs index 8a345001c2..287ba87696 100644 --- a/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs +++ b/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs @@ -3,32 +3,33 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api.Request.Accounts; -namespace Bit.Api.Models.Request.Accounts; - -public class SetPasswordRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [StringLength(300)] - public string MasterPasswordHash { get; set; } - [Required] - public string Key { get; set; } - [StringLength(50)] - public string MasterPasswordHint { get; set; } - [Required] - public KeysRequestModel Keys { get; set; } - [Required] - public KdfType Kdf { get; set; } - [Required] - public int KdfIterations { get; set; } - public string OrgIdentifier { get; set; } - - public User ToUser(User existingUser) + public class SetPasswordRequestModel { - existingUser.MasterPasswordHint = MasterPasswordHint; - existingUser.Kdf = Kdf; - existingUser.KdfIterations = KdfIterations; - existingUser.Key = Key; - Keys.ToUser(existingUser); - return existingUser; + [Required] + [StringLength(300)] + public string MasterPasswordHash { get; set; } + [Required] + public string Key { get; set; } + [StringLength(50)] + public string MasterPasswordHint { get; set; } + [Required] + public KeysRequestModel Keys { get; set; } + [Required] + public KdfType Kdf { get; set; } + [Required] + public int KdfIterations { get; set; } + public string OrgIdentifier { get; set; } + + public User ToUser(User existingUser) + { + existingUser.MasterPasswordHint = MasterPasswordHint; + existingUser.Kdf = Kdf; + existingUser.KdfIterations = KdfIterations; + existingUser.Key = Key; + Keys.ToUser(existingUser); + return existingUser; + } } } diff --git a/src/Api/Models/Request/Accounts/StorageRequestModel.cs b/src/Api/Models/Request/Accounts/StorageRequestModel.cs index 397da74116..beb7c189ba 100644 --- a/src/Api/Models/Request/Accounts/StorageRequestModel.cs +++ b/src/Api/Models/Request/Accounts/StorageRequestModel.cs @@ -1,18 +1,19 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class StorageRequestModel : IValidatableObject +namespace Bit.Api.Models.Request.Accounts { - [Required] - public short? StorageGbAdjustment { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class StorageRequestModel : IValidatableObject { - if (StorageGbAdjustment == 0) + [Required] + public short? StorageGbAdjustment { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Storage adjustment cannot be 0.", - new string[] { nameof(StorageGbAdjustment) }); + if (StorageGbAdjustment == 0) + { + yield return new ValidationResult("Storage adjustment cannot be 0.", + new string[] { nameof(StorageGbAdjustment) }); + } } } } diff --git a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs index f51580408a..205356e68f 100644 --- a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs @@ -1,19 +1,20 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class TaxInfoUpdateRequestModel : IValidatableObject +namespace Bit.Api.Models.Request.Accounts { - [Required] - public string Country { get; set; } - public string PostalCode { get; set; } - - public virtual IEnumerable Validate(ValidationContext validationContext) + public class TaxInfoUpdateRequestModel : IValidatableObject { - if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + [Required] + public string Country { get; set; } + public string PostalCode { get; set; } + + public virtual IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Zip / postal code is required.", - new string[] { nameof(PostalCode) }); + if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + { + yield return new ValidationResult("Zip / postal code is required.", + new string[] { nameof(PostalCode) }); + } } } } diff --git a/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs b/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs index 2064c09b9a..31ac2d8301 100644 --- a/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs +++ b/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs @@ -1,19 +1,20 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class UpdateKeyRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - [StringLength(300)] - public string MasterPasswordHash { get; set; } - [Required] - public IEnumerable Ciphers { get; set; } - [Required] - public IEnumerable Folders { get; set; } - public IEnumerable Sends { get; set; } - [Required] - public string PrivateKey { get; set; } - [Required] - public string Key { get; set; } + public class UpdateKeyRequestModel + { + [Required] + [StringLength(300)] + public string MasterPasswordHash { get; set; } + [Required] + public IEnumerable Ciphers { get; set; } + [Required] + public IEnumerable Folders { get; set; } + public IEnumerable Sends { get; set; } + [Required] + public string PrivateKey { get; set; } + [Required] + public string Key { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs b/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs index fd625fe9de..9f8506dfc9 100644 --- a/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs +++ b/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs @@ -1,20 +1,21 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Api.Models.Request.Accounts; - -public class UpdateProfileRequestModel +namespace Bit.Api.Models.Request.Accounts { - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - [Obsolete("Changes will be made via the 'password' endpoint going forward.")] - public string MasterPasswordHint { get; set; } - - public User ToUser(User existingUser) + public class UpdateProfileRequestModel { - existingUser.Name = Name; - existingUser.MasterPasswordHint = string.IsNullOrWhiteSpace(MasterPasswordHint) ? null : MasterPasswordHint; - return existingUser; + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + [Obsolete("Changes will be made via the 'password' endpoint going forward.")] + public string MasterPasswordHint { get; set; } + + public User ToUser(User existingUser) + { + existingUser.Name = Name; + existingUser.MasterPasswordHint = string.IsNullOrWhiteSpace(MasterPasswordHint) ? null : MasterPasswordHint; + return existingUser; + } } } diff --git a/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs b/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs index 94bfabeeed..db1c0dbd77 100644 --- a/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs +++ b/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs @@ -1,10 +1,11 @@ using System.ComponentModel.DataAnnotations; using Bit.Api.Models.Request.Organizations; -namespace Bit.Api.Models.Request.Accounts; - -public class UpdateTempPasswordRequestModel : OrganizationUserResetPasswordRequestModel +namespace Bit.Api.Models.Request.Accounts { - [StringLength(50)] - public string MasterPasswordHint { get; set; } + public class UpdateTempPasswordRequestModel : OrganizationUserResetPasswordRequestModel + { + [StringLength(50)] + public string MasterPasswordHint { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs b/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs index 1faaade2b2..463750722d 100644 --- a/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs +++ b/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs @@ -1,11 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class VerifyDeleteRecoverRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - public string UserId { get; set; } - [Required] - public string Token { get; set; } + public class VerifyDeleteRecoverRequestModel + { + [Required] + public string UserId { get; set; } + [Required] + public string Token { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs b/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs index 2e8820e1d0..d859966817 100644 --- a/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs +++ b/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs @@ -1,11 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class VerifyEmailRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - public string UserId { get; set; } - [Required] - public string Token { get; set; } + public class VerifyEmailRequestModel + { + [Required] + public string UserId { get; set; } + [Required] + public string Token { get; set; } + } } diff --git a/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs b/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs index 63e37cdf16..6466aee7ea 100644 --- a/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs +++ b/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs @@ -1,9 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts; - -public class VerifyOTPRequestModel +namespace Bit.Api.Models.Request.Accounts { - [Required] - public string OTP { get; set; } + public class VerifyOTPRequestModel + { + [Required] + public string OTP { get; set; } + } } diff --git a/src/Api/Models/Request/AttachmentRequestModel.cs b/src/Api/Models/Request/AttachmentRequestModel.cs index cadeccdc04..b5ca4fb61d 100644 --- a/src/Api/Models/Request/AttachmentRequestModel.cs +++ b/src/Api/Models/Request/AttachmentRequestModel.cs @@ -1,9 +1,10 @@ -namespace Bit.Api.Models.Request; - -public class AttachmentRequestModel +namespace Bit.Api.Models.Request { - public string Key { get; set; } - public string FileName { get; set; } - public long FileSize { get; set; } - public bool AdminRequest { get; set; } = false; + public class AttachmentRequestModel + { + public string Key { get; set; } + public string FileName { get; set; } + public long FileSize { get; set; } + public bool AdminRequest { get; set; } = false; + } } diff --git a/src/Api/Models/Request/BitPayInvoiceRequestModel.cs b/src/Api/Models/Request/BitPayInvoiceRequestModel.cs index ce1d986380..9e87cca002 100644 --- a/src/Api/Models/Request/BitPayInvoiceRequestModel.cs +++ b/src/Api/Models/Request/BitPayInvoiceRequestModel.cs @@ -1,65 +1,66 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Settings; -namespace Bit.Api.Models.Request; - -public class BitPayInvoiceRequestModel : IValidatableObject +namespace Bit.Api.Models.Request { - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public bool Credit { get; set; } - [Required] - public decimal? Amount { get; set; } - public string ReturnUrl { get; set; } - public string Name { get; set; } - public string Email { get; set; } - - public BitPayLight.Models.Invoice.Invoice ToBitpayInvoice(GlobalSettings globalSettings) + public class BitPayInvoiceRequestModel : IValidatableObject { - var inv = new BitPayLight.Models.Invoice.Invoice + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public bool Credit { get; set; } + [Required] + public decimal? Amount { get; set; } + public string ReturnUrl { get; set; } + public string Name { get; set; } + public string Email { get; set; } + + public BitPayLight.Models.Invoice.Invoice ToBitpayInvoice(GlobalSettings globalSettings) { - Price = Convert.ToDouble(Amount.Value), - Currency = "USD", - RedirectUrl = ReturnUrl, - Buyer = new BitPayLight.Models.Invoice.Buyer + var inv = new BitPayLight.Models.Invoice.Invoice { - Email = Email, - Name = Name - }, - NotificationUrl = globalSettings.BitPay.NotificationUrl, - FullNotifications = true, - ExtendedNotifications = true - }; + Price = Convert.ToDouble(Amount.Value), + Currency = "USD", + RedirectUrl = ReturnUrl, + Buyer = new BitPayLight.Models.Invoice.Buyer + { + Email = Email, + Name = Name + }, + NotificationUrl = globalSettings.BitPay.NotificationUrl, + FullNotifications = true, + ExtendedNotifications = true + }; - var posData = string.Empty; - if (UserId.HasValue) - { - posData = "userId:" + UserId.Value; - } - else if (OrganizationId.HasValue) - { - posData = "organizationId:" + OrganizationId.Value; + var posData = string.Empty; + if (UserId.HasValue) + { + posData = "userId:" + UserId.Value; + } + else if (OrganizationId.HasValue) + { + posData = "organizationId:" + OrganizationId.Value; + } + + if (Credit) + { + posData += ",accountCredit:1"; + inv.ItemDesc = "Bitwarden Account Credit"; + } + else + { + inv.ItemDesc = "Bitwarden"; + } + + inv.PosData = posData; + return inv; } - if (Credit) + public IEnumerable Validate(ValidationContext validationContext) { - posData += ",accountCredit:1"; - inv.ItemDesc = "Bitwarden Account Credit"; - } - else - { - inv.ItemDesc = "Bitwarden"; - } - - inv.PosData = posData; - return inv; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!UserId.HasValue && !OrganizationId.HasValue) - { - yield return new ValidationResult("User or Ooganization is required."); + if (!UserId.HasValue && !OrganizationId.HasValue) + { + yield return new ValidationResult("User or Ooganization is required."); + } } } } diff --git a/src/Api/Models/Request/CipherPartialRequestModel.cs b/src/Api/Models/Request/CipherPartialRequestModel.cs index bc58eb4273..996aec5fc4 100644 --- a/src/Api/Models/Request/CipherPartialRequestModel.cs +++ b/src/Api/Models/Request/CipherPartialRequestModel.cs @@ -1,10 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request; - -public class CipherPartialRequestModel +namespace Bit.Api.Models.Request { - [StringLength(36)] - public string FolderId { get; set; } - public bool Favorite { get; set; } + public class CipherPartialRequestModel + { + [StringLength(36)] + public string FolderId { get; set; } + public bool Favorite { get; set; } + } } diff --git a/src/Api/Models/Request/CipherRequestModel.cs b/src/Api/Models/Request/CipherRequestModel.cs index f5f3eee422..90435132ae 100644 --- a/src/Api/Models/Request/CipherRequestModel.cs +++ b/src/Api/Models/Request/CipherRequestModel.cs @@ -8,341 +8,342 @@ using Core.Models.Data; using NS = Newtonsoft.Json; using NSL = Newtonsoft.Json.Linq; -namespace Bit.Api.Models.Request; - -public class CipherRequestModel +namespace Bit.Api.Models.Request { - public CipherType Type { get; set; } - - [StringLength(36)] - public string OrganizationId { get; set; } - public string FolderId { get; set; } - public bool Favorite { get; set; } - public CipherRepromptType Reprompt { get; set; } - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - [EncryptedString] - [EncryptedStringLength(10000)] - public string Notes { get; set; } - public IEnumerable Fields { get; set; } - public IEnumerable PasswordHistory { get; set; } - [Obsolete] - public Dictionary Attachments { get; set; } - // TODO: Rename to Attachments whenever the above is finally removed. - public Dictionary Attachments2 { get; set; } - - public CipherLoginModel Login { get; set; } - public CipherCardModel Card { get; set; } - public CipherIdentityModel Identity { get; set; } - public CipherSecureNoteModel SecureNote { get; set; } - public DateTime? LastKnownRevisionDate { get; set; } = null; - - public CipherDetails ToCipherDetails(Guid userId, bool allowOrgIdSet = true) + public class CipherRequestModel { - var hasOrgId = !string.IsNullOrWhiteSpace(OrganizationId); - var cipher = new CipherDetails + public CipherType Type { get; set; } + + [StringLength(36)] + public string OrganizationId { get; set; } + public string FolderId { get; set; } + public bool Favorite { get; set; } + public CipherRepromptType Reprompt { get; set; } + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + [EncryptedString] + [EncryptedStringLength(10000)] + public string Notes { get; set; } + public IEnumerable Fields { get; set; } + public IEnumerable PasswordHistory { get; set; } + [Obsolete] + public Dictionary Attachments { get; set; } + // TODO: Rename to Attachments whenever the above is finally removed. + public Dictionary Attachments2 { get; set; } + + public CipherLoginModel Login { get; set; } + public CipherCardModel Card { get; set; } + public CipherIdentityModel Identity { get; set; } + public CipherSecureNoteModel SecureNote { get; set; } + public DateTime? LastKnownRevisionDate { get; set; } = null; + + public CipherDetails ToCipherDetails(Guid userId, bool allowOrgIdSet = true) { - Type = Type, - UserId = !hasOrgId ? (Guid?)userId : null, - OrganizationId = allowOrgIdSet && hasOrgId ? new Guid(OrganizationId) : (Guid?)null, - Edit = true, - ViewPassword = true, - }; - ToCipherDetails(cipher); - return cipher; - } - - public CipherDetails ToCipherDetails(CipherDetails existingCipher) - { - existingCipher.FolderId = string.IsNullOrWhiteSpace(FolderId) ? null : (Guid?)new Guid(FolderId); - existingCipher.Favorite = Favorite; - ToCipher(existingCipher); - return existingCipher; - } - - public Cipher ToCipher(Cipher existingCipher) - { - switch (existingCipher.Type) - { - case CipherType.Login: - var loginObj = NSL.JObject.FromObject(ToCipherLoginData(), - new NS.JsonSerializer { NullValueHandling = NS.NullValueHandling.Ignore }); - // TODO: Switch to JsonNode in .NET 6 https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-use-dom-utf8jsonreader-utf8jsonwriter?pivots=dotnet-6-0 - loginObj[nameof(CipherLoginData.Uri)]?.Parent?.Remove(); - existingCipher.Data = loginObj.ToString(NS.Formatting.None); - break; - case CipherType.Card: - existingCipher.Data = JsonSerializer.Serialize(ToCipherCardData(), JsonHelpers.IgnoreWritingNull); - break; - case CipherType.Identity: - existingCipher.Data = JsonSerializer.Serialize(ToCipherIdentityData(), JsonHelpers.IgnoreWritingNull); - break; - case CipherType.SecureNote: - existingCipher.Data = JsonSerializer.Serialize(ToCipherSecureNoteData(), JsonHelpers.IgnoreWritingNull); - break; - default: - throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); + var hasOrgId = !string.IsNullOrWhiteSpace(OrganizationId); + var cipher = new CipherDetails + { + Type = Type, + UserId = !hasOrgId ? (Guid?)userId : null, + OrganizationId = allowOrgIdSet && hasOrgId ? new Guid(OrganizationId) : (Guid?)null, + Edit = true, + ViewPassword = true, + }; + ToCipherDetails(cipher); + return cipher; } - existingCipher.Reprompt = Reprompt; - - var hasAttachments2 = (Attachments2?.Count ?? 0) > 0; - var hasAttachments = (Attachments?.Count ?? 0) > 0; - - if (!hasAttachments2 && !hasAttachments) + public CipherDetails ToCipherDetails(CipherDetails existingCipher) { + existingCipher.FolderId = string.IsNullOrWhiteSpace(FolderId) ? null : (Guid?)new Guid(FolderId); + existingCipher.Favorite = Favorite; + ToCipher(existingCipher); return existingCipher; } - var attachments = existingCipher.GetAttachments(); - if ((attachments?.Count ?? 0) == 0) + public Cipher ToCipher(Cipher existingCipher) { - return existingCipher; - } - - if (hasAttachments2) - { - foreach (var attachment in attachments.Where(a => Attachments2.ContainsKey(a.Key))) + switch (existingCipher.Type) { - var attachment2 = Attachments2[attachment.Key]; - attachment.Value.FileName = attachment2.FileName; - attachment.Value.Key = attachment2.Key; + case CipherType.Login: + var loginObj = NSL.JObject.FromObject(ToCipherLoginData(), + new NS.JsonSerializer { NullValueHandling = NS.NullValueHandling.Ignore }); + // TODO: Switch to JsonNode in .NET 6 https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-use-dom-utf8jsonreader-utf8jsonwriter?pivots=dotnet-6-0 + loginObj[nameof(CipherLoginData.Uri)]?.Parent?.Remove(); + existingCipher.Data = loginObj.ToString(NS.Formatting.None); + break; + case CipherType.Card: + existingCipher.Data = JsonSerializer.Serialize(ToCipherCardData(), JsonHelpers.IgnoreWritingNull); + break; + case CipherType.Identity: + existingCipher.Data = JsonSerializer.Serialize(ToCipherIdentityData(), JsonHelpers.IgnoreWritingNull); + break; + case CipherType.SecureNote: + existingCipher.Data = JsonSerializer.Serialize(ToCipherSecureNoteData(), JsonHelpers.IgnoreWritingNull); + break; + default: + throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); } - } - else if (hasAttachments) - { - foreach (var attachment in attachments.Where(a => Attachments.ContainsKey(a.Key))) + + existingCipher.Reprompt = Reprompt; + + var hasAttachments2 = (Attachments2?.Count ?? 0) > 0; + var hasAttachments = (Attachments?.Count ?? 0) > 0; + + if (!hasAttachments2 && !hasAttachments) { - attachment.Value.FileName = Attachments[attachment.Key]; - attachment.Value.Key = null; + return existingCipher; } - } - existingCipher.SetAttachments(attachments); - return existingCipher; - } - - public Cipher ToOrganizationCipher() - { - if (string.IsNullOrWhiteSpace(OrganizationId)) - { - throw new ArgumentNullException(nameof(OrganizationId)); - } - - return ToCipher(new Cipher - { - Type = Type, - OrganizationId = new Guid(OrganizationId) - }); - } - - public CipherDetails ToOrganizationCipherDetails(Guid orgId) - { - return ToCipherDetails(new CipherDetails - { - Type = Type, - OrganizationId = orgId, - Edit = true - }); - } - - private CipherLoginData ToCipherLoginData() - { - return new CipherLoginData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - Uris = - Login.Uris?.Where(u => u != null) - .Select(u => u.ToCipherLoginUriData()), - Username = Login.Username, - Password = Login.Password, - PasswordRevisionDate = Login.PasswordRevisionDate, - Totp = Login.Totp, - AutofillOnPageLoad = Login.AutofillOnPageLoad, - }; - } - - private CipherIdentityData ToCipherIdentityData() - { - return new CipherIdentityData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - Title = Identity.Title, - FirstName = Identity.FirstName, - MiddleName = Identity.MiddleName, - LastName = Identity.LastName, - Address1 = Identity.Address1, - Address2 = Identity.Address2, - Address3 = Identity.Address3, - City = Identity.City, - State = Identity.State, - PostalCode = Identity.PostalCode, - Country = Identity.Country, - Company = Identity.Company, - Email = Identity.Email, - Phone = Identity.Phone, - SSN = Identity.SSN, - Username = Identity.Username, - PassportNumber = Identity.PassportNumber, - LicenseNumber = Identity.LicenseNumber, - }; - } - - private CipherCardData ToCipherCardData() - { - return new CipherCardData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - CardholderName = Card.CardholderName, - Brand = Card.Brand, - Number = Card.Number, - ExpMonth = Card.ExpMonth, - ExpYear = Card.ExpYear, - Code = Card.Code, - }; - } - - private CipherSecureNoteData ToCipherSecureNoteData() - { - return new CipherSecureNoteData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - Type = SecureNote.Type, - }; - } -} - -public class CipherWithIdRequestModel : CipherRequestModel -{ - [Required] - public Guid? Id { get; set; } -} - -public class CipherCreateRequestModel : IValidatableObject -{ - public IEnumerable CollectionIds { get; set; } - [Required] - public CipherRequestModel Cipher { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!string.IsNullOrWhiteSpace(Cipher.OrganizationId) && (!CollectionIds?.Any() ?? true)) - { - yield return new ValidationResult("You must select at least one collection.", - new string[] { nameof(CollectionIds) }); - } - } -} - -public class CipherShareRequestModel : IValidatableObject -{ - [Required] - public IEnumerable CollectionIds { get; set; } - [Required] - public CipherRequestModel Cipher { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (string.IsNullOrWhiteSpace(Cipher.OrganizationId)) - { - yield return new ValidationResult("Cipher OrganizationId is required.", - new string[] { nameof(Cipher.OrganizationId) }); - } - - if (!CollectionIds?.Any() ?? true) - { - yield return new ValidationResult("You must select at least one collection.", - new string[] { nameof(CollectionIds) }); - } - } -} - -public class CipherCollectionsRequestModel -{ - [Required] - public IEnumerable CollectionIds { get; set; } -} - -public class CipherBulkDeleteRequestModel -{ - [Required] - public IEnumerable Ids { get; set; } - public string OrganizationId { get; set; } -} - -public class CipherBulkRestoreRequestModel -{ - [Required] - public IEnumerable Ids { get; set; } -} - -public class CipherBulkMoveRequestModel -{ - [Required] - public IEnumerable Ids { get; set; } - public string FolderId { get; set; } -} - -public class CipherBulkShareRequestModel : IValidatableObject -{ - [Required] - public IEnumerable CollectionIds { get; set; } - [Required] - public IEnumerable Ciphers { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!Ciphers?.Any() ?? true) - { - yield return new ValidationResult("You must select at least one cipher.", - new string[] { nameof(Ciphers) }); - } - else - { - var allHaveIds = true; - var organizationIds = new HashSet(); - foreach (var c in Ciphers) + var attachments = existingCipher.GetAttachments(); + if ((attachments?.Count ?? 0) == 0) { - organizationIds.Add(c.OrganizationId); - if (allHaveIds) + return existingCipher; + } + + if (hasAttachments2) + { + foreach (var attachment in attachments.Where(a => Attachments2.ContainsKey(a.Key))) { - allHaveIds = !(!c.Id.HasValue || string.IsNullOrWhiteSpace(c.OrganizationId)); + var attachment2 = Attachments2[attachment.Key]; + attachment.Value.FileName = attachment2.FileName; + attachment.Value.Key = attachment2.Key; + } + } + else if (hasAttachments) + { + foreach (var attachment in attachments.Where(a => Attachments.ContainsKey(a.Key))) + { + attachment.Value.FileName = Attachments[attachment.Key]; + attachment.Value.Key = null; } } - if (!allHaveIds) - { - yield return new ValidationResult("All Ciphers must have an Id and OrganizationId.", - new string[] { nameof(Ciphers) }); - } - else if (organizationIds.Count != 1) - { - yield return new ValidationResult("All ciphers must be for the same organization."); - } + existingCipher.SetAttachments(attachments); + return existingCipher; } - if (!CollectionIds?.Any() ?? true) + public Cipher ToOrganizationCipher() { - yield return new ValidationResult("You must select at least one collection.", - new string[] { nameof(CollectionIds) }); + if (string.IsNullOrWhiteSpace(OrganizationId)) + { + throw new ArgumentNullException(nameof(OrganizationId)); + } + + return ToCipher(new Cipher + { + Type = Type, + OrganizationId = new Guid(OrganizationId) + }); + } + + public CipherDetails ToOrganizationCipherDetails(Guid orgId) + { + return ToCipherDetails(new CipherDetails + { + Type = Type, + OrganizationId = orgId, + Edit = true + }); + } + + private CipherLoginData ToCipherLoginData() + { + return new CipherLoginData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + Uris = + Login.Uris?.Where(u => u != null) + .Select(u => u.ToCipherLoginUriData()), + Username = Login.Username, + Password = Login.Password, + PasswordRevisionDate = Login.PasswordRevisionDate, + Totp = Login.Totp, + AutofillOnPageLoad = Login.AutofillOnPageLoad, + }; + } + + private CipherIdentityData ToCipherIdentityData() + { + return new CipherIdentityData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + Title = Identity.Title, + FirstName = Identity.FirstName, + MiddleName = Identity.MiddleName, + LastName = Identity.LastName, + Address1 = Identity.Address1, + Address2 = Identity.Address2, + Address3 = Identity.Address3, + City = Identity.City, + State = Identity.State, + PostalCode = Identity.PostalCode, + Country = Identity.Country, + Company = Identity.Company, + Email = Identity.Email, + Phone = Identity.Phone, + SSN = Identity.SSN, + Username = Identity.Username, + PassportNumber = Identity.PassportNumber, + LicenseNumber = Identity.LicenseNumber, + }; + } + + private CipherCardData ToCipherCardData() + { + return new CipherCardData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + CardholderName = Card.CardholderName, + Brand = Card.Brand, + Number = Card.Number, + ExpMonth = Card.ExpMonth, + ExpYear = Card.ExpYear, + Code = Card.Code, + }; + } + + private CipherSecureNoteData ToCipherSecureNoteData() + { + return new CipherSecureNoteData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + Type = SecureNote.Type, + }; + } + } + + public class CipherWithIdRequestModel : CipherRequestModel + { + [Required] + public Guid? Id { get; set; } + } + + public class CipherCreateRequestModel : IValidatableObject + { + public IEnumerable CollectionIds { get; set; } + [Required] + public CipherRequestModel Cipher { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!string.IsNullOrWhiteSpace(Cipher.OrganizationId) && (!CollectionIds?.Any() ?? true)) + { + yield return new ValidationResult("You must select at least one collection.", + new string[] { nameof(CollectionIds) }); + } + } + } + + public class CipherShareRequestModel : IValidatableObject + { + [Required] + public IEnumerable CollectionIds { get; set; } + [Required] + public CipherRequestModel Cipher { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(Cipher.OrganizationId)) + { + yield return new ValidationResult("Cipher OrganizationId is required.", + new string[] { nameof(Cipher.OrganizationId) }); + } + + if (!CollectionIds?.Any() ?? true) + { + yield return new ValidationResult("You must select at least one collection.", + new string[] { nameof(CollectionIds) }); + } + } + } + + public class CipherCollectionsRequestModel + { + [Required] + public IEnumerable CollectionIds { get; set; } + } + + public class CipherBulkDeleteRequestModel + { + [Required] + public IEnumerable Ids { get; set; } + public string OrganizationId { get; set; } + } + + public class CipherBulkRestoreRequestModel + { + [Required] + public IEnumerable Ids { get; set; } + } + + public class CipherBulkMoveRequestModel + { + [Required] + public IEnumerable Ids { get; set; } + public string FolderId { get; set; } + } + + public class CipherBulkShareRequestModel : IValidatableObject + { + [Required] + public IEnumerable CollectionIds { get; set; } + [Required] + public IEnumerable Ciphers { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!Ciphers?.Any() ?? true) + { + yield return new ValidationResult("You must select at least one cipher.", + new string[] { nameof(Ciphers) }); + } + else + { + var allHaveIds = true; + var organizationIds = new HashSet(); + foreach (var c in Ciphers) + { + organizationIds.Add(c.OrganizationId); + if (allHaveIds) + { + allHaveIds = !(!c.Id.HasValue || string.IsNullOrWhiteSpace(c.OrganizationId)); + } + } + + if (!allHaveIds) + { + yield return new ValidationResult("All Ciphers must have an Id and OrganizationId.", + new string[] { nameof(Ciphers) }); + } + else if (organizationIds.Count != 1) + { + yield return new ValidationResult("All ciphers must be for the same organization."); + } + } + + if (!CollectionIds?.Any() ?? true) + { + yield return new ValidationResult("You must select at least one collection.", + new string[] { nameof(CollectionIds) }); + } } } } diff --git a/src/Api/Models/Request/CollectionRequestModel.cs b/src/Api/Models/Request/CollectionRequestModel.cs index fb0be314d3..e09510347c 100644 --- a/src/Api/Models/Request/CollectionRequestModel.cs +++ b/src/Api/Models/Request/CollectionRequestModel.cs @@ -2,30 +2,31 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request; - -public class CollectionRequestModel +namespace Bit.Api.Models.Request { - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - [StringLength(300)] - public string ExternalId { get; set; } - public IEnumerable Groups { get; set; } - - public Collection ToCollection(Guid orgId) + public class CollectionRequestModel { - return ToCollection(new Collection + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + [StringLength(300)] + public string ExternalId { get; set; } + public IEnumerable Groups { get; set; } + + public Collection ToCollection(Guid orgId) { - OrganizationId = orgId - }); - } + return ToCollection(new Collection + { + OrganizationId = orgId + }); + } - public Collection ToCollection(Collection existingCollection) - { - existingCollection.Name = Name; - existingCollection.ExternalId = ExternalId; - return existingCollection; + public Collection ToCollection(Collection existingCollection) + { + existingCollection.Name = Name; + existingCollection.ExternalId = ExternalId; + return existingCollection; + } } } diff --git a/src/Api/Models/Request/DeviceRequestModels.cs b/src/Api/Models/Request/DeviceRequestModels.cs index 8d88c7f9c3..b47693e3be 100644 --- a/src/Api/Models/Request/DeviceRequestModels.cs +++ b/src/Api/Models/Request/DeviceRequestModels.cs @@ -2,48 +2,49 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Request; - -public class DeviceRequestModel +namespace Bit.Api.Models.Request { - [Required] - public DeviceType? Type { get; set; } - [Required] - [StringLength(50)] - public string Name { get; set; } - [Required] - [StringLength(50)] - public string Identifier { get; set; } - [StringLength(255)] - public string PushToken { get; set; } - - public Device ToDevice(Guid? userId = null) + public class DeviceRequestModel { - return ToDevice(new Device + [Required] + public DeviceType? Type { get; set; } + [Required] + [StringLength(50)] + public string Name { get; set; } + [Required] + [StringLength(50)] + public string Identifier { get; set; } + [StringLength(255)] + public string PushToken { get; set; } + + public Device ToDevice(Guid? userId = null) { - UserId = userId == null ? default(Guid) : userId.Value - }); + return ToDevice(new Device + { + UserId = userId == null ? default(Guid) : userId.Value + }); + } + + public Device ToDevice(Device existingDevice) + { + existingDevice.Name = Name; + existingDevice.Identifier = Identifier; + existingDevice.PushToken = PushToken; + existingDevice.Type = Type.Value; + + return existingDevice; + } } - public Device ToDevice(Device existingDevice) + public class DeviceTokenRequestModel { - existingDevice.Name = Name; - existingDevice.Identifier = Identifier; - existingDevice.PushToken = PushToken; - existingDevice.Type = Type.Value; + [StringLength(255)] + public string PushToken { get; set; } - return existingDevice; - } -} - -public class DeviceTokenRequestModel -{ - [StringLength(255)] - public string PushToken { get; set; } - - public Device ToDevice(Device existingDevice) - { - existingDevice.PushToken = PushToken; - return existingDevice; + public Device ToDevice(Device existingDevice) + { + existingDevice.PushToken = PushToken; + return existingDevice; + } } } diff --git a/src/Api/Models/Request/DeviceVerificationRequestModel.cs b/src/Api/Models/Request/DeviceVerificationRequestModel.cs index d81471916b..e8c22d9fe7 100644 --- a/src/Api/Models/Request/DeviceVerificationRequestModel.cs +++ b/src/Api/Models/Request/DeviceVerificationRequestModel.cs @@ -1,16 +1,17 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Api.Models.Request; - -public class DeviceVerificationRequestModel +namespace Bit.Api.Models.Request { - [Required] - public bool UnknownDeviceVerificationEnabled { get; set; } - - public User ToUser(User user) + public class DeviceVerificationRequestModel { - user.UnknownDeviceVerificationEnabled = UnknownDeviceVerificationEnabled; - return user; + [Required] + public bool UnknownDeviceVerificationEnabled { get; set; } + + public User ToUser(User user) + { + user.UnknownDeviceVerificationEnabled = UnknownDeviceVerificationEnabled; + return user; + } } } diff --git a/src/Api/Models/Request/EmergencyAccessRequstModels.cs b/src/Api/Models/Request/EmergencyAccessRequstModels.cs index 040316c50b..a8e9f07a0f 100644 --- a/src/Api/Models/Request/EmergencyAccessRequstModels.cs +++ b/src/Api/Models/Request/EmergencyAccessRequstModels.cs @@ -3,46 +3,47 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request; - -public class EmergencyAccessInviteRequestModel +namespace Bit.Api.Models.Request { - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string Email { get; set; } - [Required] - public EmergencyAccessType? Type { get; set; } - [Required] - public int WaitTimeDays { get; set; } -} - -public class EmergencyAccessUpdateRequestModel -{ - [Required] - public EmergencyAccessType Type { get; set; } - [Required] - public int WaitTimeDays { get; set; } - public string KeyEncrypted { get; set; } - - public EmergencyAccess ToEmergencyAccess(EmergencyAccess existingEmergencyAccess) + public class EmergencyAccessInviteRequestModel { - // Ensure we only set keys for a confirmed emergency access. - if (!string.IsNullOrWhiteSpace(existingEmergencyAccess.KeyEncrypted) && !string.IsNullOrWhiteSpace(KeyEncrypted)) + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string Email { get; set; } + [Required] + public EmergencyAccessType? Type { get; set; } + [Required] + public int WaitTimeDays { get; set; } + } + + public class EmergencyAccessUpdateRequestModel + { + [Required] + public EmergencyAccessType Type { get; set; } + [Required] + public int WaitTimeDays { get; set; } + public string KeyEncrypted { get; set; } + + public EmergencyAccess ToEmergencyAccess(EmergencyAccess existingEmergencyAccess) { - existingEmergencyAccess.KeyEncrypted = KeyEncrypted; + // Ensure we only set keys for a confirmed emergency access. + if (!string.IsNullOrWhiteSpace(existingEmergencyAccess.KeyEncrypted) && !string.IsNullOrWhiteSpace(KeyEncrypted)) + { + existingEmergencyAccess.KeyEncrypted = KeyEncrypted; + } + existingEmergencyAccess.Type = Type; + existingEmergencyAccess.WaitTimeDays = WaitTimeDays; + return existingEmergencyAccess; } - existingEmergencyAccess.Type = Type; - existingEmergencyAccess.WaitTimeDays = WaitTimeDays; - return existingEmergencyAccess; + } + + public class EmergencyAccessPasswordRequestModel + { + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [Required] + public string Key { get; set; } } } - -public class EmergencyAccessPasswordRequestModel -{ - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [Required] - public string Key { get; set; } -} diff --git a/src/Api/Models/Request/FolderRequestModel.cs b/src/Api/Models/Request/FolderRequestModel.cs index 092b993bb5..52b0fcdb3a 100644 --- a/src/Api/Models/Request/FolderRequestModel.cs +++ b/src/Api/Models/Request/FolderRequestModel.cs @@ -2,31 +2,32 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request; - -public class FolderRequestModel +namespace Bit.Api.Models.Request { - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - - public Folder ToFolder(Guid userId) + public class FolderRequestModel { - return ToFolder(new Folder + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + + public Folder ToFolder(Guid userId) { - UserId = userId - }); + return ToFolder(new Folder + { + UserId = userId + }); + } + + public Folder ToFolder(Folder existingFolder) + { + existingFolder.Name = Name; + return existingFolder; + } } - public Folder ToFolder(Folder existingFolder) + public class FolderWithIdRequestModel : FolderRequestModel { - existingFolder.Name = Name; - return existingFolder; + public Guid Id { get; set; } } } - -public class FolderWithIdRequestModel : FolderRequestModel -{ - public Guid Id { get; set; } -} diff --git a/src/Api/Models/Request/GroupRequestModel.cs b/src/Api/Models/Request/GroupRequestModel.cs index 71e76590be..23b817624c 100644 --- a/src/Api/Models/Request/GroupRequestModel.cs +++ b/src/Api/Models/Request/GroupRequestModel.cs @@ -1,32 +1,33 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Api.Models.Request; - -public class GroupRequestModel +namespace Bit.Api.Models.Request { - [Required] - [StringLength(100)] - public string Name { get; set; } - [Required] - public bool? AccessAll { get; set; } - [StringLength(300)] - public string ExternalId { get; set; } - public IEnumerable Collections { get; set; } - - public Group ToGroup(Guid orgId) + public class GroupRequestModel { - return ToGroup(new Group + [Required] + [StringLength(100)] + public string Name { get; set; } + [Required] + public bool? AccessAll { get; set; } + [StringLength(300)] + public string ExternalId { get; set; } + public IEnumerable Collections { get; set; } + + public Group ToGroup(Guid orgId) { - OrganizationId = orgId - }); - } + return ToGroup(new Group + { + OrganizationId = orgId + }); + } - public Group ToGroup(Group existingGroup) - { - existingGroup.Name = Name; - existingGroup.AccessAll = AccessAll.Value; - existingGroup.ExternalId = ExternalId; - return existingGroup; + public Group ToGroup(Group existingGroup) + { + existingGroup.Name = Name; + existingGroup.AccessAll = AccessAll.Value; + existingGroup.ExternalId = ExternalId; + return existingGroup; + } } } diff --git a/src/Api/Models/Request/IapCheckRequestModel.cs b/src/Api/Models/Request/IapCheckRequestModel.cs index ededb37ee4..d7ca6ba3bc 100644 --- a/src/Api/Models/Request/IapCheckRequestModel.cs +++ b/src/Api/Models/Request/IapCheckRequestModel.cs @@ -1,19 +1,20 @@ using System.ComponentModel.DataAnnotations; using Enums = Bit.Core.Enums; -namespace Bit.Api.Models.Request; - -public class IapCheckRequestModel : IValidatableObject +namespace Bit.Api.Models.Request { - [Required] - public Enums.PaymentMethodType? PaymentMethodType { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class IapCheckRequestModel : IValidatableObject { - if (PaymentMethodType != Enums.PaymentMethodType.AppleInApp) + [Required] + public Enums.PaymentMethodType? PaymentMethodType { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Not a supported in-app purchase payment method.", - new string[] { nameof(PaymentMethodType) }); + if (PaymentMethodType != Enums.PaymentMethodType.AppleInApp) + { + yield return new ValidationResult("Not a supported in-app purchase payment method.", + new string[] { nameof(PaymentMethodType) }); + } } } } diff --git a/src/Api/Models/Request/InstallationRequestModel.cs b/src/Api/Models/Request/InstallationRequestModel.cs index 65b542e62e..9f594f7bb4 100644 --- a/src/Api/Models/Request/InstallationRequestModel.cs +++ b/src/Api/Models/Request/InstallationRequestModel.cs @@ -2,22 +2,23 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request; - -public class InstallationRequestModel +namespace Bit.Api.Models.Request { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - - public Installation ToInstallation() + public class InstallationRequestModel { - return new Installation + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + + public Installation ToInstallation() { - Key = CoreHelpers.SecureRandomString(20), - Email = Email, - Enabled = true - }; + return new Installation + { + Key = CoreHelpers.SecureRandomString(20), + Email = Email, + Enabled = true + }; + } } } diff --git a/src/Api/Models/Request/LicenseRequestModel.cs b/src/Api/Models/Request/LicenseRequestModel.cs index 7b66d95f0e..382f686152 100644 --- a/src/Api/Models/Request/LicenseRequestModel.cs +++ b/src/Api/Models/Request/LicenseRequestModel.cs @@ -1,9 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request; - -public class LicenseRequestModel +namespace Bit.Api.Models.Request { - [Required] - public IFormFile License { get; set; } + public class LicenseRequestModel + { + [Required] + public IFormFile License { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs b/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs index 3aa6ef68c2..b70f395879 100644 --- a/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs +++ b/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Api.Models.Request.Organizations; - -public class ImportOrganizationCiphersRequestModel +namespace Bit.Api.Models.Request.Organizations { - public CollectionRequestModel[] Collections { get; set; } - public CipherRequestModel[] Ciphers { get; set; } - public KeyValuePair[] CollectionRelationships { get; set; } + public class ImportOrganizationCiphersRequestModel + { + public CollectionRequestModel[] Collections { get; set; } + public CipherRequestModel[] Ciphers { get; set; } + public KeyValuePair[] CollectionRelationships { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs b/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs index 3f1e2b2441..d35e051e9e 100644 --- a/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs +++ b/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs @@ -1,68 +1,69 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Request.Organizations; - -public class ImportOrganizationUsersRequestModel +namespace Bit.Api.Models.Request.Organizations { - public Group[] Groups { get; set; } - public User[] Users { get; set; } - public bool OverwriteExisting { get; set; } - public bool LargeImport { get; set; } - - public class Group + public class ImportOrganizationUsersRequestModel { - [Required] - [StringLength(100)] - public string Name { get; set; } - [Required] - [StringLength(300)] - public string ExternalId { get; set; } - public IEnumerable Users { get; set; } + public Group[] Groups { get; set; } + public User[] Users { get; set; } + public bool OverwriteExisting { get; set; } + public bool LargeImport { get; set; } - public ImportedGroup ToImportedGroup(Guid organizationId) + public class Group { - var importedGroup = new ImportedGroup + [Required] + [StringLength(100)] + public string Name { get; set; } + [Required] + [StringLength(300)] + public string ExternalId { get; set; } + public IEnumerable Users { get; set; } + + public ImportedGroup ToImportedGroup(Guid organizationId) { - Group = new Core.Entities.Group + var importedGroup = new ImportedGroup { - OrganizationId = organizationId, - Name = Name, + Group = new Core.Entities.Group + { + OrganizationId = organizationId, + Name = Name, + ExternalId = ExternalId + }, + ExternalUserIds = new HashSet(Users) + }; + + return importedGroup; + } + } + + public class User : IValidatableObject + { + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + public bool Deleted { get; set; } + [Required] + [StringLength(300)] + public string ExternalId { get; set; } + + public ImportedOrganizationUser ToImportedOrganizationUser() + { + var importedUser = new ImportedOrganizationUser + { + Email = Email.ToLowerInvariant(), ExternalId = ExternalId - }, - ExternalUserIds = new HashSet(Users) - }; + }; - return importedGroup; - } - } + return importedUser; + } - public class User : IValidatableObject - { - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - public bool Deleted { get; set; } - [Required] - [StringLength(300)] - public string ExternalId { get; set; } - - public ImportedOrganizationUser ToImportedOrganizationUser() - { - var importedUser = new ImportedOrganizationUser + public IEnumerable Validate(ValidationContext validationContext) { - Email = Email.ToLowerInvariant(), - ExternalId = ExternalId - }; - - return importedUser; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (string.IsNullOrWhiteSpace(Email) && !Deleted) - { - yield return new ValidationResult("Email is required for enabled users.", new string[] { nameof(Email) }); + if (string.IsNullOrWhiteSpace(Email) && !Deleted) + { + yield return new ValidationResult("Email is required for enabled users.", new string[] { nameof(Email) }); + } } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs index 9dbc9ca0a0..91132ec5e1 100644 --- a/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs @@ -4,47 +4,48 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationConnections; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationConnectionRequestModel +namespace Bit.Api.Models.Request.Organizations { - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public JsonDocument Config { get; set; } - - public OrganizationConnectionRequestModel() { } -} - - -public class OrganizationConnectionRequestModel : OrganizationConnectionRequestModel where T : new() -{ - public T ParsedConfig { get; private set; } - - public OrganizationConnectionRequestModel(OrganizationConnectionRequestModel model) + public class OrganizationConnectionRequestModel { - Type = model.Type; - OrganizationId = model.OrganizationId; - Enabled = model.Enabled; - Config = model.Config; + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public JsonDocument Config { get; set; } - try - { - ParsedConfig = model.Config.ToObject(JsonHelpers.IgnoreCase); - } - catch (JsonException) - { - throw new BadRequestException("Organization Connection configuration malformed"); - } + public OrganizationConnectionRequestModel() { } } - public OrganizationConnectionData ToData(Guid? id = null) => - new() + + public class OrganizationConnectionRequestModel : OrganizationConnectionRequestModel where T : new() + { + public T ParsedConfig { get; private set; } + + public OrganizationConnectionRequestModel(OrganizationConnectionRequestModel model) { - Id = id, - Type = Type, - OrganizationId = OrganizationId, - Enabled = Enabled, - Config = ParsedConfig, - }; + Type = model.Type; + OrganizationId = model.OrganizationId; + Enabled = model.Enabled; + Config = model.Config; + + try + { + ParsedConfig = model.Config.ToObject(JsonHelpers.IgnoreCase); + } + catch (JsonException) + { + throw new BadRequestException("Organization Connection configuration malformed"); + } + } + + public OrganizationConnectionData ToData(Guid? id = null) => + new() + { + Id = id, + Type = Type, + OrganizationId = OrganizationId, + Enabled = Enabled, + Config = ParsedConfig, + }; + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs index 2d9175158f..722d338b9e 100644 --- a/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs @@ -1,14 +1,15 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationCreateLicenseRequestModel : LicenseRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - public string Key { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string CollectionName { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } + public class OrganizationCreateLicenseRequestModel : LicenseRequestModel + { + [Required] + public string Key { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string CollectionName { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs index 3e46021795..4f84ea5c6e 100644 --- a/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs @@ -4,98 +4,99 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationCreateRequestModel : IValidatableObject +namespace Bit.Api.Models.Request.Organizations { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [Required] - [StringLength(256)] - [EmailAddress] - public string BillingEmail { get; set; } - public PlanType PlanType { get; set; } - [Required] - public string Key { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string PaymentToken { get; set; } - [Range(0, int.MaxValue)] - public int AdditionalSeats { get; set; } - [Range(0, 99)] - public short? AdditionalStorageGb { get; set; } - public bool PremiumAccessAddon { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string CollectionName { get; set; } - public string TaxIdNumber { get; set; } - public string BillingAddressLine1 { get; set; } - public string BillingAddressLine2 { get; set; } - public string BillingAddressCity { get; set; } - public string BillingAddressState { get; set; } - public string BillingAddressPostalCode { get; set; } - [StringLength(2)] - public string BillingAddressCountry { get; set; } - public int? MaxAutoscaleSeats { get; set; } - - public virtual OrganizationSignup ToOrganizationSignup(User user) + public class OrganizationCreateRequestModel : IValidatableObject { - var orgSignup = new OrganizationSignup + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [Required] + [StringLength(256)] + [EmailAddress] + public string BillingEmail { get; set; } + public PlanType PlanType { get; set; } + [Required] + public string Key { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string PaymentToken { get; set; } + [Range(0, int.MaxValue)] + public int AdditionalSeats { get; set; } + [Range(0, 99)] + public short? AdditionalStorageGb { get; set; } + public bool PremiumAccessAddon { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string CollectionName { get; set; } + public string TaxIdNumber { get; set; } + public string BillingAddressLine1 { get; set; } + public string BillingAddressLine2 { get; set; } + public string BillingAddressCity { get; set; } + public string BillingAddressState { get; set; } + public string BillingAddressPostalCode { get; set; } + [StringLength(2)] + public string BillingAddressCountry { get; set; } + public int? MaxAutoscaleSeats { get; set; } + + public virtual OrganizationSignup ToOrganizationSignup(User user) { - Owner = user, - OwnerKey = Key, - Name = Name, - Plan = PlanType, - PaymentMethodType = PaymentMethodType, - PaymentToken = PaymentToken, - AdditionalSeats = AdditionalSeats, - MaxAutoscaleSeats = MaxAutoscaleSeats, - AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(0), - PremiumAccessAddon = PremiumAccessAddon, - BillingEmail = BillingEmail, - BusinessName = BusinessName, - CollectionName = CollectionName, - TaxInfo = new TaxInfo + var orgSignup = new OrganizationSignup { - TaxIdNumber = TaxIdNumber, - BillingAddressLine1 = BillingAddressLine1, - BillingAddressLine2 = BillingAddressLine2, - BillingAddressCity = BillingAddressCity, - BillingAddressState = BillingAddressState, - BillingAddressPostalCode = BillingAddressPostalCode, - BillingAddressCountry = BillingAddressCountry, - }, - }; + Owner = user, + OwnerKey = Key, + Name = Name, + Plan = PlanType, + PaymentMethodType = PaymentMethodType, + PaymentToken = PaymentToken, + AdditionalSeats = AdditionalSeats, + MaxAutoscaleSeats = MaxAutoscaleSeats, + AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(0), + PremiumAccessAddon = PremiumAccessAddon, + BillingEmail = BillingEmail, + BusinessName = BusinessName, + CollectionName = CollectionName, + TaxInfo = new TaxInfo + { + TaxIdNumber = TaxIdNumber, + BillingAddressLine1 = BillingAddressLine1, + BillingAddressLine2 = BillingAddressLine2, + BillingAddressCity = BillingAddressCity, + BillingAddressState = BillingAddressState, + BillingAddressPostalCode = BillingAddressPostalCode, + BillingAddressCountry = BillingAddressCountry, + }, + }; - Keys?.ToOrganizationSignup(orgSignup); + Keys?.ToOrganizationSignup(orgSignup); - return orgSignup; - } + return orgSignup; + } - public IEnumerable Validate(ValidationContext validationContext) - { - if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(PaymentToken)) + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Payment required.", new string[] { nameof(PaymentToken) }); - } - if (PlanType != PlanType.Free && !PaymentMethodType.HasValue) - { - yield return new ValidationResult("Payment method type required.", - new string[] { nameof(PaymentMethodType) }); - } - if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(BillingAddressCountry)) - { - yield return new ValidationResult("Country required.", - new string[] { nameof(BillingAddressCountry) }); - } - if (PlanType != PlanType.Free && BillingAddressCountry == "US" && - string.IsNullOrWhiteSpace(BillingAddressPostalCode)) - { - yield return new ValidationResult("Zip / postal code is required.", - new string[] { nameof(BillingAddressPostalCode) }); + if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(PaymentToken)) + { + yield return new ValidationResult("Payment required.", new string[] { nameof(PaymentToken) }); + } + if (PlanType != PlanType.Free && !PaymentMethodType.HasValue) + { + yield return new ValidationResult("Payment method type required.", + new string[] { nameof(PaymentMethodType) }); + } + if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(BillingAddressCountry)) + { + yield return new ValidationResult("Country required.", + new string[] { nameof(BillingAddressCountry) }); + } + if (PlanType != PlanType.Free && BillingAddressCountry == "US" && + string.IsNullOrWhiteSpace(BillingAddressPostalCode)) + { + yield return new ValidationResult("Zip / postal code is required.", + new string[] { nameof(BillingAddressPostalCode) }); + } } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs index a22b4eaa6f..070b03d19a 100644 --- a/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs @@ -2,57 +2,58 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationKeysRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - public string PublicKey { get; set; } - [Required] - public string EncryptedPrivateKey { get; set; } - - public OrganizationSignup ToOrganizationSignup(OrganizationSignup existingSignup) + public class OrganizationKeysRequestModel { - if (string.IsNullOrWhiteSpace(existingSignup.PublicKey)) + [Required] + public string PublicKey { get; set; } + [Required] + public string EncryptedPrivateKey { get; set; } + + public OrganizationSignup ToOrganizationSignup(OrganizationSignup existingSignup) { - existingSignup.PublicKey = PublicKey; + if (string.IsNullOrWhiteSpace(existingSignup.PublicKey)) + { + existingSignup.PublicKey = PublicKey; + } + + if (string.IsNullOrWhiteSpace(existingSignup.PrivateKey)) + { + existingSignup.PrivateKey = EncryptedPrivateKey; + } + + return existingSignup; } - if (string.IsNullOrWhiteSpace(existingSignup.PrivateKey)) + public OrganizationUpgrade ToOrganizationUpgrade(OrganizationUpgrade existingUpgrade) { - existingSignup.PrivateKey = EncryptedPrivateKey; + if (string.IsNullOrWhiteSpace(existingUpgrade.PublicKey)) + { + existingUpgrade.PublicKey = PublicKey; + } + + if (string.IsNullOrWhiteSpace(existingUpgrade.PrivateKey)) + { + existingUpgrade.PrivateKey = EncryptedPrivateKey; + } + + return existingUpgrade; } - return existingSignup; - } - - public OrganizationUpgrade ToOrganizationUpgrade(OrganizationUpgrade existingUpgrade) - { - if (string.IsNullOrWhiteSpace(existingUpgrade.PublicKey)) + public Organization ToOrganization(Organization existingOrg) { - existingUpgrade.PublicKey = PublicKey; + if (string.IsNullOrWhiteSpace(existingOrg.PublicKey)) + { + existingOrg.PublicKey = PublicKey; + } + + if (string.IsNullOrWhiteSpace(existingOrg.PrivateKey)) + { + existingOrg.PrivateKey = EncryptedPrivateKey; + } + + return existingOrg; } - - if (string.IsNullOrWhiteSpace(existingUpgrade.PrivateKey)) - { - existingUpgrade.PrivateKey = EncryptedPrivateKey; - } - - return existingUpgrade; - } - - public Organization ToOrganization(Organization existingOrg) - { - if (string.IsNullOrWhiteSpace(existingOrg.PublicKey)) - { - existingOrg.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingOrg.PrivateKey)) - { - existingOrg.PrivateKey = EncryptedPrivateKey; - } - - return existingOrg; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs index b3849f0a40..068d096242 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs @@ -1,17 +1,18 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationSeatRequestModel : IValidatableObject +namespace Bit.Api.Models.Request.Organizations { - [Required] - public int? SeatAdjustment { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class OrganizationSeatRequestModel : IValidatableObject { - if (SeatAdjustment == 0) + [Required] + public int? SeatAdjustment { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Seat adjustment cannot be 0.", new string[] { nameof(SeatAdjustment) }); + if (SeatAdjustment == 0) + { + yield return new ValidationResult("Seat adjustment cannot be 0.", new string[] { nameof(SeatAdjustment) }); + } } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs index ba88f1b90e..e3848a4cb5 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs @@ -2,18 +2,19 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationSponsorshipCreateRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - public PlanSponsorshipType PlanSponsorshipType { get; set; } + public class OrganizationSponsorshipCreateRequestModel + { + [Required] + public PlanSponsorshipType PlanSponsorshipType { get; set; } - [Required] - [StringLength(256)] - [StrictEmailAddress] - public string SponsoredEmail { get; set; } + [Required] + [StringLength(256)] + [StrictEmailAddress] + public string SponsoredEmail { get; set; } - [StringLength(256)] - public string FriendlyName { get; set; } + [StringLength(256)] + public string FriendlyName { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs index 19b11cd77d..4a4cc26026 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs @@ -1,12 +1,13 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationSponsorshipRedeemRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - public PlanSponsorshipType PlanSponsorshipType { get; set; } - [Required] - public Guid SponsoredOrganizationId { get; set; } + public class OrganizationSponsorshipRedeemRequestModel + { + [Required] + public PlanSponsorshipType PlanSponsorshipType { get; set; } + [Required] + public Guid SponsoredOrganizationId { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs index 47594703d0..5291ce1753 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs @@ -10,215 +10,216 @@ using Bit.Core.Sso; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authentication.OpenIdConnect; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationSsoRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - public bool Enabled { get; set; } - [Required] - public SsoConfigurationDataRequest Data { get; set; } - - public SsoConfig ToSsoConfig(Guid organizationId) + public class OrganizationSsoRequestModel { - return ToSsoConfig(new SsoConfig { OrganizationId = organizationId }); + [Required] + public bool Enabled { get; set; } + [Required] + public SsoConfigurationDataRequest Data { get; set; } + + public SsoConfig ToSsoConfig(Guid organizationId) + { + return ToSsoConfig(new SsoConfig { OrganizationId = organizationId }); + } + + public SsoConfig ToSsoConfig(SsoConfig existingConfig) + { + existingConfig.Enabled = Enabled; + var configurationData = Data.ToConfigurationData(); + existingConfig.SetData(configurationData); + return existingConfig; + } } - public SsoConfig ToSsoConfig(SsoConfig existingConfig) + public class SsoConfigurationDataRequest : IValidatableObject { - existingConfig.Enabled = Enabled; - var configurationData = Data.ToConfigurationData(); - existingConfig.SetData(configurationData); - return existingConfig; - } -} - -public class SsoConfigurationDataRequest : IValidatableObject -{ - public SsoConfigurationDataRequest() { } - - [Required] - public SsoType ConfigType { get; set; } - - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } - - // OIDC - public string Authority { get; set; } - public string ClientId { get; set; } - public string ClientSecret { get; set; } - public string MetadataAddress { get; set; } - public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } - public bool? GetClaimsFromUserInfoEndpoint { get; set; } - public string AdditionalScopes { get; set; } - public string AdditionalUserIdClaimTypes { get; set; } - public string AdditionalEmailClaimTypes { get; set; } - public string AdditionalNameClaimTypes { get; set; } - public string AcrValues { get; set; } - public string ExpectedReturnAcrValue { get; set; } - - // SAML2 SP - public Saml2NameIdFormat SpNameIdFormat { get; set; } - public string SpOutboundSigningAlgorithm { get; set; } - public Saml2SigningBehavior SpSigningBehavior { get; set; } - public bool? SpWantAssertionsSigned { get; set; } - public bool? SpValidateCertificates { get; set; } - public string SpMinIncomingSigningAlgorithm { get; set; } - - // SAML2 IDP - public string IdpEntityId { get; set; } - public Saml2BindingType IdpBindingType { get; set; } - public string IdpSingleSignOnServiceUrl { get; set; } - public string IdpSingleLogoutServiceUrl { get; set; } - public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } - public string IdpX509PublicCert { get; set; } - public string IdpOutboundSigningAlgorithm { get; set; } - public bool? IdpAllowUnsolicitedAuthnResponse { get; set; } - public bool? IdpDisableOutboundLogoutRequests { get; set; } - public bool? IdpWantAuthnRequestsSigned { get; set; } - - public IEnumerable Validate(ValidationContext context) - { - var i18nService = context.GetService(typeof(II18nService)) as I18nService; - - if (ConfigType == SsoType.OpenIdConnect) - { - if (string.IsNullOrWhiteSpace(Authority)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("AuthorityValidationError"), - new[] { nameof(Authority) }); - } - - if (string.IsNullOrWhiteSpace(ClientId)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientIdValidationError"), - new[] { nameof(ClientId) }); - } - - if (string.IsNullOrWhiteSpace(ClientSecret)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientSecretValidationError"), - new[] { nameof(ClientSecret) }); - } - } - else if (ConfigType == SsoType.Saml2) - { - if (string.IsNullOrWhiteSpace(IdpEntityId)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpEntityIdValidationError"), - new[] { nameof(IdpEntityId) }); - } - - if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"), - new[] { nameof(IdpSingleSignOnServiceUrl) }); - } - - if (InvalidServiceUrl(IdpSingleSignOnServiceUrl)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlInvalid"), - new[] { nameof(IdpSingleSignOnServiceUrl) }); - } - - if (InvalidServiceUrl(IdpSingleLogoutServiceUrl)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleLogoutServiceUrlInvalid"), - new[] { nameof(IdpSingleLogoutServiceUrl) }); - } - - if (!string.IsNullOrWhiteSpace(IdpX509PublicCert)) - { - // Validate the certificate is in a valid format - ValidationResult failedResult = null; - try - { - var certData = CoreHelpers.Base64UrlDecode(StripPemCertificateElements(IdpX509PublicCert)); - new X509Certificate2(certData); - } - catch (FormatException) - { - failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertInvalidFormatValidationError"), - new[] { nameof(IdpX509PublicCert) }); - } - catch (CryptographicException cryptoEx) - { - failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertCryptographicExceptionValidationError", cryptoEx.Message), - new[] { nameof(IdpX509PublicCert) }); - } - catch (Exception ex) - { - failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertValidationError", ex.Message), - new[] { nameof(IdpX509PublicCert) }); - } - if (failedResult != null) - { - yield return failedResult; - } - } - } - } - - public SsoConfigurationData ToConfigurationData() - { - return new SsoConfigurationData - { - ConfigType = ConfigType, - KeyConnectorEnabled = KeyConnectorEnabled, - KeyConnectorUrl = KeyConnectorUrl, - Authority = Authority, - ClientId = ClientId, - ClientSecret = ClientSecret, - MetadataAddress = MetadataAddress, - GetClaimsFromUserInfoEndpoint = GetClaimsFromUserInfoEndpoint.GetValueOrDefault(), - RedirectBehavior = RedirectBehavior, - IdpEntityId = IdpEntityId, - IdpBindingType = IdpBindingType, - IdpSingleSignOnServiceUrl = IdpSingleSignOnServiceUrl, - IdpSingleLogoutServiceUrl = IdpSingleLogoutServiceUrl, - IdpArtifactResolutionServiceUrl = null, - IdpX509PublicCert = StripPemCertificateElements(IdpX509PublicCert), - IdpOutboundSigningAlgorithm = IdpOutboundSigningAlgorithm, - IdpAllowUnsolicitedAuthnResponse = IdpAllowUnsolicitedAuthnResponse.GetValueOrDefault(), - IdpDisableOutboundLogoutRequests = IdpDisableOutboundLogoutRequests.GetValueOrDefault(), - IdpWantAuthnRequestsSigned = IdpWantAuthnRequestsSigned.GetValueOrDefault(), - SpNameIdFormat = SpNameIdFormat, - SpOutboundSigningAlgorithm = SpOutboundSigningAlgorithm ?? SamlSigningAlgorithms.Sha256, - SpSigningBehavior = SpSigningBehavior, - SpWantAssertionsSigned = SpWantAssertionsSigned.GetValueOrDefault(), - SpValidateCertificates = SpValidateCertificates.GetValueOrDefault(), - SpMinIncomingSigningAlgorithm = SpMinIncomingSigningAlgorithm, - AdditionalScopes = AdditionalScopes, - AdditionalUserIdClaimTypes = AdditionalUserIdClaimTypes, - AdditionalEmailClaimTypes = AdditionalEmailClaimTypes, - AdditionalNameClaimTypes = AdditionalNameClaimTypes, - AcrValues = AcrValues, - ExpectedReturnAcrValue = ExpectedReturnAcrValue, - }; - } - - private string StripPemCertificateElements(string certificateText) - { - if (string.IsNullOrWhiteSpace(certificateText)) - { - return null; - } - return Regex.Replace(certificateText, - @"(((BEGIN|END) CERTIFICATE)|([\-\n\r\t\s\f]))", - string.Empty, - RegexOptions.Multiline | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant); - } - - private bool InvalidServiceUrl(string url) - { - if (string.IsNullOrWhiteSpace(url)) - { - return false; - } - if (!url.StartsWith("http://") && !url.StartsWith("https://")) - { - return true; - } - return Regex.IsMatch(url, "[<>\"]"); + public SsoConfigurationDataRequest() { } + + [Required] + public SsoType ConfigType { get; set; } + + public bool KeyConnectorEnabled { get; set; } + public string KeyConnectorUrl { get; set; } + + // OIDC + public string Authority { get; set; } + public string ClientId { get; set; } + public string ClientSecret { get; set; } + public string MetadataAddress { get; set; } + public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } + public bool? GetClaimsFromUserInfoEndpoint { get; set; } + public string AdditionalScopes { get; set; } + public string AdditionalUserIdClaimTypes { get; set; } + public string AdditionalEmailClaimTypes { get; set; } + public string AdditionalNameClaimTypes { get; set; } + public string AcrValues { get; set; } + public string ExpectedReturnAcrValue { get; set; } + + // SAML2 SP + public Saml2NameIdFormat SpNameIdFormat { get; set; } + public string SpOutboundSigningAlgorithm { get; set; } + public Saml2SigningBehavior SpSigningBehavior { get; set; } + public bool? SpWantAssertionsSigned { get; set; } + public bool? SpValidateCertificates { get; set; } + public string SpMinIncomingSigningAlgorithm { get; set; } + + // SAML2 IDP + public string IdpEntityId { get; set; } + public Saml2BindingType IdpBindingType { get; set; } + public string IdpSingleSignOnServiceUrl { get; set; } + public string IdpSingleLogoutServiceUrl { get; set; } + public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } + public string IdpX509PublicCert { get; set; } + public string IdpOutboundSigningAlgorithm { get; set; } + public bool? IdpAllowUnsolicitedAuthnResponse { get; set; } + public bool? IdpDisableOutboundLogoutRequests { get; set; } + public bool? IdpWantAuthnRequestsSigned { get; set; } + + public IEnumerable Validate(ValidationContext context) + { + var i18nService = context.GetService(typeof(II18nService)) as I18nService; + + if (ConfigType == SsoType.OpenIdConnect) + { + if (string.IsNullOrWhiteSpace(Authority)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("AuthorityValidationError"), + new[] { nameof(Authority) }); + } + + if (string.IsNullOrWhiteSpace(ClientId)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientIdValidationError"), + new[] { nameof(ClientId) }); + } + + if (string.IsNullOrWhiteSpace(ClientSecret)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientSecretValidationError"), + new[] { nameof(ClientSecret) }); + } + } + else if (ConfigType == SsoType.Saml2) + { + if (string.IsNullOrWhiteSpace(IdpEntityId)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpEntityIdValidationError"), + new[] { nameof(IdpEntityId) }); + } + + if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"), + new[] { nameof(IdpSingleSignOnServiceUrl) }); + } + + if (InvalidServiceUrl(IdpSingleSignOnServiceUrl)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlInvalid"), + new[] { nameof(IdpSingleSignOnServiceUrl) }); + } + + if (InvalidServiceUrl(IdpSingleLogoutServiceUrl)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleLogoutServiceUrlInvalid"), + new[] { nameof(IdpSingleLogoutServiceUrl) }); + } + + if (!string.IsNullOrWhiteSpace(IdpX509PublicCert)) + { + // Validate the certificate is in a valid format + ValidationResult failedResult = null; + try + { + var certData = CoreHelpers.Base64UrlDecode(StripPemCertificateElements(IdpX509PublicCert)); + new X509Certificate2(certData); + } + catch (FormatException) + { + failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertInvalidFormatValidationError"), + new[] { nameof(IdpX509PublicCert) }); + } + catch (CryptographicException cryptoEx) + { + failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertCryptographicExceptionValidationError", cryptoEx.Message), + new[] { nameof(IdpX509PublicCert) }); + } + catch (Exception ex) + { + failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertValidationError", ex.Message), + new[] { nameof(IdpX509PublicCert) }); + } + if (failedResult != null) + { + yield return failedResult; + } + } + } + } + + public SsoConfigurationData ToConfigurationData() + { + return new SsoConfigurationData + { + ConfigType = ConfigType, + KeyConnectorEnabled = KeyConnectorEnabled, + KeyConnectorUrl = KeyConnectorUrl, + Authority = Authority, + ClientId = ClientId, + ClientSecret = ClientSecret, + MetadataAddress = MetadataAddress, + GetClaimsFromUserInfoEndpoint = GetClaimsFromUserInfoEndpoint.GetValueOrDefault(), + RedirectBehavior = RedirectBehavior, + IdpEntityId = IdpEntityId, + IdpBindingType = IdpBindingType, + IdpSingleSignOnServiceUrl = IdpSingleSignOnServiceUrl, + IdpSingleLogoutServiceUrl = IdpSingleLogoutServiceUrl, + IdpArtifactResolutionServiceUrl = null, + IdpX509PublicCert = StripPemCertificateElements(IdpX509PublicCert), + IdpOutboundSigningAlgorithm = IdpOutboundSigningAlgorithm, + IdpAllowUnsolicitedAuthnResponse = IdpAllowUnsolicitedAuthnResponse.GetValueOrDefault(), + IdpDisableOutboundLogoutRequests = IdpDisableOutboundLogoutRequests.GetValueOrDefault(), + IdpWantAuthnRequestsSigned = IdpWantAuthnRequestsSigned.GetValueOrDefault(), + SpNameIdFormat = SpNameIdFormat, + SpOutboundSigningAlgorithm = SpOutboundSigningAlgorithm ?? SamlSigningAlgorithms.Sha256, + SpSigningBehavior = SpSigningBehavior, + SpWantAssertionsSigned = SpWantAssertionsSigned.GetValueOrDefault(), + SpValidateCertificates = SpValidateCertificates.GetValueOrDefault(), + SpMinIncomingSigningAlgorithm = SpMinIncomingSigningAlgorithm, + AdditionalScopes = AdditionalScopes, + AdditionalUserIdClaimTypes = AdditionalUserIdClaimTypes, + AdditionalEmailClaimTypes = AdditionalEmailClaimTypes, + AdditionalNameClaimTypes = AdditionalNameClaimTypes, + AcrValues = AcrValues, + ExpectedReturnAcrValue = ExpectedReturnAcrValue, + }; + } + + private string StripPemCertificateElements(string certificateText) + { + if (string.IsNullOrWhiteSpace(certificateText)) + { + return null; + } + return Regex.Replace(certificateText, + @"(((BEGIN|END) CERTIFICATE)|([\-\n\r\t\s\f]))", + string.Empty, + RegexOptions.Multiline | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant); + } + + private bool InvalidServiceUrl(string url) + { + if (string.IsNullOrWhiteSpace(url)) + { + return false; + } + if (!url.StartsWith("http://") && !url.StartsWith("https://")) + { + return true; + } + return Regex.IsMatch(url, "[<>\"]"); + } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs index 6db32589a3..9adb0b7ec3 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs @@ -1,10 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationSubscriptionUpdateRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - public int SeatAdjustment { get; set; } - public int? MaxAutoscaleSeats { get; set; } + public class OrganizationSubscriptionUpdateRequestModel + { + [Required] + public int SeatAdjustment { get; set; } + public int? MaxAutoscaleSeats { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs index c20fa07afb..a67cbbb7e8 100644 --- a/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs @@ -1,12 +1,13 @@ using Bit.Api.Models.Request.Accounts; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationTaxInfoUpdateRequestModel : TaxInfoUpdateRequestModel +namespace Bit.Api.Models.Request.Organizations { - public string TaxId { get; set; } - public string Line1 { get; set; } - public string Line2 { get; set; } - public string City { get; set; } - public string State { get; set; } + public class OrganizationTaxInfoUpdateRequestModel : TaxInfoUpdateRequestModel + { + public string TaxId { get; set; } + public string Line1 { get; set; } + public string Line2 { get; set; } + public string City { get; set; } + public string State { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs index f67016bce1..24cce9710d 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs @@ -3,35 +3,36 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; using Bit.Core.Settings; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationUpdateRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [StringLength(50)] - public string Identifier { get; set; } - [EmailAddress] - [Required] - [StringLength(256)] - public string BillingEmail { get; set; } - public Permissions Permissions { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } - - public virtual Organization ToOrganization(Organization existingOrganization, GlobalSettings globalSettings) + public class OrganizationUpdateRequestModel { - if (!globalSettings.SelfHosted) + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [StringLength(50)] + public string Identifier { get; set; } + [EmailAddress] + [Required] + [StringLength(256)] + public string BillingEmail { get; set; } + public Permissions Permissions { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + + public virtual Organization ToOrganization(Organization existingOrganization, GlobalSettings globalSettings) { - // These items come from the license file - existingOrganization.Name = Name; - existingOrganization.BusinessName = BusinessName; - existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + if (!globalSettings.SelfHosted) + { + // These items come from the license file + existingOrganization.Name = Name; + existingOrganization.BusinessName = BusinessName; + existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + } + existingOrganization.Identifier = Identifier; + Keys?.ToOrganization(existingOrganization); + return existingOrganization; } - existingOrganization.Identifier = Identifier; - Keys?.ToOrganization(existingOrganization); - return existingOrganization; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs index fb2666cc1e..7ceedef08d 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs @@ -2,40 +2,41 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationUpgradeRequestModel +namespace Bit.Api.Models.Request.Organizations { - [StringLength(50)] - public string BusinessName { get; set; } - public PlanType PlanType { get; set; } - [Range(0, int.MaxValue)] - public int AdditionalSeats { get; set; } - [Range(0, 99)] - public short? AdditionalStorageGb { get; set; } - public bool PremiumAccessAddon { get; set; } - public string BillingAddressCountry { get; set; } - public string BillingAddressPostalCode { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } - - public OrganizationUpgrade ToOrganizationUpgrade() + public class OrganizationUpgradeRequestModel { - var orgUpgrade = new OrganizationUpgrade + [StringLength(50)] + public string BusinessName { get; set; } + public PlanType PlanType { get; set; } + [Range(0, int.MaxValue)] + public int AdditionalSeats { get; set; } + [Range(0, 99)] + public short? AdditionalStorageGb { get; set; } + public bool PremiumAccessAddon { get; set; } + public string BillingAddressCountry { get; set; } + public string BillingAddressPostalCode { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + + public OrganizationUpgrade ToOrganizationUpgrade() { - AdditionalSeats = AdditionalSeats, - AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(), - BusinessName = BusinessName, - Plan = PlanType, - PremiumAccessAddon = PremiumAccessAddon, - TaxInfo = new TaxInfo() + var orgUpgrade = new OrganizationUpgrade { - BillingAddressCountry = BillingAddressCountry, - BillingAddressPostalCode = BillingAddressPostalCode - } - }; + AdditionalSeats = AdditionalSeats, + AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(), + BusinessName = BusinessName, + Plan = PlanType, + PremiumAccessAddon = PremiumAccessAddon, + TaxInfo = new TaxInfo() + { + BillingAddressCountry = BillingAddressCountry, + BillingAddressPostalCode = BillingAddressPostalCode + } + }; - Keys?.ToOrganizationUpgrade(orgUpgrade); + Keys?.ToOrganizationUpgrade(orgUpgrade); - return orgUpgrade; + return orgUpgrade; + } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs b/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs index 4d6fcfedba..09cb7efb11 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs @@ -7,98 +7,99 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationUserInviteRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - [StrictEmailAddressList] - public IEnumerable Emails { get; set; } - [Required] - public OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public IEnumerable Collections { get; set; } - - public OrganizationUserInviteData ToData() + public class OrganizationUserInviteRequestModel { - return new OrganizationUserInviteData + [Required] + [StrictEmailAddressList] + public IEnumerable Emails { get; set; } + [Required] + public OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public IEnumerable Collections { get; set; } + + public OrganizationUserInviteData ToData() { - Emails = Emails, - Type = Type, - AccessAll = AccessAll, - Collections = Collections?.Select(c => c.ToSelectionReadOnly()), - Permissions = Permissions, - }; + return new OrganizationUserInviteData + { + Emails = Emails, + Type = Type, + AccessAll = AccessAll, + Collections = Collections?.Select(c => c.ToSelectionReadOnly()), + Permissions = Permissions, + }; + } } -} -public class OrganizationUserAcceptRequestModel -{ - [Required] - public string Token { get; set; } - // Used to auto-enroll in master password reset - public string ResetPasswordKey { get; set; } -} - -public class OrganizationUserConfirmRequestModel -{ - [Required] - public string Key { get; set; } -} - -public class OrganizationUserBulkConfirmRequestModelEntry -{ - [Required] - public Guid Id { get; set; } - [Required] - public string Key { get; set; } -} - -public class OrganizationUserBulkConfirmRequestModel -{ - [Required] - public IEnumerable Keys { get; set; } - - public Dictionary ToDictionary() + public class OrganizationUserAcceptRequestModel { - return Keys.ToDictionary(e => e.Id, e => e.Key); + [Required] + public string Token { get; set; } + // Used to auto-enroll in master password reset + public string ResetPasswordKey { get; set; } } -} -public class OrganizationUserUpdateRequestModel -{ - [Required] - public OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public IEnumerable Collections { get; set; } - - public OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + public class OrganizationUserConfirmRequestModel { - existingUser.Type = Type.Value; - existingUser.Permissions = JsonSerializer.Serialize(Permissions, new JsonSerializerOptions + [Required] + public string Key { get; set; } + } + + public class OrganizationUserBulkConfirmRequestModelEntry + { + [Required] + public Guid Id { get; set; } + [Required] + public string Key { get; set; } + } + + public class OrganizationUserBulkConfirmRequestModel + { + [Required] + public IEnumerable Keys { get; set; } + + public Dictionary ToDictionary() { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - existingUser.AccessAll = AccessAll; - return existingUser; + return Keys.ToDictionary(e => e.Id, e => e.Key); + } + } + + public class OrganizationUserUpdateRequestModel + { + [Required] + public OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public IEnumerable Collections { get; set; } + + public OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + { + existingUser.Type = Type.Value; + existingUser.Permissions = JsonSerializer.Serialize(Permissions, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + existingUser.AccessAll = AccessAll; + return existingUser; + } + } + + public class OrganizationUserUpdateGroupsRequestModel + { + [Required] + public IEnumerable GroupIds { get; set; } + } + + public class OrganizationUserResetPasswordEnrollmentRequestModel : SecretVerificationRequestModel + { + public string ResetPasswordKey { get; set; } + } + + public class OrganizationUserBulkRequestModel + { + [Required] + public IEnumerable Ids { get; set; } } } - -public class OrganizationUserUpdateGroupsRequestModel -{ - [Required] - public IEnumerable GroupIds { get; set; } -} - -public class OrganizationUserResetPasswordEnrollmentRequestModel : SecretVerificationRequestModel -{ - public string ResetPasswordKey { get; set; } -} - -public class OrganizationUserBulkRequestModel -{ - [Required] - public IEnumerable Ids { get; set; } -} diff --git a/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs index 571f69c1ef..4434a64c90 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs @@ -1,12 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationUserResetPasswordRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [Required] - public string Key { get; set; } + public class OrganizationUserResetPasswordRequestModel + { + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [Required] + public string Key { get; set; } + } } diff --git a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs index 71f6873800..9023cd6651 100644 --- a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs @@ -1,13 +1,14 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationVerifyBankRequestModel +namespace Bit.Api.Models.Request.Organizations { - [Required] - [Range(1, 99)] - public int? Amount1 { get; set; } - [Required] - [Range(1, 99)] - public int? Amount2 { get; set; } + public class OrganizationVerifyBankRequestModel + { + [Required] + [Range(1, 99)] + public int? Amount1 { get; set; } + [Required] + [Range(1, 99)] + public int? Amount2 { get; set; } + } } diff --git a/src/Api/Models/Request/PaymentRequestModel.cs b/src/Api/Models/Request/PaymentRequestModel.cs index 47e39b010d..b10b7df0c7 100644 --- a/src/Api/Models/Request/PaymentRequestModel.cs +++ b/src/Api/Models/Request/PaymentRequestModel.cs @@ -2,12 +2,13 @@ using Bit.Api.Models.Request.Organizations; using Bit.Core.Enums; -namespace Bit.Api.Models.Request; - -public class PaymentRequestModel : OrganizationTaxInfoUpdateRequestModel +namespace Bit.Api.Models.Request { - [Required] - public PaymentMethodType? PaymentMethodType { get; set; } - [Required] - public string PaymentToken { get; set; } + public class PaymentRequestModel : OrganizationTaxInfoUpdateRequestModel + { + [Required] + public PaymentMethodType? PaymentMethodType { get; set; } + [Required] + public string PaymentToken { get; set; } + } } diff --git a/src/Api/Models/Request/PolicyRequestModel.cs b/src/Api/Models/Request/PolicyRequestModel.cs index bc303cd40f..927b7fcc31 100644 --- a/src/Api/Models/Request/PolicyRequestModel.cs +++ b/src/Api/Models/Request/PolicyRequestModel.cs @@ -3,29 +3,30 @@ using System.Text.Json; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Request; - -public class PolicyRequestModel +namespace Bit.Api.Models.Request { - [Required] - public PolicyType? Type { get; set; } - [Required] - public bool? Enabled { get; set; } - public Dictionary Data { get; set; } - - public Policy ToPolicy(Guid orgId) + public class PolicyRequestModel { - return ToPolicy(new Policy + [Required] + public PolicyType? Type { get; set; } + [Required] + public bool? Enabled { get; set; } + public Dictionary Data { get; set; } + + public Policy ToPolicy(Guid orgId) { - Type = Type.Value, - OrganizationId = orgId - }); - } + return ToPolicy(new Policy + { + Type = Type.Value, + OrganizationId = orgId + }); + } - public Policy ToPolicy(Policy existingPolicy) - { - existingPolicy.Enabled = Enabled.GetValueOrDefault(); - existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; - return existingPolicy; + public Policy ToPolicy(Policy existingPolicy) + { + existingPolicy.Enabled = Enabled.GetValueOrDefault(); + existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; + return existingPolicy; + } } } diff --git a/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs b/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs index b6ea8759e0..d3d203f4dd 100644 --- a/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs @@ -1,12 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Providers; - -public class ProviderOrganizationAddRequestModel +namespace Bit.Api.Models.Request.Providers { - [Required] - public Guid OrganizationId { get; set; } + public class ProviderOrganizationAddRequestModel + { + [Required] + public Guid OrganizationId { get; set; } - [Required] - public string Key { get; set; } + [Required] + public string Key { get; set; } + } } diff --git a/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs b/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs index 7fead717be..cd796c11fb 100644 --- a/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs @@ -2,14 +2,15 @@ using Bit.Api.Models.Request.Organizations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Providers; - -public class ProviderOrganizationCreateRequestModel +namespace Bit.Api.Models.Request.Providers { - [Required] - [StringLength(256)] - [StrictEmailAddress] - public string ClientOwnerEmail { get; set; } - [Required] - public OrganizationCreateRequestModel OrganizationCreateRequest { get; set; } + public class ProviderOrganizationCreateRequestModel + { + [Required] + [StringLength(256)] + [StrictEmailAddress] + public string ClientOwnerEmail { get; set; } + [Required] + public OrganizationCreateRequestModel OrganizationCreateRequest { get; set; } + } } diff --git a/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs index 51191f947c..f0b5fffe91 100644 --- a/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -1,30 +1,31 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities.Provider; -namespace Bit.Api.Models.Request.Providers; - -public class ProviderSetupRequestModel +namespace Bit.Api.Models.Request.Providers { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [Required] - [StringLength(256)] - [EmailAddress] - public string BillingEmail { get; set; } - [Required] - public string Token { get; set; } - [Required] - public string Key { get; set; } - - public virtual Provider ToProvider(Provider provider) + public class ProviderSetupRequestModel { - provider.Name = Name; - provider.BusinessName = BusinessName; - provider.BillingEmail = BillingEmail; + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [Required] + [StringLength(256)] + [EmailAddress] + public string BillingEmail { get; set; } + [Required] + public string Token { get; set; } + [Required] + public string Key { get; set; } - return provider; + public virtual Provider ToProvider(Provider provider) + { + provider.Name = Name; + provider.BusinessName = BusinessName; + provider.BillingEmail = BillingEmail; + + return provider; + } } } diff --git a/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs b/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs index ceec796dc4..339ac3180a 100644 --- a/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs @@ -2,29 +2,30 @@ using Bit.Core.Entities.Provider; using Bit.Core.Settings; -namespace Bit.Api.Models.Request.Providers; - -public class ProviderUpdateRequestModel +namespace Bit.Api.Models.Request.Providers { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [EmailAddress] - [Required] - [StringLength(256)] - public string BillingEmail { get; set; } - - public virtual Provider ToProvider(Provider existingProvider, GlobalSettings globalSettings) + public class ProviderUpdateRequestModel { - if (!globalSettings.SelfHosted) + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [EmailAddress] + [Required] + [StringLength(256)] + public string BillingEmail { get; set; } + + public virtual Provider ToProvider(Provider existingProvider, GlobalSettings globalSettings) { - // These items come from the license file - existingProvider.Name = Name; - existingProvider.BusinessName = BusinessName; - existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + if (!globalSettings.SelfHosted) + { + // These items come from the license file + existingProvider.Name = Name; + existingProvider.BusinessName = BusinessName; + existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + } + return existingProvider; } - return existingProvider; } } diff --git a/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs b/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs index 9c451d8adc..cbce7bc904 100644 --- a/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs +++ b/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs @@ -3,62 +3,63 @@ using Bit.Core.Entities.Provider; using Bit.Core.Enums.Provider; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Providers; - -public class ProviderUserInviteRequestModel +namespace Bit.Api.Models.Request.Providers { - [Required] - [StrictEmailAddressList] - public IEnumerable Emails { get; set; } - [Required] - public ProviderUserType? Type { get; set; } -} - -public class ProviderUserAcceptRequestModel -{ - [Required] - public string Token { get; set; } -} - -public class ProviderUserConfirmRequestModel -{ - [Required] - public string Key { get; set; } -} - -public class ProviderUserBulkConfirmRequestModelEntry -{ - [Required] - public Guid Id { get; set; } - [Required] - public string Key { get; set; } -} - -public class ProviderUserBulkConfirmRequestModel -{ - [Required] - public IEnumerable Keys { get; set; } - - public Dictionary ToDictionary() + public class ProviderUserInviteRequestModel { - return Keys.ToDictionary(e => e.Id, e => e.Key); + [Required] + [StrictEmailAddressList] + public IEnumerable Emails { get; set; } + [Required] + public ProviderUserType? Type { get; set; } + } + + public class ProviderUserAcceptRequestModel + { + [Required] + public string Token { get; set; } + } + + public class ProviderUserConfirmRequestModel + { + [Required] + public string Key { get; set; } + } + + public class ProviderUserBulkConfirmRequestModelEntry + { + [Required] + public Guid Id { get; set; } + [Required] + public string Key { get; set; } + } + + public class ProviderUserBulkConfirmRequestModel + { + [Required] + public IEnumerable Keys { get; set; } + + public Dictionary ToDictionary() + { + return Keys.ToDictionary(e => e.Id, e => e.Key); + } + } + + public class ProviderUserUpdateRequestModel + { + [Required] + public ProviderUserType? Type { get; set; } + + public ProviderUser ToProviderUser(ProviderUser existingUser) + { + existingUser.Type = Type.Value; + return existingUser; + } + } + + public class ProviderUserBulkRequestModel + { + [Required] + public IEnumerable Ids { get; set; } } } - -public class ProviderUserUpdateRequestModel -{ - [Required] - public ProviderUserType? Type { get; set; } - - public ProviderUser ToProviderUser(ProviderUser existingUser) - { - existingUser.Type = Type.Value; - return existingUser; - } -} - -public class ProviderUserBulkRequestModel -{ - [Required] - public IEnumerable Ids { get; set; } -} diff --git a/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs b/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs index 5b82dc3e33..f5d2043d59 100644 --- a/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs +++ b/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs @@ -1,22 +1,23 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Request; - -public class SelectionReadOnlyRequestModel +namespace Bit.Api.Models.Request { - [Required] - public string Id { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } - - public SelectionReadOnly ToSelectionReadOnly() + public class SelectionReadOnlyRequestModel { - return new SelectionReadOnly + [Required] + public string Id { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } + + public SelectionReadOnly ToSelectionReadOnly() { - Id = new Guid(Id), - ReadOnly = ReadOnly, - HidePasswords = HidePasswords, - }; + return new SelectionReadOnly + { + Id = new Guid(Id), + ReadOnly = ReadOnly, + HidePasswords = HidePasswords, + }; + } } } diff --git a/src/Api/Models/Request/SendAccessRequestModel.cs b/src/Api/Models/Request/SendAccessRequestModel.cs index 2a8b3f40a2..3ee43985f1 100644 --- a/src/Api/Models/Request/SendAccessRequestModel.cs +++ b/src/Api/Models/Request/SendAccessRequestModel.cs @@ -1,9 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request; - -public class SendAccessRequestModel +namespace Bit.Api.Models.Request { - [StringLength(300)] - public string Password { get; set; } + public class SendAccessRequestModel + { + [StringLength(300)] + public string Password { get; set; } + } } diff --git a/src/Api/Models/Request/SendRequestModel.cs b/src/Api/Models/Request/SendRequestModel.cs index 51b15cba39..1e3182359a 100644 --- a/src/Api/Models/Request/SendRequestModel.cs +++ b/src/Api/Models/Request/SendRequestModel.cs @@ -7,134 +7,135 @@ using Bit.Core.Models.Data; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request; - -public class SendRequestModel +namespace Bit.Api.Models.Request { - public SendType Type { get; set; } - public long? FileLength { get; set; } = null; - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Notes { get; set; } - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Key { get; set; } - [Range(1, int.MaxValue)] - public int? MaxAccessCount { get; set; } - public DateTime? ExpirationDate { get; set; } - [Required] - public DateTime? DeletionDate { get; set; } - public SendFileModel File { get; set; } - public SendTextModel Text { get; set; } - [StringLength(1000)] - public string Password { get; set; } - [Required] - public bool? Disabled { get; set; } - public bool? HideEmail { get; set; } - - public Send ToSend(Guid userId, ISendService sendService) + public class SendRequestModel { - var send = new Send - { - Type = Type, - UserId = (Guid?)userId - }; - ToSend(send, sendService); - return send; - } + public SendType Type { get; set; } + public long? FileLength { get; set; } = null; + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Notes { get; set; } + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Key { get; set; } + [Range(1, int.MaxValue)] + public int? MaxAccessCount { get; set; } + public DateTime? ExpirationDate { get; set; } + [Required] + public DateTime? DeletionDate { get; set; } + public SendFileModel File { get; set; } + public SendTextModel Text { get; set; } + [StringLength(1000)] + public string Password { get; set; } + [Required] + public bool? Disabled { get; set; } + public bool? HideEmail { get; set; } - public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService) - { - var send = ToSendBase(new Send + public Send ToSend(Guid userId, ISendService sendService) { - Type = Type, - UserId = (Guid?)userId - }, sendService); - var data = new SendFileData(Name, Notes, fileName); - return (send, data); - } - - public Send ToSend(Send existingSend, ISendService sendService) - { - existingSend = ToSendBase(existingSend, sendService); - switch (existingSend.Type) - { - case SendType.File: - var fileData = JsonSerializer.Deserialize(existingSend.Data); - fileData.Name = Name; - fileData.Notes = Notes; - existingSend.Data = JsonSerializer.Serialize(fileData, JsonHelpers.IgnoreWritingNull); - break; - case SendType.Text: - existingSend.Data = JsonSerializer.Serialize(ToSendTextData(), JsonHelpers.IgnoreWritingNull); - break; - default: - throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); - } - return existingSend; - } - - public void ValidateCreation() - { - var now = DateTime.UtcNow; - // Add 1 minute for a sane buffer and client clock float - var nowPlus1Minute = now.AddMinutes(1); - if (ExpirationDate.HasValue && ExpirationDate.Value <= nowPlus1Minute) - { - throw new BadRequestException("You cannot create a Send that is already expired. " + - "Adjust the expiration date and try again."); - } - ValidateEdit(); - } - - public void ValidateEdit() - { - var now = DateTime.UtcNow; - // Add 1 minute for a sane buffer and client clock float - var nowPlus1Minute = now.AddMinutes(1); - if (DeletionDate.HasValue) - { - if (DeletionDate.Value <= nowPlus1Minute) + var send = new Send { - throw new BadRequestException("You cannot have a Send with a deletion date in the past. " + - "Adjust the deletion date and try again."); + Type = Type, + UserId = (Guid?)userId + }; + ToSend(send, sendService); + return send; + } + + public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService) + { + var send = ToSendBase(new Send + { + Type = Type, + UserId = (Guid?)userId + }, sendService); + var data = new SendFileData(Name, Notes, fileName); + return (send, data); + } + + public Send ToSend(Send existingSend, ISendService sendService) + { + existingSend = ToSendBase(existingSend, sendService); + switch (existingSend.Type) + { + case SendType.File: + var fileData = JsonSerializer.Deserialize(existingSend.Data); + fileData.Name = Name; + fileData.Notes = Notes; + existingSend.Data = JsonSerializer.Serialize(fileData, JsonHelpers.IgnoreWritingNull); + break; + case SendType.Text: + existingSend.Data = JsonSerializer.Serialize(ToSendTextData(), JsonHelpers.IgnoreWritingNull); + break; + default: + throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); } - if (DeletionDate.Value > now.AddDays(31)) + return existingSend; + } + + public void ValidateCreation() + { + var now = DateTime.UtcNow; + // Add 1 minute for a sane buffer and client clock float + var nowPlus1Minute = now.AddMinutes(1); + if (ExpirationDate.HasValue && ExpirationDate.Value <= nowPlus1Minute) { - throw new BadRequestException("You cannot have a Send with a deletion date that far " + - "into the future. Adjust the Deletion Date to a value less than 31 days from now " + - "and try again."); + throw new BadRequestException("You cannot create a Send that is already expired. " + + "Adjust the expiration date and try again."); + } + ValidateEdit(); + } + + public void ValidateEdit() + { + var now = DateTime.UtcNow; + // Add 1 minute for a sane buffer and client clock float + var nowPlus1Minute = now.AddMinutes(1); + if (DeletionDate.HasValue) + { + if (DeletionDate.Value <= nowPlus1Minute) + { + throw new BadRequestException("You cannot have a Send with a deletion date in the past. " + + "Adjust the deletion date and try again."); + } + if (DeletionDate.Value > now.AddDays(31)) + { + throw new BadRequestException("You cannot have a Send with a deletion date that far " + + "into the future. Adjust the Deletion Date to a value less than 31 days from now " + + "and try again."); + } } } - } - private Send ToSendBase(Send existingSend, ISendService sendService) - { - existingSend.Key = Key; - existingSend.ExpirationDate = ExpirationDate; - existingSend.DeletionDate = DeletionDate.Value; - existingSend.MaxAccessCount = MaxAccessCount; - if (!string.IsNullOrWhiteSpace(Password)) + private Send ToSendBase(Send existingSend, ISendService sendService) { - existingSend.Password = sendService.HashPassword(Password); + existingSend.Key = Key; + existingSend.ExpirationDate = ExpirationDate; + existingSend.DeletionDate = DeletionDate.Value; + existingSend.MaxAccessCount = MaxAccessCount; + if (!string.IsNullOrWhiteSpace(Password)) + { + existingSend.Password = sendService.HashPassword(Password); + } + existingSend.Disabled = Disabled.GetValueOrDefault(); + existingSend.HideEmail = HideEmail.GetValueOrDefault(); + return existingSend; + } + + private SendTextData ToSendTextData() + { + return new SendTextData(Name, Notes, Text.Text, Text.Hidden); } - existingSend.Disabled = Disabled.GetValueOrDefault(); - existingSend.HideEmail = HideEmail.GetValueOrDefault(); - return existingSend; } - private SendTextData ToSendTextData() + public class SendWithIdRequestModel : SendRequestModel { - return new SendTextData(Name, Notes, Text.Text, Text.Hidden); + [Required] + public Guid? Id { get; set; } } } - -public class SendWithIdRequestModel : SendRequestModel -{ - [Required] - public Guid? Id { get; set; } -} diff --git a/src/Api/Models/Request/TwoFactorRequestModels.cs b/src/Api/Models/Request/TwoFactorRequestModels.cs index 3ce42cdb98..4ccc209a68 100644 --- a/src/Api/Models/Request/TwoFactorRequestModels.cs +++ b/src/Api/Models/Request/TwoFactorRequestModels.cs @@ -5,269 +5,270 @@ using Bit.Core.Enums; using Bit.Core.Models; using Fido2NetLib; -namespace Bit.Api.Models.Request; - -public class UpdateTwoFactorAuthenticatorRequestModel : SecretVerificationRequestModel +namespace Bit.Api.Models.Request { - [Required] - [StringLength(50)] - public string Token { get; set; } - [Required] - [StringLength(50)] - public string Key { get; set; } - - public User ToUser(User extistingUser) + public class UpdateTwoFactorAuthenticatorRequestModel : SecretVerificationRequestModel { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.Authenticator)) - { - providers.Remove(TwoFactorProviderType.Authenticator); - } + [Required] + [StringLength(50)] + public string Token { get; set; } + [Required] + [StringLength(50)] + public string Key { get; set; } - providers.Add(TwoFactorProviderType.Authenticator, new TwoFactorProvider + public User ToUser(User extistingUser) { - MetaData = new Dictionary { ["Key"] = Key }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; - } -} - -public class UpdateTwoFactorDuoRequestModel : SecretVerificationRequestModel, IValidatableObject -{ - [Required] - [StringLength(50)] - public string IntegrationKey { get; set; } - [Required] - [StringLength(50)] - public string SecretKey { get; set; } - [Required] - [StringLength(50)] - public string Host { get; set; } - - public User ToUser(User extistingUser) - { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.Duo)) - { - providers.Remove(TwoFactorProviderType.Duo); - } - - providers.Add(TwoFactorProviderType.Duo, new TwoFactorProvider - { - MetaData = new Dictionary + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) { - ["SKey"] = SecretKey, - ["IKey"] = IntegrationKey, - ["Host"] = Host - }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; - } - - public Organization ToOrganization(Organization extistingOrg) - { - var providers = extistingOrg.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.OrganizationDuo)) - { - providers.Remove(TwoFactorProviderType.OrganizationDuo); - } - - providers.Add(TwoFactorProviderType.OrganizationDuo, new TwoFactorProvider - { - MetaData = new Dictionary + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.Authenticator)) { - ["SKey"] = SecretKey, - ["IKey"] = IntegrationKey, - ["Host"] = Host - }, - Enabled = true - }); - extistingOrg.SetTwoFactorProviders(providers); - return extistingOrg; - } + providers.Remove(TwoFactorProviderType.Authenticator); + } - public IEnumerable Validate(ValidationContext validationContext) - { - if (!Core.Utilities.Duo.DuoApi.ValidHost(Host)) - { - yield return new ValidationResult("Host is invalid.", new string[] { nameof(Host) }); - } - } -} - -public class UpdateTwoFactorYubicoOtpRequestModel : SecretVerificationRequestModel, IValidatableObject -{ - public string Key1 { get; set; } - public string Key2 { get; set; } - public string Key3 { get; set; } - public string Key4 { get; set; } - public string Key5 { get; set; } - [Required] - public bool? Nfc { get; set; } - - public User ToUser(User extistingUser) - { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.YubiKey)) - { - providers.Remove(TwoFactorProviderType.YubiKey); - } - - providers.Add(TwoFactorProviderType.YubiKey, new TwoFactorProvider - { - MetaData = new Dictionary + providers.Add(TwoFactorProviderType.Authenticator, new TwoFactorProvider { - ["Key1"] = FormatKey(Key1), - ["Key2"] = FormatKey(Key2), - ["Key3"] = FormatKey(Key3), - ["Key4"] = FormatKey(Key4), - ["Key5"] = FormatKey(Key5), - ["Nfc"] = Nfc.Value - }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; + MetaData = new Dictionary { ["Key"] = Key }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; + } } - private string FormatKey(string keyValue) + public class UpdateTwoFactorDuoRequestModel : SecretVerificationRequestModel, IValidatableObject { - if (string.IsNullOrWhiteSpace(keyValue)) + [Required] + [StringLength(50)] + public string IntegrationKey { get; set; } + [Required] + [StringLength(50)] + public string SecretKey { get; set; } + [Required] + [StringLength(50)] + public string Host { get; set; } + + public User ToUser(User extistingUser) { - return null; + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.Duo)) + { + providers.Remove(TwoFactorProviderType.Duo); + } + + providers.Add(TwoFactorProviderType.Duo, new TwoFactorProvider + { + MetaData = new Dictionary + { + ["SKey"] = SecretKey, + ["IKey"] = IntegrationKey, + ["Host"] = Host + }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; } - return keyValue.Substring(0, 12); + public Organization ToOrganization(Organization extistingOrg) + { + var providers = extistingOrg.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.OrganizationDuo)) + { + providers.Remove(TwoFactorProviderType.OrganizationDuo); + } + + providers.Add(TwoFactorProviderType.OrganizationDuo, new TwoFactorProvider + { + MetaData = new Dictionary + { + ["SKey"] = SecretKey, + ["IKey"] = IntegrationKey, + ["Host"] = Host + }, + Enabled = true + }); + extistingOrg.SetTwoFactorProviders(providers); + return extistingOrg; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!Core.Utilities.Duo.DuoApi.ValidHost(Host)) + { + yield return new ValidationResult("Host is invalid.", new string[] { nameof(Host) }); + } + } } - public IEnumerable Validate(ValidationContext validationContext) + public class UpdateTwoFactorYubicoOtpRequestModel : SecretVerificationRequestModel, IValidatableObject { - if (string.IsNullOrWhiteSpace(Key1) && string.IsNullOrWhiteSpace(Key2) && string.IsNullOrWhiteSpace(Key3) && - string.IsNullOrWhiteSpace(Key4) && string.IsNullOrWhiteSpace(Key5)) + public string Key1 { get; set; } + public string Key2 { get; set; } + public string Key3 { get; set; } + public string Key4 { get; set; } + public string Key5 { get; set; } + [Required] + public bool? Nfc { get; set; } + + public User ToUser(User extistingUser) { - yield return new ValidationResult("A key is required.", new string[] { nameof(Key1) }); + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.YubiKey)) + { + providers.Remove(TwoFactorProviderType.YubiKey); + } + + providers.Add(TwoFactorProviderType.YubiKey, new TwoFactorProvider + { + MetaData = new Dictionary + { + ["Key1"] = FormatKey(Key1), + ["Key2"] = FormatKey(Key2), + ["Key3"] = FormatKey(Key3), + ["Key4"] = FormatKey(Key4), + ["Key5"] = FormatKey(Key5), + ["Nfc"] = Nfc.Value + }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; } - if (!string.IsNullOrWhiteSpace(Key1) && Key1.Length < 12) + private string FormatKey(string keyValue) { - yield return new ValidationResult("Key 1 in invalid.", new string[] { nameof(Key1) }); + if (string.IsNullOrWhiteSpace(keyValue)) + { + return null; + } + + return keyValue.Substring(0, 12); } - if (!string.IsNullOrWhiteSpace(Key2) && Key2.Length < 12) + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Key 2 in invalid.", new string[] { nameof(Key2) }); - } + if (string.IsNullOrWhiteSpace(Key1) && string.IsNullOrWhiteSpace(Key2) && string.IsNullOrWhiteSpace(Key3) && + string.IsNullOrWhiteSpace(Key4) && string.IsNullOrWhiteSpace(Key5)) + { + yield return new ValidationResult("A key is required.", new string[] { nameof(Key1) }); + } - if (!string.IsNullOrWhiteSpace(Key3) && Key3.Length < 12) - { - yield return new ValidationResult("Key 3 in invalid.", new string[] { nameof(Key3) }); - } + if (!string.IsNullOrWhiteSpace(Key1) && Key1.Length < 12) + { + yield return new ValidationResult("Key 1 in invalid.", new string[] { nameof(Key1) }); + } - if (!string.IsNullOrWhiteSpace(Key4) && Key4.Length < 12) - { - yield return new ValidationResult("Key 4 in invalid.", new string[] { nameof(Key4) }); - } + if (!string.IsNullOrWhiteSpace(Key2) && Key2.Length < 12) + { + yield return new ValidationResult("Key 2 in invalid.", new string[] { nameof(Key2) }); + } - if (!string.IsNullOrWhiteSpace(Key5) && Key5.Length < 12) - { - yield return new ValidationResult("Key 5 in invalid.", new string[] { nameof(Key5) }); + if (!string.IsNullOrWhiteSpace(Key3) && Key3.Length < 12) + { + yield return new ValidationResult("Key 3 in invalid.", new string[] { nameof(Key3) }); + } + + if (!string.IsNullOrWhiteSpace(Key4) && Key4.Length < 12) + { + yield return new ValidationResult("Key 4 in invalid.", new string[] { nameof(Key4) }); + } + + if (!string.IsNullOrWhiteSpace(Key5) && Key5.Length < 12) + { + yield return new ValidationResult("Key 5 in invalid.", new string[] { nameof(Key5) }); + } } } -} -public class TwoFactorEmailRequestModel : SecretVerificationRequestModel -{ - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - - public string DeviceIdentifier { get; set; } - - public User ToUser(User extistingUser) + public class TwoFactorEmailRequestModel : SecretVerificationRequestModel { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.Email)) - { - providers.Remove(TwoFactorProviderType.Email); - } + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } - providers.Add(TwoFactorProviderType.Email, new TwoFactorProvider + public string DeviceIdentifier { get; set; } + + public User ToUser(User extistingUser) { - MetaData = new Dictionary { ["Email"] = Email.ToLowerInvariant() }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.Email)) + { + providers.Remove(TwoFactorProviderType.Email); + } + + providers.Add(TwoFactorProviderType.Email, new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = Email.ToLowerInvariant() }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; + } } -} -public class TwoFactorWebAuthnRequestModel : TwoFactorWebAuthnDeleteRequestModel -{ - [Required] - public AuthenticatorAttestationRawResponse DeviceResponse { get; set; } - public string Name { get; set; } -} - -public class TwoFactorWebAuthnDeleteRequestModel : SecretVerificationRequestModel, IValidatableObject -{ - [Required] - public int? Id { get; set; } - - public override IEnumerable Validate(ValidationContext validationContext) + public class TwoFactorWebAuthnRequestModel : TwoFactorWebAuthnDeleteRequestModel { - foreach (var validationResult in base.Validate(validationContext)) - { - yield return validationResult; - } + [Required] + public AuthenticatorAttestationRawResponse DeviceResponse { get; set; } + public string Name { get; set; } + } - if (!Id.HasValue || Id < 0 || Id > 5) + public class TwoFactorWebAuthnDeleteRequestModel : SecretVerificationRequestModel, IValidatableObject + { + [Required] + public int? Id { get; set; } + + public override IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult("Invalid Key Id", new string[] { nameof(Id) }); + foreach (var validationResult in base.Validate(validationContext)) + { + yield return validationResult; + } + + if (!Id.HasValue || Id < 0 || Id > 5) + { + yield return new ValidationResult("Invalid Key Id", new string[] { nameof(Id) }); + } } } -} -public class UpdateTwoFactorEmailRequestModel : TwoFactorEmailRequestModel -{ - [Required] - [StringLength(50)] - public string Token { get; set; } -} + public class UpdateTwoFactorEmailRequestModel : TwoFactorEmailRequestModel + { + [Required] + [StringLength(50)] + public string Token { get; set; } + } -public class TwoFactorProviderRequestModel : SecretVerificationRequestModel -{ - [Required] - public TwoFactorProviderType? Type { get; set; } -} + public class TwoFactorProviderRequestModel : SecretVerificationRequestModel + { + [Required] + public TwoFactorProviderType? Type { get; set; } + } -public class TwoFactorRecoveryRequestModel : TwoFactorEmailRequestModel -{ - [Required] - [StringLength(32)] - public string RecoveryCode { get; set; } + public class TwoFactorRecoveryRequestModel : TwoFactorEmailRequestModel + { + [Required] + [StringLength(32)] + public string RecoveryCode { get; set; } + } } diff --git a/src/Api/Models/Request/UpdateDomainsRequestModel.cs b/src/Api/Models/Request/UpdateDomainsRequestModel.cs index 47c5d05dec..0bc3f03859 100644 --- a/src/Api/Models/Request/UpdateDomainsRequestModel.cs +++ b/src/Api/Models/Request/UpdateDomainsRequestModel.cs @@ -2,18 +2,19 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Request; - -public class UpdateDomainsRequestModel +namespace Bit.Api.Models.Request { - public IEnumerable> EquivalentDomains { get; set; } - public IEnumerable ExcludedGlobalEquivalentDomains { get; set; } - - public User ToUser(User existingUser) + public class UpdateDomainsRequestModel { - existingUser.EquivalentDomains = EquivalentDomains != null ? JsonSerializer.Serialize(EquivalentDomains) : null; - existingUser.ExcludedGlobalEquivalentDomains = ExcludedGlobalEquivalentDomains != null ? - JsonSerializer.Serialize(ExcludedGlobalEquivalentDomains) : null; - return existingUser; + public IEnumerable> EquivalentDomains { get; set; } + public IEnumerable ExcludedGlobalEquivalentDomains { get; set; } + + public User ToUser(User existingUser) + { + existingUser.EquivalentDomains = EquivalentDomains != null ? JsonSerializer.Serialize(EquivalentDomains) : null; + existingUser.ExcludedGlobalEquivalentDomains = ExcludedGlobalEquivalentDomains != null ? + JsonSerializer.Serialize(ExcludedGlobalEquivalentDomains) : null; + return existingUser; + } } } diff --git a/src/Api/Models/Response/ApiKeyResponseModel.cs b/src/Api/Models/Response/ApiKeyResponseModel.cs index 0661b17bc5..3987e22d50 100644 --- a/src/Api/Models/Response/ApiKeyResponseModel.cs +++ b/src/Api/Models/Response/ApiKeyResponseModel.cs @@ -1,32 +1,33 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class ApiKeyResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public ApiKeyResponseModel(OrganizationApiKey organizationApiKey, string obj = "apiKey") - : base(obj) + public class ApiKeyResponseModel : ResponseModel { - if (organizationApiKey == null) + public ApiKeyResponseModel(OrganizationApiKey organizationApiKey, string obj = "apiKey") + : base(obj) { - throw new ArgumentNullException(nameof(organizationApiKey)); + if (organizationApiKey == null) + { + throw new ArgumentNullException(nameof(organizationApiKey)); + } + ApiKey = organizationApiKey.ApiKey; + RevisionDate = organizationApiKey.RevisionDate; } - ApiKey = organizationApiKey.ApiKey; - RevisionDate = organizationApiKey.RevisionDate; - } - public ApiKeyResponseModel(User user, string obj = "apiKey") - : base(obj) - { - if (user == null) + public ApiKeyResponseModel(User user, string obj = "apiKey") + : base(obj) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + ApiKey = user.ApiKey; + RevisionDate = user.RevisionDate; } - ApiKey = user.ApiKey; - RevisionDate = user.RevisionDate; - } - public string ApiKey { get; set; } - public DateTime RevisionDate { get; set; } + public string ApiKey { get; set; } + public DateTime RevisionDate { get; set; } + } } diff --git a/src/Api/Models/Response/AttachmentResponseModel.cs b/src/Api/Models/Response/AttachmentResponseModel.cs index 018cdd6504..5659cb5352 100644 --- a/src/Api/Models/Response/AttachmentResponseModel.cs +++ b/src/Api/Models/Response/AttachmentResponseModel.cs @@ -5,48 +5,49 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response; - -public class AttachmentResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public AttachmentResponseModel(AttachmentResponseData data) : base("attachment") + public class AttachmentResponseModel : ResponseModel { - Id = data.Id; - Url = data.Url; - FileName = data.Data.FileName; - Key = data.Data.Key; - Size = data.Data.Size; - SizeName = CoreHelpers.ReadableBytesSize(data.Data.Size); - } - - public AttachmentResponseModel(string id, CipherAttachment.MetaData data, Cipher cipher, - IGlobalSettings globalSettings) - : base("attachment") - { - Id = id; - Url = $"{globalSettings.Attachment.BaseUrl}/{cipher.Id}/{id}"; - FileName = data.FileName; - Key = data.Key; - Size = data.Size; - SizeName = CoreHelpers.ReadableBytesSize(data.Size); - } - - public string Id { get; set; } - public string Url { get; set; } - public string FileName { get; set; } - public string Key { get; set; } - [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] - public long Size { get; set; } - public string SizeName { get; set; } - - public static IEnumerable FromCipher(Cipher cipher, IGlobalSettings globalSettings) - { - var attachments = cipher.GetAttachments(); - if (attachments == null) + public AttachmentResponseModel(AttachmentResponseData data) : base("attachment") { - return null; + Id = data.Id; + Url = data.Url; + FileName = data.Data.FileName; + Key = data.Data.Key; + Size = data.Data.Size; + SizeName = CoreHelpers.ReadableBytesSize(data.Data.Size); } - return attachments.Select(a => new AttachmentResponseModel(a.Key, a.Value, cipher, globalSettings)); + public AttachmentResponseModel(string id, CipherAttachment.MetaData data, Cipher cipher, + IGlobalSettings globalSettings) + : base("attachment") + { + Id = id; + Url = $"{globalSettings.Attachment.BaseUrl}/{cipher.Id}/{id}"; + FileName = data.FileName; + Key = data.Key; + Size = data.Size; + SizeName = CoreHelpers.ReadableBytesSize(data.Size); + } + + public string Id { get; set; } + public string Url { get; set; } + public string FileName { get; set; } + public string Key { get; set; } + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] + public long Size { get; set; } + public string SizeName { get; set; } + + public static IEnumerable FromCipher(Cipher cipher, IGlobalSettings globalSettings) + { + var attachments = cipher.GetAttachments(); + if (attachments == null) + { + return null; + } + + return attachments.Select(a => new AttachmentResponseModel(a.Key, a.Value, cipher, globalSettings)); + } } } diff --git a/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs b/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs index 1c9a5d2a72..7acc5715c0 100644 --- a/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs +++ b/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs @@ -1,15 +1,16 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class AttachmentUploadDataResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public string AttachmentId { get; set; } - public string Url { get; set; } - public FileUploadType FileUploadType { get; set; } - public CipherResponseModel CipherResponse { get; set; } - public CipherMiniResponseModel CipherMiniResponse { get; set; } + public class AttachmentUploadDataResponseModel : ResponseModel + { + public string AttachmentId { get; set; } + public string Url { get; set; } + public FileUploadType FileUploadType { get; set; } + public CipherResponseModel CipherResponse { get; set; } + public CipherMiniResponseModel CipherMiniResponse { get; set; } - public AttachmentUploadDataResponseModel() : base("attachment-fileUpload") { } + public AttachmentUploadDataResponseModel() : base("attachment-fileUpload") { } + } } diff --git a/src/Api/Models/Response/BillingHistoryResponseModel.cs b/src/Api/Models/Response/BillingHistoryResponseModel.cs index e0e85f0699..892a6530de 100644 --- a/src/Api/Models/Response/BillingHistoryResponseModel.cs +++ b/src/Api/Models/Response/BillingHistoryResponseModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response; - -public class BillingHistoryResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public BillingHistoryResponseModel(BillingInfo billing) - : base("billingHistory") + public class BillingHistoryResponseModel : ResponseModel { - Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); - Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); + public BillingHistoryResponseModel(BillingInfo billing) + : base("billingHistory") + { + Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); + Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); + } + public IEnumerable Invoices { get; set; } + public IEnumerable Transactions { get; set; } } - public IEnumerable Invoices { get; set; } - public IEnumerable Transactions { get; set; } } diff --git a/src/Api/Models/Response/BillingPaymentResponseModel.cs b/src/Api/Models/Response/BillingPaymentResponseModel.cs index dcc0046133..12c14c4d6e 100644 --- a/src/Api/Models/Response/BillingPaymentResponseModel.cs +++ b/src/Api/Models/Response/BillingPaymentResponseModel.cs @@ -1,17 +1,18 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response; - -public class BillingPaymentResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public BillingPaymentResponseModel(BillingInfo billing) - : base("billingPayment") + public class BillingPaymentResponseModel : ResponseModel { - Balance = billing.Balance; - PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; - } + public BillingPaymentResponseModel(BillingInfo billing) + : base("billingPayment") + { + Balance = billing.Balance; + PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; + } - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } + public decimal Balance { get; set; } + public BillingSource PaymentSource { get; set; } + } } diff --git a/src/Api/Models/Response/BillingResponseModel.cs b/src/Api/Models/Response/BillingResponseModel.cs index c5232242f0..6e2930e1b8 100644 --- a/src/Api/Models/Response/BillingResponseModel.cs +++ b/src/Api/Models/Response/BillingResponseModel.cs @@ -2,81 +2,82 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response; - -public class BillingResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public BillingResponseModel(BillingInfo billing) - : base("billing") + public class BillingResponseModel : ResponseModel { - Balance = billing.Balance; - PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; - Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); - Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); + public BillingResponseModel(BillingInfo billing) + : base("billing") + { + Balance = billing.Balance; + PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; + Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); + Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); + } + + public decimal Balance { get; set; } + public BillingSource PaymentSource { get; set; } + public IEnumerable Invoices { get; set; } + public IEnumerable Transactions { get; set; } } - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } - public IEnumerable Invoices { get; set; } - public IEnumerable Transactions { get; set; } -} - -public class BillingSource -{ - public BillingSource(BillingInfo.BillingSource source) + public class BillingSource { - Type = source.Type; - CardBrand = source.CardBrand; - Description = source.Description; - NeedsVerification = source.NeedsVerification; + public BillingSource(BillingInfo.BillingSource source) + { + Type = source.Type; + CardBrand = source.CardBrand; + Description = source.Description; + NeedsVerification = source.NeedsVerification; + } + + public PaymentMethodType Type { get; set; } + public string CardBrand { get; set; } + public string Description { get; set; } + public bool NeedsVerification { get; set; } } - public PaymentMethodType Type { get; set; } - public string CardBrand { get; set; } - public string Description { get; set; } - public bool NeedsVerification { get; set; } -} - -public class BillingInvoice -{ - public BillingInvoice(BillingInfo.BillingInvoice inv) + public class BillingInvoice { - Amount = inv.Amount; - Date = inv.Date; - Url = inv.Url; - PdfUrl = inv.PdfUrl; - Number = inv.Number; - Paid = inv.Paid; + public BillingInvoice(BillingInfo.BillingInvoice inv) + { + Amount = inv.Amount; + Date = inv.Date; + Url = inv.Url; + PdfUrl = inv.PdfUrl; + Number = inv.Number; + Paid = inv.Paid; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } + public string Url { get; set; } + public string PdfUrl { get; set; } + public string Number { get; set; } + public bool Paid { get; set; } } - public decimal Amount { get; set; } - public DateTime? Date { get; set; } - public string Url { get; set; } - public string PdfUrl { get; set; } - public string Number { get; set; } - public bool Paid { get; set; } -} - -public class BillingTransaction -{ - public BillingTransaction(BillingInfo.BillingTransaction transaction) + public class BillingTransaction { - CreatedDate = transaction.CreatedDate; - Amount = transaction.Amount; - Refunded = transaction.Refunded; - RefundedAmount = transaction.RefundedAmount; - PartiallyRefunded = transaction.PartiallyRefunded; - Type = transaction.Type; - PaymentMethodType = transaction.PaymentMethodType; - Details = transaction.Details; - } + public BillingTransaction(BillingInfo.BillingTransaction transaction) + { + CreatedDate = transaction.CreatedDate; + Amount = transaction.Amount; + Refunded = transaction.Refunded; + RefundedAmount = transaction.RefundedAmount; + PartiallyRefunded = transaction.PartiallyRefunded; + Type = transaction.Type; + PaymentMethodType = transaction.PaymentMethodType; + Details = transaction.Details; + } - public DateTime CreatedDate { get; set; } - public decimal Amount { get; set; } - public bool? Refunded { get; set; } - public bool? PartiallyRefunded { get; set; } - public decimal? RefundedAmount { get; set; } - public TransactionType Type { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string Details { get; set; } + public DateTime CreatedDate { get; set; } + public decimal Amount { get; set; } + public bool? Refunded { get; set; } + public bool? PartiallyRefunded { get; set; } + public decimal? RefundedAmount { get; set; } + public TransactionType Type { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string Details { get; set; } + } } diff --git a/src/Api/Models/Response/CipherResponseModel.cs b/src/Api/Models/Response/CipherResponseModel.cs index 9b0d95894b..5edc27145b 100644 --- a/src/Api/Models/Response/CipherResponseModel.cs +++ b/src/Api/Models/Response/CipherResponseModel.cs @@ -6,141 +6,142 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class CipherMiniResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public CipherMiniResponseModel(Cipher cipher, IGlobalSettings globalSettings, bool orgUseTotp, string obj = "cipherMini") - : base(obj) + public class CipherMiniResponseModel : ResponseModel { - if (cipher == null) + public CipherMiniResponseModel(Cipher cipher, IGlobalSettings globalSettings, bool orgUseTotp, string obj = "cipherMini") + : base(obj) { - throw new ArgumentNullException(nameof(cipher)); + if (cipher == null) + { + throw new ArgumentNullException(nameof(cipher)); + } + + Id = cipher.Id.ToString(); + Type = cipher.Type; + + CipherData cipherData; + switch (cipher.Type) + { + case CipherType.Login: + var loginData = JsonSerializer.Deserialize(cipher.Data); + cipherData = loginData; + Data = loginData; + Login = new CipherLoginModel(loginData); + break; + case CipherType.SecureNote: + var secureNoteData = JsonSerializer.Deserialize(cipher.Data); + Data = secureNoteData; + cipherData = secureNoteData; + SecureNote = new CipherSecureNoteModel(secureNoteData); + break; + case CipherType.Card: + var cardData = JsonSerializer.Deserialize(cipher.Data); + Data = cardData; + cipherData = cardData; + Card = new CipherCardModel(cardData); + break; + case CipherType.Identity: + var identityData = JsonSerializer.Deserialize(cipher.Data); + Data = identityData; + cipherData = identityData; + Identity = new CipherIdentityModel(identityData); + break; + default: + throw new ArgumentException("Unsupported " + nameof(Type) + "."); + } + + Name = cipherData.Name; + Notes = cipherData.Notes; + Fields = cipherData.Fields?.Select(f => new CipherFieldModel(f)); + PasswordHistory = cipherData.PasswordHistory?.Select(ph => new CipherPasswordHistoryModel(ph)); + RevisionDate = cipher.RevisionDate; + OrganizationId = cipher.OrganizationId?.ToString(); + Attachments = AttachmentResponseModel.FromCipher(cipher, globalSettings); + OrganizationUseTotp = orgUseTotp; + DeletedDate = cipher.DeletedDate; + Reprompt = cipher.Reprompt.GetValueOrDefault(CipherRepromptType.None); } - Id = cipher.Id.ToString(); - Type = cipher.Type; - - CipherData cipherData; - switch (cipher.Type) - { - case CipherType.Login: - var loginData = JsonSerializer.Deserialize(cipher.Data); - cipherData = loginData; - Data = loginData; - Login = new CipherLoginModel(loginData); - break; - case CipherType.SecureNote: - var secureNoteData = JsonSerializer.Deserialize(cipher.Data); - Data = secureNoteData; - cipherData = secureNoteData; - SecureNote = new CipherSecureNoteModel(secureNoteData); - break; - case CipherType.Card: - var cardData = JsonSerializer.Deserialize(cipher.Data); - Data = cardData; - cipherData = cardData; - Card = new CipherCardModel(cardData); - break; - case CipherType.Identity: - var identityData = JsonSerializer.Deserialize(cipher.Data); - Data = identityData; - cipherData = identityData; - Identity = new CipherIdentityModel(identityData); - break; - default: - throw new ArgumentException("Unsupported " + nameof(Type) + "."); - } - - Name = cipherData.Name; - Notes = cipherData.Notes; - Fields = cipherData.Fields?.Select(f => new CipherFieldModel(f)); - PasswordHistory = cipherData.PasswordHistory?.Select(ph => new CipherPasswordHistoryModel(ph)); - RevisionDate = cipher.RevisionDate; - OrganizationId = cipher.OrganizationId?.ToString(); - Attachments = AttachmentResponseModel.FromCipher(cipher, globalSettings); - OrganizationUseTotp = orgUseTotp; - DeletedDate = cipher.DeletedDate; - Reprompt = cipher.Reprompt.GetValueOrDefault(CipherRepromptType.None); + public string Id { get; set; } + public string OrganizationId { get; set; } + public CipherType Type { get; set; } + public dynamic Data { get; set; } + public string Name { get; set; } + public string Notes { get; set; } + public CipherLoginModel Login { get; set; } + public CipherCardModel Card { get; set; } + public CipherIdentityModel Identity { get; set; } + public CipherSecureNoteModel SecureNote { get; set; } + public IEnumerable Fields { get; set; } + public IEnumerable PasswordHistory { get; set; } + public IEnumerable Attachments { get; set; } + public bool OrganizationUseTotp { get; set; } + public DateTime RevisionDate { get; set; } + public DateTime? DeletedDate { get; set; } + public CipherRepromptType Reprompt { get; set; } } - public string Id { get; set; } - public string OrganizationId { get; set; } - public CipherType Type { get; set; } - public dynamic Data { get; set; } - public string Name { get; set; } - public string Notes { get; set; } - public CipherLoginModel Login { get; set; } - public CipherCardModel Card { get; set; } - public CipherIdentityModel Identity { get; set; } - public CipherSecureNoteModel SecureNote { get; set; } - public IEnumerable Fields { get; set; } - public IEnumerable PasswordHistory { get; set; } - public IEnumerable Attachments { get; set; } - public bool OrganizationUseTotp { get; set; } - public DateTime RevisionDate { get; set; } - public DateTime? DeletedDate { get; set; } - public CipherRepromptType Reprompt { get; set; } -} - -public class CipherResponseModel : CipherMiniResponseModel -{ - public CipherResponseModel(CipherDetails cipher, IGlobalSettings globalSettings, string obj = "cipher") - : base(cipher, globalSettings, cipher.OrganizationUseTotp, obj) - { - FolderId = cipher.FolderId?.ToString(); - Favorite = cipher.Favorite; - Edit = cipher.Edit; - ViewPassword = cipher.ViewPassword; - } - - public string FolderId { get; set; } - public bool Favorite { get; set; } - public bool Edit { get; set; } - public bool ViewPassword { get; set; } -} - -public class CipherDetailsResponseModel : CipherResponseModel -{ - public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, - IDictionary> collectionCiphers, string obj = "cipherDetails") - : base(cipher, globalSettings, obj) - { - if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) - { - CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); - } - else - { - CollectionIds = new Guid[] { }; - } - } - - public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, - IEnumerable collectionCiphers, string obj = "cipherDetails") - : base(cipher, globalSettings, obj) - { - CollectionIds = collectionCiphers?.Select(c => c.CollectionId) ?? new List(); - } - - public IEnumerable CollectionIds { get; set; } -} - -public class CipherMiniDetailsResponseModel : CipherMiniResponseModel -{ - public CipherMiniDetailsResponseModel(Cipher cipher, GlobalSettings globalSettings, - IDictionary> collectionCiphers, bool orgUseTotp, string obj = "cipherMiniDetails") - : base(cipher, globalSettings, orgUseTotp, obj) - { - if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) - { - CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); - } - else - { - CollectionIds = new Guid[] { }; - } - } - - public IEnumerable CollectionIds { get; set; } + public class CipherResponseModel : CipherMiniResponseModel + { + public CipherResponseModel(CipherDetails cipher, IGlobalSettings globalSettings, string obj = "cipher") + : base(cipher, globalSettings, cipher.OrganizationUseTotp, obj) + { + FolderId = cipher.FolderId?.ToString(); + Favorite = cipher.Favorite; + Edit = cipher.Edit; + ViewPassword = cipher.ViewPassword; + } + + public string FolderId { get; set; } + public bool Favorite { get; set; } + public bool Edit { get; set; } + public bool ViewPassword { get; set; } + } + + public class CipherDetailsResponseModel : CipherResponseModel + { + public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, + IDictionary> collectionCiphers, string obj = "cipherDetails") + : base(cipher, globalSettings, obj) + { + if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) + { + CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); + } + else + { + CollectionIds = new Guid[] { }; + } + } + + public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, + IEnumerable collectionCiphers, string obj = "cipherDetails") + : base(cipher, globalSettings, obj) + { + CollectionIds = collectionCiphers?.Select(c => c.CollectionId) ?? new List(); + } + + public IEnumerable CollectionIds { get; set; } + } + + public class CipherMiniDetailsResponseModel : CipherMiniResponseModel + { + public CipherMiniDetailsResponseModel(Cipher cipher, GlobalSettings globalSettings, + IDictionary> collectionCiphers, bool orgUseTotp, string obj = "cipherMiniDetails") + : base(cipher, globalSettings, orgUseTotp, obj) + { + if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) + { + CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); + } + else + { + CollectionIds = new Guid[] { }; + } + } + + public IEnumerable CollectionIds { get; set; } + } } diff --git a/src/Api/Models/Response/CollectionResponseModel.cs b/src/Api/Models/Response/CollectionResponseModel.cs index aa56402c0a..5ac923a9dc 100644 --- a/src/Api/Models/Response/CollectionResponseModel.cs +++ b/src/Api/Models/Response/CollectionResponseModel.cs @@ -2,50 +2,51 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class CollectionResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public CollectionResponseModel(Collection collection, string obj = "collection") - : base(obj) + public class CollectionResponseModel : ResponseModel { - if (collection == null) + public CollectionResponseModel(Collection collection, string obj = "collection") + : base(obj) { - throw new ArgumentNullException(nameof(collection)); + if (collection == null) + { + throw new ArgumentNullException(nameof(collection)); + } + + Id = collection.Id.ToString(); + OrganizationId = collection.OrganizationId.ToString(); + Name = collection.Name; + ExternalId = collection.ExternalId; } - Id = collection.Id.ToString(); - OrganizationId = collection.OrganizationId.ToString(); - Name = collection.Name; - ExternalId = collection.ExternalId; + public string Id { get; set; } + public string OrganizationId { get; set; } + public string Name { get; set; } + public string ExternalId { get; set; } } - public string Id { get; set; } - public string OrganizationId { get; set; } - public string Name { get; set; } - public string ExternalId { get; set; } -} - -public class CollectionDetailsResponseModel : CollectionResponseModel -{ - public CollectionDetailsResponseModel(CollectionDetails collectionDetails) - : base(collectionDetails, "collectionDetails") + public class CollectionDetailsResponseModel : CollectionResponseModel { - ReadOnly = collectionDetails.ReadOnly; - HidePasswords = collectionDetails.HidePasswords; + public CollectionDetailsResponseModel(CollectionDetails collectionDetails) + : base(collectionDetails, "collectionDetails") + { + ReadOnly = collectionDetails.ReadOnly; + HidePasswords = collectionDetails.HidePasswords; + } + + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } -} - -public class CollectionGroupDetailsResponseModel : CollectionResponseModel -{ - public CollectionGroupDetailsResponseModel(Collection collection, IEnumerable groups) - : base(collection, "collectionGroupDetails") + public class CollectionGroupDetailsResponseModel : CollectionResponseModel { - Groups = groups.Select(g => new SelectionReadOnlyResponseModel(g)); - } + public CollectionGroupDetailsResponseModel(Collection collection, IEnumerable groups) + : base(collection, "collectionGroupDetails") + { + Groups = groups.Select(g => new SelectionReadOnlyResponseModel(g)); + } - public IEnumerable Groups { get; set; } + public IEnumerable Groups { get; set; } + } } diff --git a/src/Api/Models/Response/DeviceResponseModel.cs b/src/Api/Models/Response/DeviceResponseModel.cs index e88dff9fa0..e25562cbbb 100644 --- a/src/Api/Models/Response/DeviceResponseModel.cs +++ b/src/Api/Models/Response/DeviceResponseModel.cs @@ -2,28 +2,29 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class DeviceResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public DeviceResponseModel(Device device) - : base("device") + public class DeviceResponseModel : ResponseModel { - if (device == null) + public DeviceResponseModel(Device device) + : base("device") { - throw new ArgumentNullException(nameof(device)); + if (device == null) + { + throw new ArgumentNullException(nameof(device)); + } + + Id = device.Id.ToString(); + Name = device.Name; + Type = device.Type; + Identifier = device.Identifier; + CreationDate = device.CreationDate; } - Id = device.Id.ToString(); - Name = device.Name; - Type = device.Type; - Identifier = device.Identifier; - CreationDate = device.CreationDate; + public string Id { get; set; } + public string Name { get; set; } + public DeviceType Type { get; set; } + public string Identifier { get; set; } + public DateTime CreationDate { get; set; } } - - public string Id { get; set; } - public string Name { get; set; } - public DeviceType Type { get; set; } - public string Identifier { get; set; } - public DateTime CreationDate { get; set; } } diff --git a/src/Api/Models/Response/DeviceVerificationResponseModel.cs b/src/Api/Models/Response/DeviceVerificationResponseModel.cs index 0358ff7771..1f47547615 100644 --- a/src/Api/Models/Response/DeviceVerificationResponseModel.cs +++ b/src/Api/Models/Response/DeviceVerificationResponseModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class DeviceVerificationResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public DeviceVerificationResponseModel(bool isDeviceVerificationSectionEnabled, bool unknownDeviceVerificationEnabled) - : base("deviceVerificationSettings") + public class DeviceVerificationResponseModel : ResponseModel { - IsDeviceVerificationSectionEnabled = isDeviceVerificationSectionEnabled; - UnknownDeviceVerificationEnabled = unknownDeviceVerificationEnabled; - } + public DeviceVerificationResponseModel(bool isDeviceVerificationSectionEnabled, bool unknownDeviceVerificationEnabled) + : base("deviceVerificationSettings") + { + IsDeviceVerificationSectionEnabled = isDeviceVerificationSectionEnabled; + UnknownDeviceVerificationEnabled = unknownDeviceVerificationEnabled; + } - public bool IsDeviceVerificationSectionEnabled { get; } - public bool UnknownDeviceVerificationEnabled { get; } + public bool IsDeviceVerificationSectionEnabled { get; } + public bool UnknownDeviceVerificationEnabled { get; } + } } diff --git a/src/Api/Models/Response/DomainsResponseModel.cs b/src/Api/Models/Response/DomainsResponseModel.cs index b7f1028458..fd6ea46b6c 100644 --- a/src/Api/Models/Response/DomainsResponseModel.cs +++ b/src/Api/Models/Response/DomainsResponseModel.cs @@ -3,53 +3,54 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class DomainsResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public DomainsResponseModel(User user, bool excluded = true) - : base("domains") + public class DomainsResponseModel : ResponseModel { - if (user == null) + public DomainsResponseModel(User user, bool excluded = true) + : base("domains") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + EquivalentDomains = user.EquivalentDomains != null ? + JsonSerializer.Deserialize>>(user.EquivalentDomains) : null; + + var excludedGlobalEquivalentDomains = user.ExcludedGlobalEquivalentDomains != null ? + JsonSerializer.Deserialize>(user.ExcludedGlobalEquivalentDomains) : + new List(); + var globalDomains = new List(); + var domainsToInclude = excluded ? Core.Utilities.StaticStore.GlobalDomains : + Core.Utilities.StaticStore.GlobalDomains.Where(d => !excludedGlobalEquivalentDomains.Contains(d.Key)); + foreach (var domain in domainsToInclude) + { + globalDomains.Add(new GlobalDomains(domain.Key, domain.Value, excludedGlobalEquivalentDomains, excluded)); + } + GlobalEquivalentDomains = !globalDomains.Any() ? null : globalDomains; } - EquivalentDomains = user.EquivalentDomains != null ? - JsonSerializer.Deserialize>>(user.EquivalentDomains) : null; + public IEnumerable> EquivalentDomains { get; set; } + public IEnumerable GlobalEquivalentDomains { get; set; } - var excludedGlobalEquivalentDomains = user.ExcludedGlobalEquivalentDomains != null ? - JsonSerializer.Deserialize>(user.ExcludedGlobalEquivalentDomains) : - new List(); - var globalDomains = new List(); - var domainsToInclude = excluded ? Core.Utilities.StaticStore.GlobalDomains : - Core.Utilities.StaticStore.GlobalDomains.Where(d => !excludedGlobalEquivalentDomains.Contains(d.Key)); - foreach (var domain in domainsToInclude) + + public class GlobalDomains { - globalDomains.Add(new GlobalDomains(domain.Key, domain.Value, excludedGlobalEquivalentDomains, excluded)); + public GlobalDomains( + GlobalEquivalentDomainsType globalDomain, + IEnumerable domains, + IEnumerable excludedDomains, + bool excluded) + { + Type = globalDomain; + Domains = domains; + Excluded = excluded && (excludedDomains?.Contains(globalDomain) ?? false); + } + + public GlobalEquivalentDomainsType Type { get; set; } + public IEnumerable Domains { get; set; } + public bool Excluded { get; set; } } - GlobalEquivalentDomains = !globalDomains.Any() ? null : globalDomains; - } - - public IEnumerable> EquivalentDomains { get; set; } - public IEnumerable GlobalEquivalentDomains { get; set; } - - - public class GlobalDomains - { - public GlobalDomains( - GlobalEquivalentDomainsType globalDomain, - IEnumerable domains, - IEnumerable excludedDomains, - bool excluded) - { - Type = globalDomain; - Domains = domains; - Excluded = excluded && (excludedDomains?.Contains(globalDomain) ?? false); - } - - public GlobalEquivalentDomainsType Type { get; set; } - public IEnumerable Domains { get; set; } - public bool Excluded { get; set; } } } diff --git a/src/Api/Models/Response/EmergencyAccessResponseModel.cs b/src/Api/Models/Response/EmergencyAccessResponseModel.cs index ec8dbd1ee0..16d255e923 100644 --- a/src/Api/Models/Response/EmergencyAccessResponseModel.cs +++ b/src/Api/Models/Response/EmergencyAccessResponseModel.cs @@ -5,113 +5,114 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class EmergencyAccessResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public EmergencyAccessResponseModel(EmergencyAccess emergencyAccess, string obj = "emergencyAccess") : base(obj) + public class EmergencyAccessResponseModel : ResponseModel { - if (emergencyAccess == null) + public EmergencyAccessResponseModel(EmergencyAccess emergencyAccess, string obj = "emergencyAccess") : base(obj) { - throw new ArgumentNullException(nameof(emergencyAccess)); + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + Id = emergencyAccess.Id.ToString(); + Status = emergencyAccess.Status; + Type = emergencyAccess.Type; + WaitTimeDays = emergencyAccess.WaitTimeDays; } - Id = emergencyAccess.Id.ToString(); - Status = emergencyAccess.Status; - Type = emergencyAccess.Type; - WaitTimeDays = emergencyAccess.WaitTimeDays; - } - - public EmergencyAccessResponseModel(EmergencyAccessDetails emergencyAccess, string obj = "emergencyAccess") : base(obj) - { - if (emergencyAccess == null) + public EmergencyAccessResponseModel(EmergencyAccessDetails emergencyAccess, string obj = "emergencyAccess") : base(obj) { - throw new ArgumentNullException(nameof(emergencyAccess)); + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + Id = emergencyAccess.Id.ToString(); + Status = emergencyAccess.Status; + Type = emergencyAccess.Type; + WaitTimeDays = emergencyAccess.WaitTimeDays; } - Id = emergencyAccess.Id.ToString(); - Status = emergencyAccess.Status; - Type = emergencyAccess.Type; - WaitTimeDays = emergencyAccess.WaitTimeDays; + public string Id { get; private set; } + public EmergencyAccessStatusType Status { get; private set; } + public EmergencyAccessType Type { get; private set; } + public int WaitTimeDays { get; private set; } } - public string Id { get; private set; } - public EmergencyAccessStatusType Status { get; private set; } - public EmergencyAccessType Type { get; private set; } - public int WaitTimeDays { get; private set; } -} - -public class EmergencyAccessGranteeDetailsResponseModel : EmergencyAccessResponseModel -{ - public EmergencyAccessGranteeDetailsResponseModel(EmergencyAccessDetails emergencyAccess) - : base(emergencyAccess, "emergencyAccessGranteeDetails") + public class EmergencyAccessGranteeDetailsResponseModel : EmergencyAccessResponseModel { - if (emergencyAccess == null) + public EmergencyAccessGranteeDetailsResponseModel(EmergencyAccessDetails emergencyAccess) + : base(emergencyAccess, "emergencyAccessGranteeDetails") { - throw new ArgumentNullException(nameof(emergencyAccess)); + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + GranteeId = emergencyAccess.GranteeId.ToString(); + Email = emergencyAccess.GranteeEmail; + Name = emergencyAccess.GranteeName; } - GranteeId = emergencyAccess.GranteeId.ToString(); - Email = emergencyAccess.GranteeEmail; - Name = emergencyAccess.GranteeName; + public string GranteeId { get; private set; } + public string Name { get; private set; } + public string Email { get; private set; } } - public string GranteeId { get; private set; } - public string Name { get; private set; } - public string Email { get; private set; } -} - -public class EmergencyAccessGrantorDetailsResponseModel : EmergencyAccessResponseModel -{ - public EmergencyAccessGrantorDetailsResponseModel(EmergencyAccessDetails emergencyAccess) - : base(emergencyAccess, "emergencyAccessGrantorDetails") + public class EmergencyAccessGrantorDetailsResponseModel : EmergencyAccessResponseModel { - if (emergencyAccess == null) + public EmergencyAccessGrantorDetailsResponseModel(EmergencyAccessDetails emergencyAccess) + : base(emergencyAccess, "emergencyAccessGrantorDetails") { - throw new ArgumentNullException(nameof(emergencyAccess)); + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + GrantorId = emergencyAccess.GrantorId.ToString(); + Email = emergencyAccess.GrantorEmail; + Name = emergencyAccess.GrantorName; } - GrantorId = emergencyAccess.GrantorId.ToString(); - Email = emergencyAccess.GrantorEmail; - Name = emergencyAccess.GrantorName; + public string GrantorId { get; private set; } + public string Name { get; private set; } + public string Email { get; private set; } } - public string GrantorId { get; private set; } - public string Name { get; private set; } - public string Email { get; private set; } -} - -public class EmergencyAccessTakeoverResponseModel : ResponseModel -{ - public EmergencyAccessTakeoverResponseModel(EmergencyAccess emergencyAccess, User grantor, string obj = "emergencyAccessTakeover") : base(obj) + public class EmergencyAccessTakeoverResponseModel : ResponseModel { - if (emergencyAccess == null) + public EmergencyAccessTakeoverResponseModel(EmergencyAccess emergencyAccess, User grantor, string obj = "emergencyAccessTakeover") : base(obj) { - throw new ArgumentNullException(nameof(emergencyAccess)); + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + KeyEncrypted = emergencyAccess.KeyEncrypted; + Kdf = grantor.Kdf; + KdfIterations = grantor.KdfIterations; } - KeyEncrypted = emergencyAccess.KeyEncrypted; - Kdf = grantor.Kdf; - KdfIterations = grantor.KdfIterations; + public int KdfIterations { get; private set; } + public KdfType Kdf { get; private set; } + public string KeyEncrypted { get; private set; } } - public int KdfIterations { get; private set; } - public KdfType Kdf { get; private set; } - public string KeyEncrypted { get; private set; } -} - -public class EmergencyAccessViewResponseModel : ResponseModel -{ - public EmergencyAccessViewResponseModel( - IGlobalSettings globalSettings, - EmergencyAccess emergencyAccess, - IEnumerable ciphers) - : base("emergencyAccessView") + public class EmergencyAccessViewResponseModel : ResponseModel { - KeyEncrypted = emergencyAccess.KeyEncrypted; - Ciphers = ciphers.Select(c => new CipherResponseModel(c, globalSettings)); - } + public EmergencyAccessViewResponseModel( + IGlobalSettings globalSettings, + EmergencyAccess emergencyAccess, + IEnumerable ciphers) + : base("emergencyAccessView") + { + KeyEncrypted = emergencyAccess.KeyEncrypted; + Ciphers = ciphers.Select(c => new CipherResponseModel(c, globalSettings)); + } - public string KeyEncrypted { get; set; } - public IEnumerable Ciphers { get; set; } + public string KeyEncrypted { get; set; } + public IEnumerable Ciphers { get; set; } + } } diff --git a/src/Api/Models/Response/EventResponseModel.cs b/src/Api/Models/Response/EventResponseModel.cs index 6cb723c6e0..40ccbb7e1f 100644 --- a/src/Api/Models/Response/EventResponseModel.cs +++ b/src/Api/Models/Response/EventResponseModel.cs @@ -2,50 +2,51 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class EventResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public EventResponseModel(IEvent ev) - : base("event") + public class EventResponseModel : ResponseModel { - if (ev == null) + public EventResponseModel(IEvent ev) + : base("event") { - throw new ArgumentNullException(nameof(ev)); + if (ev == null) + { + throw new ArgumentNullException(nameof(ev)); + } + + Type = ev.Type; + UserId = ev.UserId; + OrganizationId = ev.OrganizationId; + ProviderId = ev.ProviderId; + CipherId = ev.CipherId; + CollectionId = ev.CollectionId; + GroupId = ev.GroupId; + PolicyId = ev.PolicyId; + OrganizationUserId = ev.OrganizationUserId; + ProviderUserId = ev.ProviderUserId; + ProviderOrganizationId = ev.ProviderOrganizationId; + ActingUserId = ev.ActingUserId; + Date = ev.Date; + DeviceType = ev.DeviceType; + IpAddress = ev.IpAddress; + InstallationId = ev.InstallationId; } - Type = ev.Type; - UserId = ev.UserId; - OrganizationId = ev.OrganizationId; - ProviderId = ev.ProviderId; - CipherId = ev.CipherId; - CollectionId = ev.CollectionId; - GroupId = ev.GroupId; - PolicyId = ev.PolicyId; - OrganizationUserId = ev.OrganizationUserId; - ProviderUserId = ev.ProviderUserId; - ProviderOrganizationId = ev.ProviderOrganizationId; - ActingUserId = ev.ActingUserId; - Date = ev.Date; - DeviceType = ev.DeviceType; - IpAddress = ev.IpAddress; - InstallationId = ev.InstallationId; + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? GroupId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public Guid? ActingUserId { get; set; } + public Guid? InstallationId { get; set; } + public DateTime Date { get; set; } + public DeviceType? DeviceType { get; set; } + public string IpAddress { get; set; } } - - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? GroupId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public Guid? ActingUserId { get; set; } - public Guid? InstallationId { get; set; } - public DateTime Date { get; set; } - public DeviceType? DeviceType { get; set; } - public string IpAddress { get; set; } } diff --git a/src/Api/Models/Response/FolderResponseModel.cs b/src/Api/Models/Response/FolderResponseModel.cs index 03971b4e3a..0396471e12 100644 --- a/src/Api/Models/Response/FolderResponseModel.cs +++ b/src/Api/Models/Response/FolderResponseModel.cs @@ -1,24 +1,25 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class FolderResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public FolderResponseModel(Folder folder) - : base("folder") + public class FolderResponseModel : ResponseModel { - if (folder == null) + public FolderResponseModel(Folder folder) + : base("folder") { - throw new ArgumentNullException(nameof(folder)); + if (folder == null) + { + throw new ArgumentNullException(nameof(folder)); + } + + Id = folder.Id.ToString(); + Name = folder.Name; + RevisionDate = folder.RevisionDate; } - Id = folder.Id.ToString(); - Name = folder.Name; - RevisionDate = folder.RevisionDate; + public string Id { get; set; } + public string Name { get; set; } + public DateTime RevisionDate { get; set; } } - - public string Id { get; set; } - public string Name { get; set; } - public DateTime RevisionDate { get; set; } } diff --git a/src/Api/Models/Response/GroupResponseModel.cs b/src/Api/Models/Response/GroupResponseModel.cs index 4b6496a40c..c75ff31e2e 100644 --- a/src/Api/Models/Response/GroupResponseModel.cs +++ b/src/Api/Models/Response/GroupResponseModel.cs @@ -2,39 +2,40 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class GroupResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public GroupResponseModel(Group group, string obj = "group") - : base(obj) + public class GroupResponseModel : ResponseModel { - if (group == null) + public GroupResponseModel(Group group, string obj = "group") + : base(obj) { - throw new ArgumentNullException(nameof(group)); + if (group == null) + { + throw new ArgumentNullException(nameof(group)); + } + + Id = group.Id.ToString(); + OrganizationId = group.OrganizationId.ToString(); + Name = group.Name; + AccessAll = group.AccessAll; + ExternalId = group.ExternalId; } - Id = group.Id.ToString(); - OrganizationId = group.OrganizationId.ToString(); - Name = group.Name; - AccessAll = group.AccessAll; - ExternalId = group.ExternalId; + public string Id { get; set; } + public string OrganizationId { get; set; } + public string Name { get; set; } + public bool AccessAll { get; set; } + public string ExternalId { get; set; } } - public string Id { get; set; } - public string OrganizationId { get; set; } - public string Name { get; set; } - public bool AccessAll { get; set; } - public string ExternalId { get; set; } -} - -public class GroupDetailsResponseModel : GroupResponseModel -{ - public GroupDetailsResponseModel(Group group, IEnumerable collections) - : base(group, "groupDetails") + public class GroupDetailsResponseModel : GroupResponseModel { - Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); - } + public GroupDetailsResponseModel(Group group, IEnumerable collections) + : base(group, "groupDetails") + { + Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); + } - public IEnumerable Collections { get; set; } + public IEnumerable Collections { get; set; } + } } diff --git a/src/Api/Models/Response/InstallationResponseModel.cs b/src/Api/Models/Response/InstallationResponseModel.cs index 75000471da..68e1524b14 100644 --- a/src/Api/Models/Response/InstallationResponseModel.cs +++ b/src/Api/Models/Response/InstallationResponseModel.cs @@ -1,19 +1,20 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class InstallationResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public InstallationResponseModel(Installation installation, bool withKey) - : base("installation") + public class InstallationResponseModel : ResponseModel { - Id = installation.Id.ToString(); - Key = withKey ? installation.Key : null; - Enabled = installation.Enabled; - } + public InstallationResponseModel(Installation installation, bool withKey) + : base("installation") + { + Id = installation.Id.ToString(); + Key = withKey ? installation.Key : null; + Enabled = installation.Enabled; + } - public string Id { get; set; } - public string Key { get; set; } - public bool Enabled { get; set; } + public string Id { get; set; } + public string Key { get; set; } + public bool Enabled { get; set; } + } } diff --git a/src/Api/Models/Response/KeysResponseModel.cs b/src/Api/Models/Response/KeysResponseModel.cs index 2f7e5e7304..1ca1ae052c 100644 --- a/src/Api/Models/Response/KeysResponseModel.cs +++ b/src/Api/Models/Response/KeysResponseModel.cs @@ -1,24 +1,25 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class KeysResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public KeysResponseModel(User user) - : base("keys") + public class KeysResponseModel : ResponseModel { - if (user == null) + public KeysResponseModel(User user) + : base("keys") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Key = user.Key; + PublicKey = user.PublicKey; + PrivateKey = user.PrivateKey; } - Key = user.Key; - PublicKey = user.PublicKey; - PrivateKey = user.PrivateKey; + public string Key { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } } - - public string Key { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } } diff --git a/src/Api/Models/Response/ListResponseModel.cs b/src/Api/Models/Response/ListResponseModel.cs index ecfe0a7e19..c16a3461cf 100644 --- a/src/Api/Models/Response/ListResponseModel.cs +++ b/src/Api/Models/Response/ListResponseModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class ListResponseModel : ResponseModel where T : ResponseModel +namespace Bit.Api.Models.Response { - public ListResponseModel(IEnumerable data, string continuationToken = null) - : base("list") + public class ListResponseModel : ResponseModel where T : ResponseModel { - Data = data; - ContinuationToken = continuationToken; - } + public ListResponseModel(IEnumerable data, string continuationToken = null) + : base("list") + { + Data = data; + ContinuationToken = continuationToken; + } - public IEnumerable Data { get; set; } - public string ContinuationToken { get; set; } + public IEnumerable Data { get; set; } + public string ContinuationToken { get; set; } + } } diff --git a/src/Api/Models/Response/OrganizationExportResponseModel.cs b/src/Api/Models/Response/OrganizationExportResponseModel.cs index a7533c918d..f5ce61873c 100644 --- a/src/Api/Models/Response/OrganizationExportResponseModel.cs +++ b/src/Api/Models/Response/OrganizationExportResponseModel.cs @@ -1,12 +1,13 @@ -namespace Bit.Api.Models.Response; - -public class OrganizationExportResponseModel +namespace Bit.Api.Models.Response { - public OrganizationExportResponseModel() + public class OrganizationExportResponseModel { + public OrganizationExportResponseModel() + { + } + + public ListResponseModel Collections { get; set; } + + public ListResponseModel Ciphers { get; set; } } - - public ListResponseModel Collections { get; set; } - - public ListResponseModel Ciphers { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs index a25cb89355..05adb502a1 100644 --- a/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs @@ -2,16 +2,17 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationApiKeyInformation : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationApiKeyInformation(OrganizationApiKey key) : base("keyInformation") + public class OrganizationApiKeyInformation : ResponseModel { - KeyType = key.Type; - RevisionDate = key.RevisionDate; - } + public OrganizationApiKeyInformation(OrganizationApiKey key) : base("keyInformation") + { + KeyType = key.Type; + RevisionDate = key.RevisionDate; + } - public OrganizationApiKeyType KeyType { get; set; } - public DateTime RevisionDate { get; set; } + public OrganizationApiKeyType KeyType { get; set; } + public DateTime RevisionDate { get; set; } + } } diff --git a/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs index 9c1f0ee22d..529168c6ad 100644 --- a/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs @@ -1,15 +1,16 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationAutoEnrollStatusResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationAutoEnrollStatusResponseModel(Guid orgId, bool resetPasswordEnabled) : base("organizationAutoEnrollStatus") + public class OrganizationAutoEnrollStatusResponseModel : ResponseModel { - Id = orgId.ToString(); - ResetPasswordEnabled = resetPasswordEnabled; - } + public OrganizationAutoEnrollStatusResponseModel(Guid orgId, bool resetPasswordEnabled) : base("organizationAutoEnrollStatus") + { + Id = orgId.ToString(); + ResetPasswordEnabled = resetPasswordEnabled; + } - public string Id { get; set; } - public bool ResetPasswordEnabled { get; set; } + public string Id { get; set; } + public bool ResetPasswordEnabled { get; set; } + } } diff --git a/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs index f199ce56c2..86fb9b4db1 100644 --- a/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs @@ -2,27 +2,28 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationConnectionResponseModel +namespace Bit.Api.Models.Response.Organizations { - public Guid? Id { get; set; } - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public JsonDocument Config { get; set; } - - public OrganizationConnectionResponseModel(OrganizationConnection connection, Type configType) + public class OrganizationConnectionResponseModel { - if (connection == null) - { - return; - } + public Guid? Id { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public JsonDocument Config { get; set; } - Id = connection.Id; - Type = connection.Type; - OrganizationId = connection.OrganizationId; - Enabled = connection.Enabled; - Config = JsonDocument.Parse(connection.Config); + public OrganizationConnectionResponseModel(OrganizationConnection connection, Type configType) + { + if (connection == null) + { + return; + } + + Id = connection.Id; + Type = connection.Type; + OrganizationId = connection.OrganizationId; + Enabled = connection.Enabled; + Config = JsonDocument.Parse(connection.Config); + } } } diff --git a/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs index 35c2f77e7d..06430bef27 100644 --- a/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs @@ -1,21 +1,22 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationKeysResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationKeysResponseModel(Organization org) : base("organizationKeys") + public class OrganizationKeysResponseModel : ResponseModel { - if (org == null) + public OrganizationKeysResponseModel(Organization org) : base("organizationKeys") { - throw new ArgumentNullException(nameof(org)); + if (org == null) + { + throw new ArgumentNullException(nameof(org)); + } + + PublicKey = org.PublicKey; + PrivateKey = org.PrivateKey; } - PublicKey = org.PublicKey; - PrivateKey = org.PrivateKey; + public string PublicKey { get; set; } + public string PrivateKey { get; set; } } - - public string PublicKey { get; set; } - public string PrivateKey { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs index 4aa83d201a..eab23bee97 100644 --- a/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs @@ -4,109 +4,110 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationResponseModel(Organization organization, string obj = "organization") - : base(obj) + public class OrganizationResponseModel : ResponseModel { - if (organization == null) + public OrganizationResponseModel(Organization organization, string obj = "organization") + : base(obj) { - throw new ArgumentNullException(nameof(organization)); + if (organization == null) + { + throw new ArgumentNullException(nameof(organization)); + } + + Id = organization.Id.ToString(); + Identifier = organization.Identifier; + Name = organization.Name; + BusinessName = organization.BusinessName; + BusinessAddress1 = organization.BusinessAddress1; + BusinessAddress2 = organization.BusinessAddress2; + BusinessAddress3 = organization.BusinessAddress3; + BusinessCountry = organization.BusinessCountry; + BusinessTaxNumber = organization.BusinessTaxNumber; + BillingEmail = organization.BillingEmail; + Plan = new PlanResponseModel(StaticStore.Plans.FirstOrDefault(plan => plan.Type == organization.PlanType)); + PlanType = organization.PlanType; + Seats = organization.Seats; + MaxAutoscaleSeats = organization.MaxAutoscaleSeats; + MaxCollections = organization.MaxCollections; + MaxStorageGb = organization.MaxStorageGb; + UsePolicies = organization.UsePolicies; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseGroups = organization.UseGroups; + UseDirectory = organization.UseDirectory; + UseEvents = organization.UseEvents; + UseTotp = organization.UseTotp; + Use2fa = organization.Use2fa; + UseApi = organization.UseApi; + UseResetPassword = organization.UseResetPassword; + UsersGetPremium = organization.UsersGetPremium; + SelfHost = organization.SelfHost; + HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; } - Id = organization.Id.ToString(); - Identifier = organization.Identifier; - Name = organization.Name; - BusinessName = organization.BusinessName; - BusinessAddress1 = organization.BusinessAddress1; - BusinessAddress2 = organization.BusinessAddress2; - BusinessAddress3 = organization.BusinessAddress3; - BusinessCountry = organization.BusinessCountry; - BusinessTaxNumber = organization.BusinessTaxNumber; - BillingEmail = organization.BillingEmail; - Plan = new PlanResponseModel(StaticStore.Plans.FirstOrDefault(plan => plan.Type == organization.PlanType)); - PlanType = organization.PlanType; - Seats = organization.Seats; - MaxAutoscaleSeats = organization.MaxAutoscaleSeats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - SelfHost = organization.SelfHost; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; + public string Id { get; set; } + public string Identifier { get; set; } + public string Name { get; set; } + public string BusinessName { get; set; } + public string BusinessAddress1 { get; set; } + public string BusinessAddress2 { get; set; } + public string BusinessAddress3 { get; set; } + public string BusinessCountry { get; set; } + public string BusinessTaxNumber { get; set; } + public string BillingEmail { get; set; } + public PlanResponseModel Plan { get; set; } + public PlanType PlanType { get; set; } + public int? Seats { get; set; } + public int? MaxAutoscaleSeats { get; set; } = null; + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool UsersGetPremium { get; set; } + public bool SelfHost { get; set; } + public bool HasPublicAndPrivateKeys { get; set; } } - public string Id { get; set; } - public string Identifier { get; set; } - public string Name { get; set; } - public string BusinessName { get; set; } - public string BusinessAddress1 { get; set; } - public string BusinessAddress2 { get; set; } - public string BusinessAddress3 { get; set; } - public string BusinessCountry { get; set; } - public string BusinessTaxNumber { get; set; } - public string BillingEmail { get; set; } - public PlanResponseModel Plan { get; set; } - public PlanType PlanType { get; set; } - public int? Seats { get; set; } - public int? MaxAutoscaleSeats { get; set; } = null; - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool UsersGetPremium { get; set; } - public bool SelfHost { get; set; } - public bool HasPublicAndPrivateKeys { get; set; } -} - -public class OrganizationSubscriptionResponseModel : OrganizationResponseModel -{ - public OrganizationSubscriptionResponseModel(Organization organization, SubscriptionInfo subscription = null) - : base(organization, "organizationSubscription") + public class OrganizationSubscriptionResponseModel : OrganizationResponseModel { - if (subscription != null) + public OrganizationSubscriptionResponseModel(Organization organization, SubscriptionInfo subscription = null) + : base(organization, "organizationSubscription") { - Subscription = subscription.Subscription != null ? - new BillingSubscription(subscription.Subscription) : null; - UpcomingInvoice = subscription.UpcomingInvoice != null ? - new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; - Expiration = DateTime.UtcNow.AddYears(1); // Not used, so just give it a value. - } - else - { - Expiration = organization.ExpirationDate; + if (subscription != null) + { + Subscription = subscription.Subscription != null ? + new BillingSubscription(subscription.Subscription) : null; + UpcomingInvoice = subscription.UpcomingInvoice != null ? + new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; + Expiration = DateTime.UtcNow.AddYears(1); // Not used, so just give it a value. + } + else + { + Expiration = organization.ExpirationDate; + } + + StorageName = organization.Storage.HasValue ? + CoreHelpers.ReadableBytesSize(organization.Storage.Value) : null; + StorageGb = organization.Storage.HasValue ? + Math.Round(organization.Storage.Value / 1073741824D, 2) : 0; // 1 GB } - StorageName = organization.Storage.HasValue ? - CoreHelpers.ReadableBytesSize(organization.Storage.Value) : null; - StorageGb = organization.Storage.HasValue ? - Math.Round(organization.Storage.Value / 1073741824D, 2) : 0; // 1 GB + public string StorageName { get; set; } + public double? StorageGb { get; set; } + public BillingSubscription Subscription { get; set; } + public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } + public DateTime? Expiration { get; set; } } - - public string StorageName { get; set; } - public double? StorageGb { get; set; } - public BillingSubscription Subscription { get; set; } - public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } - public DateTime? Expiration { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs index 33862f391e..33e349bbfd 100644 --- a/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs @@ -1,14 +1,15 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationSponsorshipSyncStatusResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationSponsorshipSyncStatusResponseModel(DateTime? lastSyncDate) - : base("syncStatus") + public class OrganizationSponsorshipSyncStatusResponseModel : ResponseModel { - LastSyncDate = lastSyncDate; - } + public OrganizationSponsorshipSyncStatusResponseModel(DateTime? lastSyncDate) + : base("syncStatus") + { + LastSyncDate = lastSyncDate; + } - public DateTime? LastSyncDate { get; set; } + public DateTime? LastSyncDate { get; set; } + } } diff --git a/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs index cd7e6c2665..dd828630b3 100644 --- a/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs @@ -3,41 +3,42 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Settings; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationSsoResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationSsoResponseModel(Organization organization, GlobalSettings globalSettings, - SsoConfig config = null) : base("organizationSso") + public class OrganizationSsoResponseModel : ResponseModel { - if (config != null) + public OrganizationSsoResponseModel(Organization organization, GlobalSettings globalSettings, + SsoConfig config = null) : base("organizationSso") { - Enabled = config.Enabled; - Data = config.GetData(); + if (config != null) + { + Enabled = config.Enabled; + Data = config.GetData(); + } + + Urls = new SsoUrls(organization.Id.ToString(), globalSettings); } - Urls = new SsoUrls(organization.Id.ToString(), globalSettings); + public bool Enabled { get; set; } + public SsoConfigurationData Data { get; set; } + public SsoUrls Urls { get; set; } } - public bool Enabled { get; set; } - public SsoConfigurationData Data { get; set; } - public SsoUrls Urls { get; set; } -} - -public class SsoUrls -{ - public SsoUrls(string organizationId, GlobalSettings globalSettings) + public class SsoUrls { - CallbackPath = SsoConfigurationData.BuildCallbackPath(globalSettings.BaseServiceUri.Sso); - SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(globalSettings.BaseServiceUri.Sso); - SpEntityId = SsoConfigurationData.BuildSaml2ModulePath(globalSettings.BaseServiceUri.Sso); - SpMetadataUrl = SsoConfigurationData.BuildSaml2MetadataUrl(globalSettings.BaseServiceUri.Sso, organizationId); - SpAcsUrl = SsoConfigurationData.BuildSaml2AcsUrl(globalSettings.BaseServiceUri.Sso, organizationId); - } + public SsoUrls(string organizationId, GlobalSettings globalSettings) + { + CallbackPath = SsoConfigurationData.BuildCallbackPath(globalSettings.BaseServiceUri.Sso); + SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(globalSettings.BaseServiceUri.Sso); + SpEntityId = SsoConfigurationData.BuildSaml2ModulePath(globalSettings.BaseServiceUri.Sso); + SpMetadataUrl = SsoConfigurationData.BuildSaml2MetadataUrl(globalSettings.BaseServiceUri.Sso, organizationId); + SpAcsUrl = SsoConfigurationData.BuildSaml2AcsUrl(globalSettings.BaseServiceUri.Sso, organizationId); + } - public string CallbackPath { get; set; } - public string SignedOutCallbackPath { get; set; } - public string SpEntityId { get; set; } - public string SpMetadataUrl { get; set; } - public string SpAcsUrl { get; set; } + public string CallbackPath { get; set; } + public string SignedOutCallbackPath { get; set; } + public string SpEntityId { get; set; } + public string SpMetadataUrl { get; set; } + public string SpAcsUrl { get; set; } + } } diff --git a/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs index 619769b066..7be68c41a7 100644 --- a/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs @@ -5,138 +5,139 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Organizations; - -public class OrganizationUserResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Organizations { - public OrganizationUserResponseModel(OrganizationUser organizationUser, string obj = "organizationUser") - : base(obj) + public class OrganizationUserResponseModel : ResponseModel { - if (organizationUser == null) + public OrganizationUserResponseModel(OrganizationUser organizationUser, string obj = "organizationUser") + : base(obj) { - throw new ArgumentNullException(nameof(organizationUser)); + if (organizationUser == null) + { + throw new ArgumentNullException(nameof(organizationUser)); + } + + Id = organizationUser.Id.ToString(); + UserId = organizationUser.UserId?.ToString(); + Type = organizationUser.Type; + Status = organizationUser.Status; + AccessAll = organizationUser.AccessAll; + Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); + ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); } - Id = organizationUser.Id.ToString(); - UserId = organizationUser.UserId?.ToString(); - Type = organizationUser.Type; - Status = organizationUser.Status; - AccessAll = organizationUser.AccessAll; - Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); - ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); - } - - public OrganizationUserResponseModel(OrganizationUserUserDetails organizationUser, string obj = "organizationUser") - : base(obj) - { - if (organizationUser == null) + public OrganizationUserResponseModel(OrganizationUserUserDetails organizationUser, string obj = "organizationUser") + : base(obj) { - throw new ArgumentNullException(nameof(organizationUser)); + if (organizationUser == null) + { + throw new ArgumentNullException(nameof(organizationUser)); + } + + Id = organizationUser.Id.ToString(); + UserId = organizationUser.UserId?.ToString(); + Type = organizationUser.Type; + Status = organizationUser.Status; + AccessAll = organizationUser.AccessAll; + Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); + ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); + UsesKeyConnector = organizationUser.UsesKeyConnector; } - Id = organizationUser.Id.ToString(); - UserId = organizationUser.UserId?.ToString(); - Type = organizationUser.Type; - Status = organizationUser.Status; - AccessAll = organizationUser.AccessAll; - Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); - ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); - UsesKeyConnector = organizationUser.UsesKeyConnector; + public string Id { get; set; } + public string UserId { get; set; } + public OrganizationUserType Type { get; set; } + public OrganizationUserStatusType Status { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public bool ResetPasswordEnrolled { get; set; } + public bool UsesKeyConnector { get; set; } } - public string Id { get; set; } - public string UserId { get; set; } - public OrganizationUserType Type { get; set; } - public OrganizationUserStatusType Status { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public bool ResetPasswordEnrolled { get; set; } - public bool UsesKeyConnector { get; set; } -} - -public class OrganizationUserDetailsResponseModel : OrganizationUserResponseModel -{ - public OrganizationUserDetailsResponseModel(OrganizationUser organizationUser, - IEnumerable collections) - : base(organizationUser, "organizationUserDetails") + public class OrganizationUserDetailsResponseModel : OrganizationUserResponseModel { - Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); - } - - public IEnumerable Collections { get; set; } -} - -public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponseModel -{ - public OrganizationUserUserDetailsResponseModel(OrganizationUserUserDetails organizationUser, - bool twoFactorEnabled, string obj = "organizationUserUserDetails") - : base(organizationUser, obj) - { - if (organizationUser == null) + public OrganizationUserDetailsResponseModel(OrganizationUser organizationUser, + IEnumerable collections) + : base(organizationUser, "organizationUserDetails") { - throw new ArgumentNullException(nameof(organizationUser)); + Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); } - Name = organizationUser.Name; - Email = organizationUser.Email; - TwoFactorEnabled = twoFactorEnabled; - SsoBound = !string.IsNullOrWhiteSpace(organizationUser.SsoExternalId); - // Prevent reset password when using key connector. - ResetPasswordEnrolled = ResetPasswordEnrolled && !organizationUser.UsesKeyConnector; + public IEnumerable Collections { get; set; } } - public string Name { get; set; } - public string Email { get; set; } - public bool TwoFactorEnabled { get; set; } - public bool SsoBound { get; set; } -} - -public class OrganizationUserResetPasswordDetailsResponseModel : ResponseModel -{ - public OrganizationUserResetPasswordDetailsResponseModel(OrganizationUserResetPasswordDetails orgUser, - string obj = "organizationUserResetPasswordDetails") : base(obj) + public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponseModel { - if (orgUser == null) + public OrganizationUserUserDetailsResponseModel(OrganizationUserUserDetails organizationUser, + bool twoFactorEnabled, string obj = "organizationUserUserDetails") + : base(organizationUser, obj) { - throw new ArgumentNullException(nameof(orgUser)); + if (organizationUser == null) + { + throw new ArgumentNullException(nameof(organizationUser)); + } + + Name = organizationUser.Name; + Email = organizationUser.Email; + TwoFactorEnabled = twoFactorEnabled; + SsoBound = !string.IsNullOrWhiteSpace(organizationUser.SsoExternalId); + // Prevent reset password when using key connector. + ResetPasswordEnrolled = ResetPasswordEnrolled && !organizationUser.UsesKeyConnector; } - Kdf = orgUser.Kdf; - KdfIterations = orgUser.KdfIterations; - ResetPasswordKey = orgUser.ResetPasswordKey; - EncryptedPrivateKey = orgUser.EncryptedPrivateKey; + public string Name { get; set; } + public string Email { get; set; } + public bool TwoFactorEnabled { get; set; } + public bool SsoBound { get; set; } } - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } - public string ResetPasswordKey { get; set; } - public string EncryptedPrivateKey { get; set; } -} - -public class OrganizationUserPublicKeyResponseModel : ResponseModel -{ - public OrganizationUserPublicKeyResponseModel(Guid id, Guid userId, - string key, string obj = "organizationUserPublicKeyResponseModel") : - base(obj) + public class OrganizationUserResetPasswordDetailsResponseModel : ResponseModel { - Id = id; - UserId = userId; - Key = key; + public OrganizationUserResetPasswordDetailsResponseModel(OrganizationUserResetPasswordDetails orgUser, + string obj = "organizationUserResetPasswordDetails") : base(obj) + { + if (orgUser == null) + { + throw new ArgumentNullException(nameof(orgUser)); + } + + Kdf = orgUser.Kdf; + KdfIterations = orgUser.KdfIterations; + ResetPasswordKey = orgUser.ResetPasswordKey; + EncryptedPrivateKey = orgUser.EncryptedPrivateKey; + } + + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } + public string ResetPasswordKey { get; set; } + public string EncryptedPrivateKey { get; set; } } - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string Key { get; set; } -} - -public class OrganizationUserBulkResponseModel : ResponseModel -{ - public OrganizationUserBulkResponseModel(Guid id, string error, - string obj = "OrganizationBulkConfirmResponseModel") : base(obj) + public class OrganizationUserPublicKeyResponseModel : ResponseModel { - Id = id; - Error = error; + public OrganizationUserPublicKeyResponseModel(Guid id, Guid userId, + string key, string obj = "organizationUserPublicKeyResponseModel") : + base(obj) + { + Id = id; + UserId = userId; + Key = key; + } + + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string Key { get; set; } + } + + public class OrganizationUserBulkResponseModel : ResponseModel + { + public OrganizationUserBulkResponseModel(Guid id, string error, + string obj = "OrganizationBulkConfirmResponseModel") : base(obj) + { + Id = id; + Error = error; + } + public Guid Id { get; set; } + public string Error { get; set; } } - public Guid Id { get; set; } - public string Error { get; set; } } diff --git a/src/Api/Models/Response/PaymentResponseModel.cs b/src/Api/Models/Response/PaymentResponseModel.cs index 067ac969ec..43edb32164 100644 --- a/src/Api/Models/Response/PaymentResponseModel.cs +++ b/src/Api/Models/Response/PaymentResponseModel.cs @@ -1,14 +1,15 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class PaymentResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public PaymentResponseModel() - : base("payment") - { } + public class PaymentResponseModel : ResponseModel + { + public PaymentResponseModel() + : base("payment") + { } - public ProfileResponseModel UserProfile { get; set; } - public string PaymentIntentClientSecret { get; set; } - public bool Success { get; set; } + public ProfileResponseModel UserProfile { get; set; } + public string PaymentIntentClientSecret { get; set; } + public bool Success { get; set; } + } } diff --git a/src/Api/Models/Response/PlanResponseModel.cs b/src/Api/Models/Response/PlanResponseModel.cs index fd2934e735..5974772a8c 100644 --- a/src/Api/Models/Response/PlanResponseModel.cs +++ b/src/Api/Models/Response/PlanResponseModel.cs @@ -2,100 +2,101 @@ using Bit.Core.Models.Api; using Bit.Core.Models.StaticStore; -namespace Bit.Api.Models.Response; - -public class PlanResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public PlanResponseModel(Plan plan, string obj = "plan") - : base(obj) + public class PlanResponseModel : ResponseModel { - if (plan == null) + public PlanResponseModel(Plan plan, string obj = "plan") + : base(obj) { - throw new ArgumentNullException(nameof(plan)); + if (plan == null) + { + throw new ArgumentNullException(nameof(plan)); + } + + Type = plan.Type; + Product = plan.Product; + Name = plan.Name; + IsAnnual = plan.IsAnnual; + NameLocalizationKey = plan.NameLocalizationKey; + DescriptionLocalizationKey = plan.DescriptionLocalizationKey; + CanBeUsedByBusiness = plan.CanBeUsedByBusiness; + BaseSeats = plan.BaseSeats; + BaseStorageGb = plan.BaseStorageGb; + MaxCollections = plan.MaxCollections; + MaxUsers = plan.MaxUsers; + HasAdditionalSeatsOption = plan.HasAdditionalSeatsOption; + HasAdditionalStorageOption = plan.HasAdditionalStorageOption; + MaxAdditionalSeats = plan.MaxAdditionalSeats; + MaxAdditionalStorage = plan.MaxAdditionalStorage; + HasPremiumAccessOption = plan.HasPremiumAccessOption; + TrialPeriodDays = plan.TrialPeriodDays; + HasSelfHost = plan.HasSelfHost; + HasPolicies = plan.HasPolicies; + HasGroups = plan.HasGroups; + HasDirectory = plan.HasDirectory; + HasEvents = plan.HasEvents; + HasTotp = plan.HasTotp; + Has2fa = plan.Has2fa; + HasSso = plan.HasSso; + HasResetPassword = plan.HasResetPassword; + UsersGetPremium = plan.UsersGetPremium; + UpgradeSortOrder = plan.UpgradeSortOrder; + DisplaySortOrder = plan.DisplaySortOrder; + LegacyYear = plan.LegacyYear; + Disabled = plan.Disabled; + StripePlanId = plan.StripePlanId; + StripeSeatPlanId = plan.StripeSeatPlanId; + StripeStoragePlanId = plan.StripeStoragePlanId; + BasePrice = plan.BasePrice; + SeatPrice = plan.SeatPrice; + AdditionalStoragePricePerGb = plan.AdditionalStoragePricePerGb; + PremiumAccessOptionPrice = plan.PremiumAccessOptionPrice; } - Type = plan.Type; - Product = plan.Product; - Name = plan.Name; - IsAnnual = plan.IsAnnual; - NameLocalizationKey = plan.NameLocalizationKey; - DescriptionLocalizationKey = plan.DescriptionLocalizationKey; - CanBeUsedByBusiness = plan.CanBeUsedByBusiness; - BaseSeats = plan.BaseSeats; - BaseStorageGb = plan.BaseStorageGb; - MaxCollections = plan.MaxCollections; - MaxUsers = plan.MaxUsers; - HasAdditionalSeatsOption = plan.HasAdditionalSeatsOption; - HasAdditionalStorageOption = plan.HasAdditionalStorageOption; - MaxAdditionalSeats = plan.MaxAdditionalSeats; - MaxAdditionalStorage = plan.MaxAdditionalStorage; - HasPremiumAccessOption = plan.HasPremiumAccessOption; - TrialPeriodDays = plan.TrialPeriodDays; - HasSelfHost = plan.HasSelfHost; - HasPolicies = plan.HasPolicies; - HasGroups = plan.HasGroups; - HasDirectory = plan.HasDirectory; - HasEvents = plan.HasEvents; - HasTotp = plan.HasTotp; - Has2fa = plan.Has2fa; - HasSso = plan.HasSso; - HasResetPassword = plan.HasResetPassword; - UsersGetPremium = plan.UsersGetPremium; - UpgradeSortOrder = plan.UpgradeSortOrder; - DisplaySortOrder = plan.DisplaySortOrder; - LegacyYear = plan.LegacyYear; - Disabled = plan.Disabled; - StripePlanId = plan.StripePlanId; - StripeSeatPlanId = plan.StripeSeatPlanId; - StripeStoragePlanId = plan.StripeStoragePlanId; - BasePrice = plan.BasePrice; - SeatPrice = plan.SeatPrice; - AdditionalStoragePricePerGb = plan.AdditionalStoragePricePerGb; - PremiumAccessOptionPrice = plan.PremiumAccessOptionPrice; + public PlanType Type { get; set; } + public ProductType Product { get; set; } + public string Name { get; set; } + public bool IsAnnual { get; set; } + public string NameLocalizationKey { get; set; } + public string DescriptionLocalizationKey { get; set; } + public bool CanBeUsedByBusiness { get; set; } + public int BaseSeats { get; set; } + public short? BaseStorageGb { get; set; } + public short? MaxCollections { get; set; } + public short? MaxUsers { get; set; } + + public bool HasAdditionalSeatsOption { get; set; } + public int? MaxAdditionalSeats { get; set; } + public bool HasAdditionalStorageOption { get; set; } + public short? MaxAdditionalStorage { get; set; } + public bool HasPremiumAccessOption { get; set; } + public int? TrialPeriodDays { get; set; } + + public bool HasSelfHost { get; set; } + public bool HasPolicies { get; set; } + public bool HasGroups { get; set; } + public bool HasDirectory { get; set; } + public bool HasEvents { get; set; } + public bool HasTotp { get; set; } + public bool Has2fa { get; set; } + public bool HasApi { get; set; } + public bool HasSso { get; set; } + public bool HasResetPassword { get; set; } + public bool UsersGetPremium { get; set; } + + public int UpgradeSortOrder { get; set; } + public int DisplaySortOrder { get; set; } + public int? LegacyYear { get; set; } + public bool Disabled { get; set; } + + public string StripePlanId { get; set; } + public string StripeSeatPlanId { get; set; } + public string StripeStoragePlanId { get; set; } + public string StripePremiumAccessPlanId { get; set; } + public decimal BasePrice { get; set; } + public decimal SeatPrice { get; set; } + public decimal AdditionalStoragePricePerGb { get; set; } + public decimal PremiumAccessOptionPrice { get; set; } } - - public PlanType Type { get; set; } - public ProductType Product { get; set; } - public string Name { get; set; } - public bool IsAnnual { get; set; } - public string NameLocalizationKey { get; set; } - public string DescriptionLocalizationKey { get; set; } - public bool CanBeUsedByBusiness { get; set; } - public int BaseSeats { get; set; } - public short? BaseStorageGb { get; set; } - public short? MaxCollections { get; set; } - public short? MaxUsers { get; set; } - - public bool HasAdditionalSeatsOption { get; set; } - public int? MaxAdditionalSeats { get; set; } - public bool HasAdditionalStorageOption { get; set; } - public short? MaxAdditionalStorage { get; set; } - public bool HasPremiumAccessOption { get; set; } - public int? TrialPeriodDays { get; set; } - - public bool HasSelfHost { get; set; } - public bool HasPolicies { get; set; } - public bool HasGroups { get; set; } - public bool HasDirectory { get; set; } - public bool HasEvents { get; set; } - public bool HasTotp { get; set; } - public bool Has2fa { get; set; } - public bool HasApi { get; set; } - public bool HasSso { get; set; } - public bool HasResetPassword { get; set; } - public bool UsersGetPremium { get; set; } - - public int UpgradeSortOrder { get; set; } - public int DisplaySortOrder { get; set; } - public int? LegacyYear { get; set; } - public bool Disabled { get; set; } - - public string StripePlanId { get; set; } - public string StripeSeatPlanId { get; set; } - public string StripeStoragePlanId { get; set; } - public string StripePremiumAccessPlanId { get; set; } - public decimal BasePrice { get; set; } - public decimal SeatPrice { get; set; } - public decimal AdditionalStoragePricePerGb { get; set; } - public decimal PremiumAccessOptionPrice { get; set; } } diff --git a/src/Api/Models/Response/PolicyResponseModel.cs b/src/Api/Models/Response/PolicyResponseModel.cs index a812a911d0..7f725ba31b 100644 --- a/src/Api/Models/Response/PolicyResponseModel.cs +++ b/src/Api/Models/Response/PolicyResponseModel.cs @@ -3,31 +3,32 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class PolicyResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public PolicyResponseModel(Policy policy, string obj = "policy") - : base(obj) + public class PolicyResponseModel : ResponseModel { - if (policy == null) + public PolicyResponseModel(Policy policy, string obj = "policy") + : base(obj) { - throw new ArgumentNullException(nameof(policy)); + if (policy == null) + { + throw new ArgumentNullException(nameof(policy)); + } + + Id = policy.Id.ToString(); + OrganizationId = policy.OrganizationId.ToString(); + Type = policy.Type; + Enabled = policy.Enabled; + if (!string.IsNullOrWhiteSpace(policy.Data)) + { + Data = JsonSerializer.Deserialize>(policy.Data); + } } - Id = policy.Id.ToString(); - OrganizationId = policy.OrganizationId.ToString(); - Type = policy.Type; - Enabled = policy.Enabled; - if (!string.IsNullOrWhiteSpace(policy.Data)) - { - Data = JsonSerializer.Deserialize>(policy.Data); - } + public string Id { get; set; } + public string OrganizationId { get; set; } + public PolicyType Type { get; set; } + public Dictionary Data { get; set; } + public bool Enabled { get; set; } } - - public string Id { get; set; } - public string OrganizationId { get; set; } - public PolicyType Type { get; set; } - public Dictionary Data { get; set; } - public bool Enabled { get; set; } } diff --git a/src/Api/Models/Response/ProfileOrganizationResponseModel.cs b/src/Api/Models/Response/ProfileOrganizationResponseModel.cs index 4285ae432b..969dbbaf18 100644 --- a/src/Api/Models/Response/ProfileOrganizationResponseModel.cs +++ b/src/Api/Models/Response/ProfileOrganizationResponseModel.cs @@ -4,97 +4,98 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response; - -public class ProfileOrganizationResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public ProfileOrganizationResponseModel(string str) : base(str) { } - - public ProfileOrganizationResponseModel(OrganizationUserOrganizationDetails organization) : this("profileOrganization") + public class ProfileOrganizationResponseModel : ResponseModel { - Id = organization.OrganizationId.ToString(); - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; - Status = organization.Status; - Type = organization.Type; - Enabled = organization.Enabled; - SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId); - Identifier = organization.Identifier; - Permissions = CoreHelpers.LoadClassFromJsonData(organization.Permissions); - ResetPasswordEnrolled = organization.ResetPasswordKey != null; - UserId = organization.UserId?.ToString(); - ProviderId = organization.ProviderId?.ToString(); - ProviderName = organization.ProviderName; - FamilySponsorshipFriendlyName = organization.FamilySponsorshipFriendlyName; - FamilySponsorshipAvailable = FamilySponsorshipFriendlyName == null && - StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) - .UsersCanSponsor(organization); - PlanProductType = StaticStore.GetPlan(organization.PlanType).Product; - FamilySponsorshipLastSyncDate = organization.FamilySponsorshipLastSyncDate; - FamilySponsorshipToDelete = organization.FamilySponsorshipToDelete; - FamilySponsorshipValidUntil = organization.FamilySponsorshipValidUntil; + public ProfileOrganizationResponseModel(string str) : base(str) { } - if (organization.SsoConfig != null) + public ProfileOrganizationResponseModel(OrganizationUserOrganizationDetails organization) : this("profileOrganization") { - var ssoConfigData = SsoConfigurationData.Deserialize(organization.SsoConfig); - KeyConnectorEnabled = ssoConfigData.KeyConnectorEnabled && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); - KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; - } - } + Id = organization.OrganizationId.ToString(); + Name = organization.Name; + UsePolicies = organization.UsePolicies; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseGroups = organization.UseGroups; + UseDirectory = organization.UseDirectory; + UseEvents = organization.UseEvents; + UseTotp = organization.UseTotp; + Use2fa = organization.Use2fa; + UseApi = organization.UseApi; + UseResetPassword = organization.UseResetPassword; + UsersGetPremium = organization.UsersGetPremium; + SelfHost = organization.SelfHost; + Seats = organization.Seats; + MaxCollections = organization.MaxCollections; + MaxStorageGb = organization.MaxStorageGb; + Key = organization.Key; + HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; + Status = organization.Status; + Type = organization.Type; + Enabled = organization.Enabled; + SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId); + Identifier = organization.Identifier; + Permissions = CoreHelpers.LoadClassFromJsonData(organization.Permissions); + ResetPasswordEnrolled = organization.ResetPasswordKey != null; + UserId = organization.UserId?.ToString(); + ProviderId = organization.ProviderId?.ToString(); + ProviderName = organization.ProviderName; + FamilySponsorshipFriendlyName = organization.FamilySponsorshipFriendlyName; + FamilySponsorshipAvailable = FamilySponsorshipFriendlyName == null && + StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) + .UsersCanSponsor(organization); + PlanProductType = StaticStore.GetPlan(organization.PlanType).Product; + FamilySponsorshipLastSyncDate = organization.FamilySponsorshipLastSyncDate; + FamilySponsorshipToDelete = organization.FamilySponsorshipToDelete; + FamilySponsorshipValidUntil = organization.FamilySponsorshipValidUntil; - public string Id { get; set; } - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool UsersGetPremium { get; set; } - public bool SelfHost { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool Enabled { get; set; } - public bool SsoBound { get; set; } - public string Identifier { get; set; } - public Permissions Permissions { get; set; } - public bool ResetPasswordEnrolled { get; set; } - public string UserId { get; set; } - public bool HasPublicAndPrivateKeys { get; set; } - public string ProviderId { get; set; } - public string ProviderName { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } - public bool FamilySponsorshipAvailable { get; set; } - public ProductType PlanProductType { get; set; } - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } - public DateTime? FamilySponsorshipLastSyncDate { get; set; } - public DateTime? FamilySponsorshipValidUntil { get; set; } - public bool? FamilySponsorshipToDelete { get; set; } + if (organization.SsoConfig != null) + { + var ssoConfigData = SsoConfigurationData.Deserialize(organization.SsoConfig); + KeyConnectorEnabled = ssoConfigData.KeyConnectorEnabled && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); + KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; + } + } + + public string Id { get; set; } + public string Name { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool UsersGetPremium { get; set; } + public bool SelfHost { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string Key { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public bool Enabled { get; set; } + public bool SsoBound { get; set; } + public string Identifier { get; set; } + public Permissions Permissions { get; set; } + public bool ResetPasswordEnrolled { get; set; } + public string UserId { get; set; } + public bool HasPublicAndPrivateKeys { get; set; } + public string ProviderId { get; set; } + public string ProviderName { get; set; } + public string FamilySponsorshipFriendlyName { get; set; } + public bool FamilySponsorshipAvailable { get; set; } + public ProductType PlanProductType { get; set; } + public bool KeyConnectorEnabled { get; set; } + public string KeyConnectorUrl { get; set; } + public DateTime? FamilySponsorshipLastSyncDate { get; set; } + public DateTime? FamilySponsorshipValidUntil { get; set; } + public bool? FamilySponsorshipToDelete { get; set; } + } } diff --git a/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs b/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs index a660662fac..c2d7858b5b 100644 --- a/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs +++ b/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs @@ -1,42 +1,43 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class ProfileProviderOrganizationResponseModel : ProfileOrganizationResponseModel +namespace Bit.Api.Models.Response { - public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organization) - : base("profileProviderOrganization") + public class ProfileProviderOrganizationResponseModel : ProfileOrganizationResponseModel { - Id = organization.OrganizationId.ToString(); - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; - Status = OrganizationUserStatusType.Confirmed; // Provider users are always confirmed - Type = OrganizationUserType.Owner; // Provider users behave like Owners - Enabled = organization.Enabled; - SsoBound = false; - Identifier = organization.Identifier; - Permissions = new Permissions(); - ResetPasswordEnrolled = false; - UserId = organization.UserId?.ToString(); - ProviderId = organization.ProviderId?.ToString(); - ProviderName = organization.ProviderName; + public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organization) + : base("profileProviderOrganization") + { + Id = organization.OrganizationId.ToString(); + Name = organization.Name; + UsePolicies = organization.UsePolicies; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseGroups = organization.UseGroups; + UseDirectory = organization.UseDirectory; + UseEvents = organization.UseEvents; + UseTotp = organization.UseTotp; + Use2fa = organization.Use2fa; + UseApi = organization.UseApi; + UseResetPassword = organization.UseResetPassword; + UsersGetPremium = organization.UsersGetPremium; + SelfHost = organization.SelfHost; + Seats = organization.Seats; + MaxCollections = organization.MaxCollections; + MaxStorageGb = organization.MaxStorageGb; + Key = organization.Key; + HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; + Status = OrganizationUserStatusType.Confirmed; // Provider users are always confirmed + Type = OrganizationUserType.Owner; // Provider users behave like Owners + Enabled = organization.Enabled; + SsoBound = false; + Identifier = organization.Identifier; + Permissions = new Permissions(); + ResetPasswordEnrolled = false; + UserId = organization.UserId?.ToString(); + ProviderId = organization.ProviderId?.ToString(); + ProviderName = organization.ProviderName; + } } } diff --git a/src/Api/Models/Response/ProfileResponseModel.cs b/src/Api/Models/Response/ProfileResponseModel.cs index dfa9e5dac4..42e9943c46 100644 --- a/src/Api/Models/Response/ProfileResponseModel.cs +++ b/src/Api/Models/Response/ProfileResponseModel.cs @@ -4,61 +4,62 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Api.Models.Response; - -public class ProfileResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public ProfileResponseModel(User user, - IEnumerable organizationsUserDetails, - IEnumerable providerUserDetails, - IEnumerable providerUserOrganizationDetails, - bool twoFactorEnabled, - bool premiumFromOrganization) : base("profile") + public class ProfileResponseModel : ResponseModel { - if (user == null) + public ProfileResponseModel(User user, + IEnumerable organizationsUserDetails, + IEnumerable providerUserDetails, + IEnumerable providerUserOrganizationDetails, + bool twoFactorEnabled, + bool premiumFromOrganization) : base("profile") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Id = user.Id.ToString(); + Name = user.Name; + Email = user.Email; + EmailVerified = user.EmailVerified; + Premium = user.Premium; + PremiumFromOrganization = premiumFromOrganization; + MasterPasswordHint = string.IsNullOrWhiteSpace(user.MasterPasswordHint) ? null : user.MasterPasswordHint; + Culture = user.Culture; + TwoFactorEnabled = twoFactorEnabled; + Key = user.Key; + PrivateKey = user.PrivateKey; + SecurityStamp = user.SecurityStamp; + ForcePasswordReset = user.ForcePasswordReset; + UsesKeyConnector = user.UsesKeyConnector; + Organizations = organizationsUserDetails?.Select(o => new ProfileOrganizationResponseModel(o)); + Providers = providerUserDetails?.Select(p => new ProfileProviderResponseModel(p)); + ProviderOrganizations = + providerUserOrganizationDetails?.Select(po => new ProfileProviderOrganizationResponseModel(po)); } - Id = user.Id.ToString(); - Name = user.Name; - Email = user.Email; - EmailVerified = user.EmailVerified; - Premium = user.Premium; - PremiumFromOrganization = premiumFromOrganization; - MasterPasswordHint = string.IsNullOrWhiteSpace(user.MasterPasswordHint) ? null : user.MasterPasswordHint; - Culture = user.Culture; - TwoFactorEnabled = twoFactorEnabled; - Key = user.Key; - PrivateKey = user.PrivateKey; - SecurityStamp = user.SecurityStamp; - ForcePasswordReset = user.ForcePasswordReset; - UsesKeyConnector = user.UsesKeyConnector; - Organizations = organizationsUserDetails?.Select(o => new ProfileOrganizationResponseModel(o)); - Providers = providerUserDetails?.Select(p => new ProfileProviderResponseModel(p)); - ProviderOrganizations = - providerUserOrganizationDetails?.Select(po => new ProfileProviderOrganizationResponseModel(po)); - } + public ProfileResponseModel() : base("profile") + { + } - public ProfileResponseModel() : base("profile") - { + public string Id { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public bool EmailVerified { get; set; } + public bool Premium { get; set; } + public bool PremiumFromOrganization { get; set; } + public string MasterPasswordHint { get; set; } + public string Culture { get; set; } + public bool TwoFactorEnabled { get; set; } + public string Key { get; set; } + public string PrivateKey { get; set; } + public string SecurityStamp { get; set; } + public bool ForcePasswordReset { get; set; } + public bool UsesKeyConnector { get; set; } + public IEnumerable Organizations { get; set; } + public IEnumerable Providers { get; set; } + public IEnumerable ProviderOrganizations { get; set; } } - - public string Id { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public bool EmailVerified { get; set; } - public bool Premium { get; set; } - public bool PremiumFromOrganization { get; set; } - public string MasterPasswordHint { get; set; } - public string Culture { get; set; } - public bool TwoFactorEnabled { get; set; } - public string Key { get; set; } - public string PrivateKey { get; set; } - public string SecurityStamp { get; set; } - public bool ForcePasswordReset { get; set; } - public bool UsesKeyConnector { get; set; } - public IEnumerable Organizations { get; set; } - public IEnumerable Providers { get; set; } - public IEnumerable ProviderOrganizations { get; set; } } diff --git a/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs b/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs index 7a218d1c78..c8a0c38182 100644 --- a/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs @@ -3,31 +3,32 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Providers; - -public class ProfileProviderResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Providers { - public ProfileProviderResponseModel(ProviderUserProviderDetails provider) - : base("profileProvider") + public class ProfileProviderResponseModel : ResponseModel { - Id = provider.ProviderId.ToString(); - Name = provider.Name; - Key = provider.Key; - Status = provider.Status; - Type = provider.Type; - Enabled = provider.Enabled; - Permissions = CoreHelpers.LoadClassFromJsonData(provider.Permissions); - UserId = provider.UserId?.ToString(); - UseEvents = provider.UseEvents; - } + public ProfileProviderResponseModel(ProviderUserProviderDetails provider) + : base("profileProvider") + { + Id = provider.ProviderId.ToString(); + Name = provider.Name; + Key = provider.Key; + Status = provider.Status; + Type = provider.Type; + Enabled = provider.Enabled; + Permissions = CoreHelpers.LoadClassFromJsonData(provider.Permissions); + UserId = provider.UserId?.ToString(); + UseEvents = provider.UseEvents; + } - public string Id { get; set; } - public string Name { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public bool Enabled { get; set; } - public Permissions Permissions { get; set; } - public string UserId { get; set; } - public bool UseEvents { get; set; } + public string Id { get; set; } + public string Name { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public bool Enabled { get; set; } + public Permissions Permissions { get; set; } + public string UserId { get; set; } + public bool UseEvents { get; set; } + } } diff --git a/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs b/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs index 9bc7d52dc6..e508787a08 100644 --- a/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs @@ -2,71 +2,72 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response.Providers; - -public class ProviderOrganizationResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Providers { - public ProviderOrganizationResponseModel(ProviderOrganization providerOrganization, - string obj = "providerOrganization") : base(obj) + public class ProviderOrganizationResponseModel : ResponseModel { - if (providerOrganization == null) + public ProviderOrganizationResponseModel(ProviderOrganization providerOrganization, + string obj = "providerOrganization") : base(obj) { - throw new ArgumentNullException(nameof(providerOrganization)); + if (providerOrganization == null) + { + throw new ArgumentNullException(nameof(providerOrganization)); + } + + Id = providerOrganization.Id; + ProviderId = providerOrganization.ProviderId; + OrganizationId = providerOrganization.OrganizationId; + Key = providerOrganization.Key; + Settings = providerOrganization.Settings; + CreationDate = providerOrganization.CreationDate; + RevisionDate = providerOrganization.RevisionDate; } - Id = providerOrganization.Id; - ProviderId = providerOrganization.ProviderId; - OrganizationId = providerOrganization.OrganizationId; - Key = providerOrganization.Key; - Settings = providerOrganization.Settings; - CreationDate = providerOrganization.CreationDate; - RevisionDate = providerOrganization.RevisionDate; - } - - public ProviderOrganizationResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, - string obj = "providerOrganization") : base(obj) - { - if (providerOrganization == null) + public ProviderOrganizationResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, + string obj = "providerOrganization") : base(obj) { - throw new ArgumentNullException(nameof(providerOrganization)); + if (providerOrganization == null) + { + throw new ArgumentNullException(nameof(providerOrganization)); + } + + Id = providerOrganization.Id; + ProviderId = providerOrganization.ProviderId; + OrganizationId = providerOrganization.OrganizationId; + Key = providerOrganization.Key; + Settings = providerOrganization.Settings; + CreationDate = providerOrganization.CreationDate; + RevisionDate = providerOrganization.RevisionDate; + UserCount = providerOrganization.UserCount; + Seats = providerOrganization.Seats; + Plan = providerOrganization.Plan; } - Id = providerOrganization.Id; - ProviderId = providerOrganization.ProviderId; - OrganizationId = providerOrganization.OrganizationId; - Key = providerOrganization.Key; - Settings = providerOrganization.Settings; - CreationDate = providerOrganization.CreationDate; - RevisionDate = providerOrganization.RevisionDate; - UserCount = providerOrganization.UserCount; - Seats = providerOrganization.Seats; - Plan = providerOrganization.Plan; + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid OrganizationId { get; set; } + public string Key { get; set; } + public string Settings { get; set; } + public DateTime CreationDate { get; set; } + public DateTime RevisionDate { get; set; } + public int UserCount { get; set; } + public int? Seats { get; set; } + public string Plan { get; set; } } - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string Key { get; set; } - public string Settings { get; set; } - public DateTime CreationDate { get; set; } - public DateTime RevisionDate { get; set; } - public int UserCount { get; set; } - public int? Seats { get; set; } - public string Plan { get; set; } -} - -public class ProviderOrganizationOrganizationDetailsResponseModel : ProviderOrganizationResponseModel -{ - public ProviderOrganizationOrganizationDetailsResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, - string obj = "providerOrganizationOrganizationDetail") : base(providerOrganization, obj) - { - if (providerOrganization == null) - { - throw new ArgumentNullException(nameof(providerOrganization)); - } - - OrganizationName = providerOrganization.OrganizationName; - } - - public string OrganizationName { get; set; } + public class ProviderOrganizationOrganizationDetailsResponseModel : ProviderOrganizationResponseModel + { + public ProviderOrganizationOrganizationDetailsResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, + string obj = "providerOrganizationOrganizationDetail") : base(providerOrganization, obj) + { + if (providerOrganization == null) + { + throw new ArgumentNullException(nameof(providerOrganization)); + } + + OrganizationName = providerOrganization.OrganizationName; + } + + public string OrganizationName { get; set; } + } } diff --git a/src/Api/Models/Response/Providers/ProviderResponseModel.cs b/src/Api/Models/Response/Providers/ProviderResponseModel.cs index ce62fdaa68..02cea09d14 100644 --- a/src/Api/Models/Response/Providers/ProviderResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProviderResponseModel.cs @@ -1,35 +1,36 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Providers; - -public class ProviderResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Providers { - public ProviderResponseModel(Provider provider, string obj = "provider") : base(obj) + public class ProviderResponseModel : ResponseModel { - if (provider == null) + public ProviderResponseModel(Provider provider, string obj = "provider") : base(obj) { - throw new ArgumentNullException(nameof(provider)); + if (provider == null) + { + throw new ArgumentNullException(nameof(provider)); + } + + Id = provider.Id; + Name = provider.Name; + BusinessName = provider.BusinessName; + BusinessAddress1 = provider.BusinessAddress1; + BusinessAddress2 = provider.BusinessAddress2; + BusinessAddress3 = provider.BusinessAddress3; + BusinessCountry = provider.BusinessCountry; + BusinessTaxNumber = provider.BusinessTaxNumber; + BillingEmail = provider.BillingEmail; } - Id = provider.Id; - Name = provider.Name; - BusinessName = provider.BusinessName; - BusinessAddress1 = provider.BusinessAddress1; - BusinessAddress2 = provider.BusinessAddress2; - BusinessAddress3 = provider.BusinessAddress3; - BusinessCountry = provider.BusinessCountry; - BusinessTaxNumber = provider.BusinessTaxNumber; - BillingEmail = provider.BillingEmail; + public Guid Id { get; set; } + public string Name { get; set; } + public string BusinessName { get; set; } + public string BusinessAddress1 { get; set; } + public string BusinessAddress2 { get; set; } + public string BusinessAddress3 { get; set; } + public string BusinessCountry { get; set; } + public string BusinessTaxNumber { get; set; } + public string BillingEmail { get; set; } } - - public Guid Id { get; set; } - public string Name { get; set; } - public string BusinessName { get; set; } - public string BusinessAddress1 { get; set; } - public string BusinessAddress2 { get; set; } - public string BusinessAddress3 { get; set; } - public string BusinessCountry { get; set; } - public string BusinessTaxNumber { get; set; } - public string BillingEmail { get; set; } } diff --git a/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs b/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs index b08e39e198..44122b2b07 100644 --- a/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs @@ -4,88 +4,89 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Providers; - -public class ProviderUserResponseModel : ResponseModel +namespace Bit.Api.Models.Response.Providers { - public ProviderUserResponseModel(ProviderUser providerUser, string obj = "providerUser") - : base(obj) + public class ProviderUserResponseModel : ResponseModel { - if (providerUser == null) + public ProviderUserResponseModel(ProviderUser providerUser, string obj = "providerUser") + : base(obj) { - throw new ArgumentNullException(nameof(providerUser)); + if (providerUser == null) + { + throw new ArgumentNullException(nameof(providerUser)); + } + + Id = providerUser.Id.ToString(); + UserId = providerUser.UserId?.ToString(); + Type = providerUser.Type; + Status = providerUser.Status; + Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); } - Id = providerUser.Id.ToString(); - UserId = providerUser.UserId?.ToString(); - Type = providerUser.Type; - Status = providerUser.Status; - Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); - } - - public ProviderUserResponseModel(ProviderUserUserDetails providerUser, string obj = "providerUser") - : base(obj) - { - if (providerUser == null) + public ProviderUserResponseModel(ProviderUserUserDetails providerUser, string obj = "providerUser") + : base(obj) { - throw new ArgumentNullException(nameof(providerUser)); + if (providerUser == null) + { + throw new ArgumentNullException(nameof(providerUser)); + } + + Id = providerUser.Id.ToString(); + UserId = providerUser.UserId?.ToString(); + Type = providerUser.Type; + Status = providerUser.Status; + Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); } - Id = providerUser.Id.ToString(); - UserId = providerUser.UserId?.ToString(); - Type = providerUser.Type; - Status = providerUser.Status; - Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); + public string Id { get; set; } + public string UserId { get; set; } + public ProviderUserType Type { get; set; } + public ProviderUserStatusType Status { get; set; } + public Permissions Permissions { get; set; } } - public string Id { get; set; } - public string UserId { get; set; } - public ProviderUserType Type { get; set; } - public ProviderUserStatusType Status { get; set; } - public Permissions Permissions { get; set; } -} - -public class ProviderUserUserDetailsResponseModel : ProviderUserResponseModel -{ - public ProviderUserUserDetailsResponseModel(ProviderUserUserDetails providerUser, - string obj = "providerUserUserDetails") : base(providerUser, obj) + public class ProviderUserUserDetailsResponseModel : ProviderUserResponseModel { - if (providerUser == null) + public ProviderUserUserDetailsResponseModel(ProviderUserUserDetails providerUser, + string obj = "providerUserUserDetails") : base(providerUser, obj) { - throw new ArgumentNullException(nameof(providerUser)); + if (providerUser == null) + { + throw new ArgumentNullException(nameof(providerUser)); + } + + Name = providerUser.Name; + Email = providerUser.Email; } - Name = providerUser.Name; - Email = providerUser.Email; + public string Name { get; set; } + public string Email { get; set; } } - public string Name { get; set; } - public string Email { get; set; } -} - -public class ProviderUserPublicKeyResponseModel : ResponseModel -{ - public ProviderUserPublicKeyResponseModel(Guid id, Guid userId, string key, - string obj = "providerUserPublicKeyResponseModel") : base(obj) + public class ProviderUserPublicKeyResponseModel : ResponseModel { - Id = id; - UserId = userId; - Key = key; + public ProviderUserPublicKeyResponseModel(Guid id, Guid userId, string key, + string obj = "providerUserPublicKeyResponseModel") : base(obj) + { + Id = id; + UserId = userId; + Key = key; + } + + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string Key { get; set; } } - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string Key { get; set; } -} - -public class ProviderUserBulkResponseModel : ResponseModel -{ - public ProviderUserBulkResponseModel(Guid id, string error, - string obj = "providerBulkConfirmResponseModel") : base(obj) + public class ProviderUserBulkResponseModel : ResponseModel { - Id = id; - Error = error; + public ProviderUserBulkResponseModel(Guid id, string error, + string obj = "providerBulkConfirmResponseModel") : base(obj) + { + Id = id; + Error = error; + } + public Guid Id { get; set; } + public string Error { get; set; } } - public Guid Id { get; set; } - public string Error { get; set; } } diff --git a/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs b/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs index 0d4cc637d1..a3ff0ddf63 100644 --- a/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs +++ b/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs @@ -1,22 +1,23 @@ using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class SelectionReadOnlyResponseModel +namespace Bit.Api.Models.Response { - public SelectionReadOnlyResponseModel(SelectionReadOnly selection) + public class SelectionReadOnlyResponseModel { - if (selection == null) + public SelectionReadOnlyResponseModel(SelectionReadOnly selection) { - throw new ArgumentNullException(nameof(selection)); + if (selection == null) + { + throw new ArgumentNullException(nameof(selection)); + } + + Id = selection.Id.ToString(); + ReadOnly = selection.ReadOnly; + HidePasswords = selection.HidePasswords; } - Id = selection.Id.ToString(); - ReadOnly = selection.ReadOnly; - HidePasswords = selection.HidePasswords; + public string Id { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } - - public string Id { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } } diff --git a/src/Api/Models/Response/SendAccessResponseModel.cs b/src/Api/Models/Response/SendAccessResponseModel.cs index d4620385b8..7e2adc04f1 100644 --- a/src/Api/Models/Response/SendAccessResponseModel.cs +++ b/src/Api/Models/Response/SendAccessResponseModel.cs @@ -6,47 +6,48 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response; - -public class SendAccessResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public SendAccessResponseModel(Send send, GlobalSettings globalSettings) - : base("send-access") + public class SendAccessResponseModel : ResponseModel { - if (send == null) + public SendAccessResponseModel(Send send, GlobalSettings globalSettings) + : base("send-access") { - throw new ArgumentNullException(nameof(send)); + if (send == null) + { + throw new ArgumentNullException(nameof(send)); + } + + Id = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); + Type = send.Type; + + SendData sendData; + switch (send.Type) + { + case SendType.File: + var fileData = JsonSerializer.Deserialize(send.Data); + sendData = fileData; + File = new SendFileModel(fileData); + break; + case SendType.Text: + var textData = JsonSerializer.Deserialize(send.Data); + sendData = textData; + Text = new SendTextModel(textData); + break; + default: + throw new ArgumentException("Unsupported " + nameof(Type) + "."); + } + + Name = sendData.Name; + ExpirationDate = send.ExpirationDate; } - Id = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); - Type = send.Type; - - SendData sendData; - switch (send.Type) - { - case SendType.File: - var fileData = JsonSerializer.Deserialize(send.Data); - sendData = fileData; - File = new SendFileModel(fileData); - break; - case SendType.Text: - var textData = JsonSerializer.Deserialize(send.Data); - sendData = textData; - Text = new SendTextModel(textData); - break; - default: - throw new ArgumentException("Unsupported " + nameof(Type) + "."); - } - - Name = sendData.Name; - ExpirationDate = send.ExpirationDate; + public string Id { get; set; } + public SendType Type { get; set; } + public string Name { get; set; } + public SendFileModel File { get; set; } + public SendTextModel Text { get; set; } + public DateTime? ExpirationDate { get; set; } + public string CreatorIdentifier { get; set; } } - - public string Id { get; set; } - public SendType Type { get; set; } - public string Name { get; set; } - public SendFileModel File { get; set; } - public SendTextModel Text { get; set; } - public DateTime? ExpirationDate { get; set; } - public string CreatorIdentifier { get; set; } } diff --git a/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs b/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs index 24e3a53f74..e8efed8a49 100644 --- a/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs +++ b/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs @@ -1,11 +1,12 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class SendFileDownloadDataResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public string Id { get; set; } - public string Url { get; set; } + public class SendFileDownloadDataResponseModel : ResponseModel + { + public string Id { get; set; } + public string Url { get; set; } - public SendFileDownloadDataResponseModel() : base("send-fileDownload") { } + public SendFileDownloadDataResponseModel() : base("send-fileDownload") { } + } } diff --git a/src/Api/Models/Response/SendFileUploadDataResponseModel.cs b/src/Api/Models/Response/SendFileUploadDataResponseModel.cs index 0e7b4997c9..20e3694fe7 100644 --- a/src/Api/Models/Response/SendFileUploadDataResponseModel.cs +++ b/src/Api/Models/Response/SendFileUploadDataResponseModel.cs @@ -1,14 +1,15 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class SendFileUploadDataResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public SendFileUploadDataResponseModel() : base("send-fileUpload") { } + public class SendFileUploadDataResponseModel : ResponseModel + { + public SendFileUploadDataResponseModel() : base("send-fileUpload") { } - public string Url { get; set; } - public FileUploadType FileUploadType { get; set; } - public SendResponseModel SendResponse { get; set; } + public string Url { get; set; } + public FileUploadType FileUploadType { get; set; } + public SendResponseModel SendResponse { get; set; } + } } diff --git a/src/Api/Models/Response/SendResponseModel.cs b/src/Api/Models/Response/SendResponseModel.cs index 42552d2a4b..c4f88157d3 100644 --- a/src/Api/Models/Response/SendResponseModel.cs +++ b/src/Api/Models/Response/SendResponseModel.cs @@ -6,66 +6,67 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response; - -public class SendResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public SendResponseModel(Send send, GlobalSettings globalSettings) - : base("send") + public class SendResponseModel : ResponseModel { - if (send == null) + public SendResponseModel(Send send, GlobalSettings globalSettings) + : base("send") { - throw new ArgumentNullException(nameof(send)); + if (send == null) + { + throw new ArgumentNullException(nameof(send)); + } + + Id = send.Id.ToString(); + AccessId = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); + Type = send.Type; + Key = send.Key; + MaxAccessCount = send.MaxAccessCount; + AccessCount = send.AccessCount; + RevisionDate = send.RevisionDate; + ExpirationDate = send.ExpirationDate; + DeletionDate = send.DeletionDate; + Password = send.Password; + Disabled = send.Disabled; + HideEmail = send.HideEmail.GetValueOrDefault(); + + SendData sendData; + switch (send.Type) + { + case SendType.File: + var fileData = JsonSerializer.Deserialize(send.Data); + sendData = fileData; + File = new SendFileModel(fileData); + break; + case SendType.Text: + var textData = JsonSerializer.Deserialize(send.Data); + sendData = textData; + Text = new SendTextModel(textData); + break; + default: + throw new ArgumentException("Unsupported " + nameof(Type) + "."); + } + + Name = sendData.Name; + Notes = sendData.Notes; } - Id = send.Id.ToString(); - AccessId = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); - Type = send.Type; - Key = send.Key; - MaxAccessCount = send.MaxAccessCount; - AccessCount = send.AccessCount; - RevisionDate = send.RevisionDate; - ExpirationDate = send.ExpirationDate; - DeletionDate = send.DeletionDate; - Password = send.Password; - Disabled = send.Disabled; - HideEmail = send.HideEmail.GetValueOrDefault(); - - SendData sendData; - switch (send.Type) - { - case SendType.File: - var fileData = JsonSerializer.Deserialize(send.Data); - sendData = fileData; - File = new SendFileModel(fileData); - break; - case SendType.Text: - var textData = JsonSerializer.Deserialize(send.Data); - sendData = textData; - Text = new SendTextModel(textData); - break; - default: - throw new ArgumentException("Unsupported " + nameof(Type) + "."); - } - - Name = sendData.Name; - Notes = sendData.Notes; + public string Id { get; set; } + public string AccessId { get; set; } + public SendType Type { get; set; } + public string Name { get; set; } + public string Notes { get; set; } + public SendFileModel File { get; set; } + public SendTextModel Text { get; set; } + public string Key { get; set; } + public int? MaxAccessCount { get; set; } + public int AccessCount { get; set; } + public string Password { get; set; } + public bool Disabled { get; set; } + public DateTime RevisionDate { get; set; } + public DateTime? ExpirationDate { get; set; } + public DateTime DeletionDate { get; set; } + public bool HideEmail { get; set; } } - - public string Id { get; set; } - public string AccessId { get; set; } - public SendType Type { get; set; } - public string Name { get; set; } - public string Notes { get; set; } - public SendFileModel File { get; set; } - public SendTextModel Text { get; set; } - public string Key { get; set; } - public int? MaxAccessCount { get; set; } - public int AccessCount { get; set; } - public string Password { get; set; } - public bool Disabled { get; set; } - public DateTime RevisionDate { get; set; } - public DateTime? ExpirationDate { get; set; } - public DateTime DeletionDate { get; set; } - public bool HideEmail { get; set; } } diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index 4888bd2080..e8b9dbbcb0 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -3,103 +3,104 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response; - -public class SubscriptionResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license) - : base("subscription") + public class SubscriptionResponseModel : ResponseModel { - Subscription = subscription.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; - UpcomingInvoice = subscription.UpcomingInvoice != null ? - new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; - StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; - StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB - MaxStorageGb = user.MaxStorageGb; - License = license; - Expiration = License.Expires; - UsingInAppPurchase = subscription.UsingInAppPurchase; - } - - public SubscriptionResponseModel(User user, UserLicense license = null) - : base("subscription") - { - StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; - StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB - MaxStorageGb = user.MaxStorageGb; - Expiration = user.PremiumExpirationDate; - - if (license != null) + public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license) + : base("subscription") { + Subscription = subscription.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; + UpcomingInvoice = subscription.UpcomingInvoice != null ? + new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; + StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; + StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB + MaxStorageGb = user.MaxStorageGb; License = license; + Expiration = License.Expires; + UsingInAppPurchase = subscription.UsingInAppPurchase; + } + + public SubscriptionResponseModel(User user, UserLicense license = null) + : base("subscription") + { + StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; + StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB + MaxStorageGb = user.MaxStorageGb; + Expiration = user.PremiumExpirationDate; + + if (license != null) + { + License = license; + } + } + + public string StorageName { get; set; } + public double? StorageGb { get; set; } + public short? MaxStorageGb { get; set; } + public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } + public BillingSubscription Subscription { get; set; } + public UserLicense License { get; set; } + public DateTime? Expiration { get; set; } + public bool UsingInAppPurchase { get; set; } + } + + public class BillingSubscription + { + public BillingSubscription(SubscriptionInfo.BillingSubscription sub) + { + Status = sub.Status; + TrialStartDate = sub.TrialStartDate; + TrialEndDate = sub.TrialEndDate; + PeriodStartDate = sub.PeriodStartDate; + PeriodEndDate = sub.PeriodEndDate; + CancelledDate = sub.CancelledDate; + CancelAtEndDate = sub.CancelAtEndDate; + Cancelled = sub.Cancelled; + if (sub.Items != null) + { + Items = sub.Items.Select(i => new BillingSubscriptionItem(i)); + } + } + + public DateTime? TrialStartDate { get; set; } + public DateTime? TrialEndDate { get; set; } + public DateTime? PeriodStartDate { get; set; } + public DateTime? PeriodEndDate { get; set; } + public DateTime? CancelledDate { get; set; } + public bool CancelAtEndDate { get; set; } + public string Status { get; set; } + public bool Cancelled { get; set; } + public IEnumerable Items { get; set; } = new List(); + + public class BillingSubscriptionItem + { + public BillingSubscriptionItem(SubscriptionInfo.BillingSubscription.BillingSubscriptionItem item) + { + Name = item.Name; + Amount = item.Amount; + Interval = item.Interval; + Quantity = item.Quantity; + SponsoredSubscriptionItem = item.SponsoredSubscriptionItem; + } + + public string Name { get; set; } + public decimal Amount { get; set; } + public int Quantity { get; set; } + public string Interval { get; set; } + public bool SponsoredSubscriptionItem { get; set; } } } - public string StorageName { get; set; } - public double? StorageGb { get; set; } - public short? MaxStorageGb { get; set; } - public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } - public BillingSubscription Subscription { get; set; } - public UserLicense License { get; set; } - public DateTime? Expiration { get; set; } - public bool UsingInAppPurchase { get; set; } -} - -public class BillingSubscription -{ - public BillingSubscription(SubscriptionInfo.BillingSubscription sub) + public class BillingSubscriptionUpcomingInvoice { - Status = sub.Status; - TrialStartDate = sub.TrialStartDate; - TrialEndDate = sub.TrialEndDate; - PeriodStartDate = sub.PeriodStartDate; - PeriodEndDate = sub.PeriodEndDate; - CancelledDate = sub.CancelledDate; - CancelAtEndDate = sub.CancelAtEndDate; - Cancelled = sub.Cancelled; - if (sub.Items != null) + public BillingSubscriptionUpcomingInvoice(SubscriptionInfo.BillingUpcomingInvoice inv) { - Items = sub.Items.Select(i => new BillingSubscriptionItem(i)); - } - } - - public DateTime? TrialStartDate { get; set; } - public DateTime? TrialEndDate { get; set; } - public DateTime? PeriodStartDate { get; set; } - public DateTime? PeriodEndDate { get; set; } - public DateTime? CancelledDate { get; set; } - public bool CancelAtEndDate { get; set; } - public string Status { get; set; } - public bool Cancelled { get; set; } - public IEnumerable Items { get; set; } = new List(); - - public class BillingSubscriptionItem - { - public BillingSubscriptionItem(SubscriptionInfo.BillingSubscription.BillingSubscriptionItem item) - { - Name = item.Name; - Amount = item.Amount; - Interval = item.Interval; - Quantity = item.Quantity; - SponsoredSubscriptionItem = item.SponsoredSubscriptionItem; + Amount = inv.Amount; + Date = inv.Date; } - public string Name { get; set; } public decimal Amount { get; set; } - public int Quantity { get; set; } - public string Interval { get; set; } - public bool SponsoredSubscriptionItem { get; set; } + public DateTime? Date { get; set; } } } - -public class BillingSubscriptionUpcomingInvoice -{ - public BillingSubscriptionUpcomingInvoice(SubscriptionInfo.BillingUpcomingInvoice inv) - { - Amount = inv.Amount; - Date = inv.Date; - } - - public decimal Amount { get; set; } - public DateTime? Date { get; set; } -} diff --git a/src/Api/Models/Response/SyncResponseModel.cs b/src/Api/Models/Response/SyncResponseModel.cs index 6d028b12f8..8c9f126862 100644 --- a/src/Api/Models/Response/SyncResponseModel.cs +++ b/src/Api/Models/Response/SyncResponseModel.cs @@ -5,43 +5,44 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Settings; using Core.Models.Data; -namespace Bit.Api.Models.Response; - -public class SyncResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public SyncResponseModel( - GlobalSettings globalSettings, - User user, - bool userTwoFactorEnabled, - bool userHasPremiumFromOrganization, - IEnumerable organizationUserDetails, - IEnumerable providerUserDetails, - IEnumerable providerUserOrganizationDetails, - IEnumerable folders, - IEnumerable collections, - IEnumerable ciphers, - IDictionary> collectionCiphersDict, - bool excludeDomains, - IEnumerable policies, - IEnumerable sends) - : base("sync") + public class SyncResponseModel : ResponseModel { - Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, - providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization); - Folders = folders.Select(f => new FolderResponseModel(f)); - Ciphers = ciphers.Select(c => new CipherDetailsResponseModel(c, globalSettings, collectionCiphersDict)); - Collections = collections?.Select( - c => new CollectionDetailsResponseModel(c)) ?? new List(); - Domains = excludeDomains ? null : new DomainsResponseModel(user, false); - Policies = policies?.Select(p => new PolicyResponseModel(p)) ?? new List(); - Sends = sends.Select(s => new SendResponseModel(s, globalSettings)); - } + public SyncResponseModel( + GlobalSettings globalSettings, + User user, + bool userTwoFactorEnabled, + bool userHasPremiumFromOrganization, + IEnumerable organizationUserDetails, + IEnumerable providerUserDetails, + IEnumerable providerUserOrganizationDetails, + IEnumerable folders, + IEnumerable collections, + IEnumerable ciphers, + IDictionary> collectionCiphersDict, + bool excludeDomains, + IEnumerable policies, + IEnumerable sends) + : base("sync") + { + Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization); + Folders = folders.Select(f => new FolderResponseModel(f)); + Ciphers = ciphers.Select(c => new CipherDetailsResponseModel(c, globalSettings, collectionCiphersDict)); + Collections = collections?.Select( + c => new CollectionDetailsResponseModel(c)) ?? new List(); + Domains = excludeDomains ? null : new DomainsResponseModel(user, false); + Policies = policies?.Select(p => new PolicyResponseModel(p)) ?? new List(); + Sends = sends.Select(s => new SendResponseModel(s, globalSettings)); + } - public ProfileResponseModel Profile { get; set; } - public IEnumerable Folders { get; set; } - public IEnumerable Collections { get; set; } - public IEnumerable Ciphers { get; set; } - public DomainsResponseModel Domains { get; set; } - public IEnumerable Policies { get; set; } - public IEnumerable Sends { get; set; } + public ProfileResponseModel Profile { get; set; } + public IEnumerable Folders { get; set; } + public IEnumerable Collections { get; set; } + public IEnumerable Ciphers { get; set; } + public DomainsResponseModel Domains { get; set; } + public IEnumerable Policies { get; set; } + public IEnumerable Sends { get; set; } + } } diff --git a/src/Api/Models/Response/TaxInfoResponseModel.cs b/src/Api/Models/Response/TaxInfoResponseModel.cs index c1cd51267e..6ba6bad458 100644 --- a/src/Api/Models/Response/TaxInfoResponseModel.cs +++ b/src/Api/Models/Response/TaxInfoResponseModel.cs @@ -1,34 +1,35 @@ using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response; - -public class TaxInfoResponseModel +namespace Bit.Api.Models.Response { - public TaxInfoResponseModel() { } - - public TaxInfoResponseModel(TaxInfo taxInfo) + public class TaxInfoResponseModel { - if (taxInfo == null) + public TaxInfoResponseModel() { } + + public TaxInfoResponseModel(TaxInfo taxInfo) { - return; + if (taxInfo == null) + { + return; + } + + TaxIdNumber = taxInfo.TaxIdNumber; + TaxIdType = taxInfo.TaxIdType; + Line1 = taxInfo.BillingAddressLine1; + Line2 = taxInfo.BillingAddressLine2; + City = taxInfo.BillingAddressCity; + State = taxInfo.BillingAddressState; + PostalCode = taxInfo.BillingAddressPostalCode; + Country = taxInfo.BillingAddressCountry; } - TaxIdNumber = taxInfo.TaxIdNumber; - TaxIdType = taxInfo.TaxIdType; - Line1 = taxInfo.BillingAddressLine1; - Line2 = taxInfo.BillingAddressLine2; - City = taxInfo.BillingAddressCity; - State = taxInfo.BillingAddressState; - PostalCode = taxInfo.BillingAddressPostalCode; - Country = taxInfo.BillingAddressCountry; + public string TaxIdNumber { get; set; } + public string TaxIdType { get; set; } + public string Line1 { get; set; } + public string Line2 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Country { get; set; } } - - public string TaxIdNumber { get; set; } - public string TaxIdType { get; set; } - public string Line1 { get; set; } - public string Line2 { get; set; } - public string City { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public string Country { get; set; } } diff --git a/src/Api/Models/Response/TaxRateResponseModel.cs b/src/Api/Models/Response/TaxRateResponseModel.cs index 2c3335314c..ec08cb7f73 100644 --- a/src/Api/Models/Response/TaxRateResponseModel.cs +++ b/src/Api/Models/Response/TaxRateResponseModel.cs @@ -1,28 +1,29 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class TaxRateResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public TaxRateResponseModel(TaxRate taxRate) - : base("profile") + public class TaxRateResponseModel : ResponseModel { - if (taxRate == null) + public TaxRateResponseModel(TaxRate taxRate) + : base("profile") { - throw new ArgumentNullException(nameof(taxRate)); + if (taxRate == null) + { + throw new ArgumentNullException(nameof(taxRate)); + } + + Id = taxRate.Id; + Country = taxRate.Country; + State = taxRate.State; + PostalCode = taxRate.PostalCode; + Rate = taxRate.Rate; } - Id = taxRate.Id; - Country = taxRate.Country; - State = taxRate.State; - PostalCode = taxRate.PostalCode; - Rate = taxRate.Rate; + public string Id { get; set; } + public string Country { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public decimal Rate { get; set; } } - - public string Id { get; set; } - public string Country { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public decimal Rate { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs index 0a283b7e60..3747a411a9 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs @@ -3,32 +3,33 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; using OtpNet; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorAuthenticatorResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - public TwoFactorAuthenticatorResponseModel(User user) - : base("twoFactorAuthenticator") + public class TwoFactorAuthenticatorResponseModel : ResponseModel { - if (user == null) + public TwoFactorAuthenticatorResponseModel(User user) + : base("twoFactorAuthenticator") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); + if (provider?.MetaData?.ContainsKey("Key") ?? false) + { + Key = (string)provider.MetaData["Key"]; + Enabled = provider.Enabled; + } + else + { + var key = KeyGeneration.GenerateRandomKey(20); + Key = Base32Encoding.ToString(key); + Enabled = false; + } } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); - if (provider?.MetaData?.ContainsKey("Key") ?? false) - { - Key = (string)provider.MetaData["Key"]; - Enabled = provider.Enabled; - } - else - { - var key = KeyGeneration.GenerateRandomKey(20); - Key = Base32Encoding.ToString(key); - Enabled = false; - } + public bool Enabled { get; set; } + public string Key { get; set; } } - - public bool Enabled { get; set; } - public string Key { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs index 3331a8d766..c2461abdb7 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs @@ -3,63 +3,64 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorDuoResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - private const string ResponseObj = "twoFactorDuo"; - - public TwoFactorDuoResponseModel(User user) - : base(ResponseObj) + public class TwoFactorDuoResponseModel : ResponseModel { - if (user == null) + private const string ResponseObj = "twoFactorDuo"; + + public TwoFactorDuoResponseModel(User user) + : base(ResponseObj) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + Build(provider); } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - Build(provider); - } - - public TwoFactorDuoResponseModel(Organization org) - : base(ResponseObj) - { - if (org == null) + public TwoFactorDuoResponseModel(Organization org) + : base(ResponseObj) { - throw new ArgumentNullException(nameof(org)); + if (org == null) + { + throw new ArgumentNullException(nameof(org)); + } + + var provider = org.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + Build(provider); } - var provider = org.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - Build(provider); - } + public bool Enabled { get; set; } + public string Host { get; set; } + public string SecretKey { get; set; } + public string IntegrationKey { get; set; } - public bool Enabled { get; set; } - public string Host { get; set; } - public string SecretKey { get; set; } - public string IntegrationKey { get; set; } - - private void Build(TwoFactorProvider provider) - { - if (provider?.MetaData != null && provider.MetaData.Count > 0) + private void Build(TwoFactorProvider provider) { - Enabled = provider.Enabled; + if (provider?.MetaData != null && provider.MetaData.Count > 0) + { + Enabled = provider.Enabled; - if (provider.MetaData.ContainsKey("Host")) - { - Host = (string)provider.MetaData["Host"]; + if (provider.MetaData.ContainsKey("Host")) + { + Host = (string)provider.MetaData["Host"]; + } + if (provider.MetaData.ContainsKey("SKey")) + { + SecretKey = (string)provider.MetaData["SKey"]; + } + if (provider.MetaData.ContainsKey("IKey")) + { + IntegrationKey = (string)provider.MetaData["IKey"]; + } } - if (provider.MetaData.ContainsKey("SKey")) + else { - SecretKey = (string)provider.MetaData["SKey"]; + Enabled = false; } - if (provider.MetaData.ContainsKey("IKey")) - { - IntegrationKey = (string)provider.MetaData["IKey"]; - } - } - else - { - Enabled = false; } } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs index f2be91f9d9..9f8fecc4f9 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs @@ -2,30 +2,31 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorEmailResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - public TwoFactorEmailResponseModel(User user) - : base("twoFactorEmail") + public class TwoFactorEmailResponseModel : ResponseModel { - if (user == null) + public TwoFactorEmailResponseModel(User user) + : base("twoFactorEmail") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (provider?.MetaData?.ContainsKey("Email") ?? false) + { + Email = (string)provider.MetaData["Email"]; + Enabled = provider.Enabled; + } + else + { + Enabled = false; + } } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (provider?.MetaData?.ContainsKey("Email") ?? false) - { - Email = (string)provider.MetaData["Email"]; - Enabled = provider.Enabled; - } - else - { - Enabled = false; - } + public bool Enabled { get; set; } + public string Email { get; set; } } - - public bool Enabled { get; set; } - public string Email { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs index 0e8522104a..c742d8b2f3 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs @@ -3,50 +3,51 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorProviderResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - private const string ResponseObj = "twoFactorProvider"; - - public TwoFactorProviderResponseModel(TwoFactorProviderType type, TwoFactorProvider provider) - : base(ResponseObj) + public class TwoFactorProviderResponseModel : ResponseModel { - if (provider == null) + private const string ResponseObj = "twoFactorProvider"; + + public TwoFactorProviderResponseModel(TwoFactorProviderType type, TwoFactorProvider provider) + : base(ResponseObj) { - throw new ArgumentNullException(nameof(provider)); + if (provider == null) + { + throw new ArgumentNullException(nameof(provider)); + } + + Enabled = provider.Enabled; + Type = type; } - Enabled = provider.Enabled; - Type = type; - } - - public TwoFactorProviderResponseModel(TwoFactorProviderType type, User user) - : base(ResponseObj) - { - if (user == null) + public TwoFactorProviderResponseModel(TwoFactorProviderType type, User user) + : base(ResponseObj) { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(type); + Enabled = provider?.Enabled ?? false; + Type = type; } - var provider = user.GetTwoFactorProvider(type); - Enabled = provider?.Enabled ?? false; - Type = type; - } - - public TwoFactorProviderResponseModel(TwoFactorProviderType type, Organization organization) - : base(ResponseObj) - { - if (organization == null) + public TwoFactorProviderResponseModel(TwoFactorProviderType type, Organization organization) + : base(ResponseObj) { - throw new ArgumentNullException(nameof(organization)); + if (organization == null) + { + throw new ArgumentNullException(nameof(organization)); + } + + var provider = organization.GetTwoFactorProvider(type); + Enabled = provider?.Enabled ?? false; + Type = type; } - var provider = organization.GetTwoFactorProvider(type); - Enabled = provider?.Enabled ?? false; - Type = type; + public bool Enabled { get; set; } + public TwoFactorProviderType Type { get; set; } } - - public bool Enabled { get; set; } - public TwoFactorProviderType Type { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs index 26324de7cf..5d87a0e94d 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs @@ -1,20 +1,21 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorRecoverResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - public TwoFactorRecoverResponseModel(User user) - : base("twoFactorRecover") + public class TwoFactorRecoverResponseModel : ResponseModel { - if (user == null) + public TwoFactorRecoverResponseModel(User user) + : base("twoFactorRecover") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Code = user.TwoFactorRecoveryCode; } - Code = user.TwoFactorRecoveryCode; + public string Code { get; set; } } - - public string Code { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs index 3e2ab2bc64..05c3b2f44a 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs @@ -3,39 +3,40 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorWebAuthnResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - public TwoFactorWebAuthnResponseModel(User user) - : base("twoFactorWebAuthn") + public class TwoFactorWebAuthnResponseModel : ResponseModel { - if (user == null) + public TwoFactorWebAuthnResponseModel(User user) + : base("twoFactorWebAuthn") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + Enabled = provider?.Enabled ?? false; + Keys = provider?.MetaData? + .Where(k => k.Key.StartsWith("Key")) + .Select(k => new KeyModel(k.Key, new TwoFactorProvider.WebAuthnData((dynamic)k.Value))); } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - Enabled = provider?.Enabled ?? false; - Keys = provider?.MetaData? - .Where(k => k.Key.StartsWith("Key")) - .Select(k => new KeyModel(k.Key, new TwoFactorProvider.WebAuthnData((dynamic)k.Value))); - } + public bool Enabled { get; set; } + public IEnumerable Keys { get; set; } - public bool Enabled { get; set; } - public IEnumerable Keys { get; set; } - - public class KeyModel - { - public KeyModel(string id, TwoFactorProvider.WebAuthnData data) + public class KeyModel { - Name = data.Name; - Id = Convert.ToInt32(id.Replace("Key", string.Empty)); - Migrated = data.Migrated; - } + public KeyModel(string id, TwoFactorProvider.WebAuthnData data) + { + Name = data.Name; + Id = Convert.ToInt32(id.Replace("Key", string.Empty)); + Migrated = data.Migrated; + } - public string Name { get; set; } - public int Id { get; set; } - public bool Migrated { get; set; } + public string Name { get; set; } + public int Id { get; set; } + public bool Migrated { get; set; } + } } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs index 48c7670c32..9654bd1e6b 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs @@ -2,59 +2,60 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor; - -public class TwoFactorYubiKeyResponseModel : ResponseModel +namespace Bit.Api.Models.Response.TwoFactor { - public TwoFactorYubiKeyResponseModel(User user) - : base("twoFactorYubiKey") + public class TwoFactorYubiKeyResponseModel : ResponseModel { - if (user == null) + public TwoFactorYubiKeyResponseModel(User user) + : base("twoFactorYubiKey") { - throw new ArgumentNullException(nameof(user)); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); + if (provider?.MetaData != null && provider.MetaData.Count > 0) + { + Enabled = provider.Enabled; + + if (provider.MetaData.ContainsKey("Key1")) + { + Key1 = (string)provider.MetaData["Key1"]; + } + if (provider.MetaData.ContainsKey("Key2")) + { + Key2 = (string)provider.MetaData["Key2"]; + } + if (provider.MetaData.ContainsKey("Key3")) + { + Key3 = (string)provider.MetaData["Key3"]; + } + if (provider.MetaData.ContainsKey("Key4")) + { + Key4 = (string)provider.MetaData["Key4"]; + } + if (provider.MetaData.ContainsKey("Key5")) + { + Key5 = (string)provider.MetaData["Key5"]; + } + if (provider.MetaData.ContainsKey("Nfc")) + { + Nfc = (bool)provider.MetaData["Nfc"]; + } + } + else + { + Enabled = false; + } } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); - if (provider?.MetaData != null && provider.MetaData.Count > 0) - { - Enabled = provider.Enabled; - - if (provider.MetaData.ContainsKey("Key1")) - { - Key1 = (string)provider.MetaData["Key1"]; - } - if (provider.MetaData.ContainsKey("Key2")) - { - Key2 = (string)provider.MetaData["Key2"]; - } - if (provider.MetaData.ContainsKey("Key3")) - { - Key3 = (string)provider.MetaData["Key3"]; - } - if (provider.MetaData.ContainsKey("Key4")) - { - Key4 = (string)provider.MetaData["Key4"]; - } - if (provider.MetaData.ContainsKey("Key5")) - { - Key5 = (string)provider.MetaData["Key5"]; - } - if (provider.MetaData.ContainsKey("Nfc")) - { - Nfc = (bool)provider.MetaData["Nfc"]; - } - } - else - { - Enabled = false; - } + public bool Enabled { get; set; } + public string Key1 { get; set; } + public string Key2 { get; set; } + public string Key3 { get; set; } + public string Key4 { get; set; } + public string Key5 { get; set; } + public bool Nfc { get; set; } } - - public bool Enabled { get; set; } - public string Key1 { get; set; } - public string Key2 { get; set; } - public string Key3 { get; set; } - public string Key4 { get; set; } - public string Key5 { get; set; } - public bool Nfc { get; set; } } diff --git a/src/Api/Models/Response/UserKeyResponseModel.cs b/src/Api/Models/Response/UserKeyResponseModel.cs index d80571993d..b31f1e95ae 100644 --- a/src/Api/Models/Response/UserKeyResponseModel.cs +++ b/src/Api/Models/Response/UserKeyResponseModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response; - -public class UserKeyResponseModel : ResponseModel +namespace Bit.Api.Models.Response { - public UserKeyResponseModel(Guid id, string key) - : base("userKey") + public class UserKeyResponseModel : ResponseModel { - UserId = id.ToString(); - PublicKey = key; - } + public UserKeyResponseModel(Guid id, string key) + : base("userKey") + { + UserId = id.ToString(); + PublicKey = key; + } - public string UserId { get; set; } - public string PublicKey { get; set; } + public string UserId { get; set; } + public string PublicKey { get; set; } + } } diff --git a/src/Api/Models/SendFileModel.cs b/src/Api/Models/SendFileModel.cs index bfe10f86f1..653510c896 100644 --- a/src/Api/Models/SendFileModel.cs +++ b/src/Api/Models/SendFileModel.cs @@ -2,25 +2,26 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class SendFileModel +namespace Bit.Api.Models { - public SendFileModel() { } - - public SendFileModel(SendFileData data) + public class SendFileModel { - Id = data.Id; - FileName = data.FileName; - Size = data.Size; - SizeName = CoreHelpers.ReadableBytesSize(data.Size); - } + public SendFileModel() { } - public string Id { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string FileName { get; set; } - [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] - public long? Size { get; set; } - public string SizeName { get; set; } + public SendFileModel(SendFileData data) + { + Id = data.Id; + FileName = data.FileName; + Size = data.Size; + SizeName = CoreHelpers.ReadableBytesSize(data.Size); + } + + public string Id { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string FileName { get; set; } + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] + public long? Size { get; set; } + public string SizeName { get; set; } + } } diff --git a/src/Api/Models/SendTextModel.cs b/src/Api/Models/SendTextModel.cs index ba2e6f8a62..a362a61d9e 100644 --- a/src/Api/Models/SendTextModel.cs +++ b/src/Api/Models/SendTextModel.cs @@ -1,20 +1,21 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models; - -public class SendTextModel +namespace Bit.Api.Models { - public SendTextModel() { } - - public SendTextModel(SendTextData data) + public class SendTextModel { - Text = data.Text; - Hidden = data.Hidden; - } + public SendTextModel() { } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Text { get; set; } - public bool Hidden { get; set; } + public SendTextModel(SendTextData data) + { + Text = data.Text; + Hidden = data.Hidden; + } + + [EncryptedString] + [EncryptedStringLength(1000)] + public string Text { get; set; } + public bool Hidden { get; set; } + } } diff --git a/src/Api/Program.cs b/src/Api/Program.cs index b7e80d6c26..bcd6284af0 100644 --- a/src/Api/Program.cs +++ b/src/Api/Program.cs @@ -3,45 +3,46 @@ using Bit.Core.Utilities; using Microsoft.IdentityModel.Tokens; using Serilog.Events; -namespace Bit.Api; - -public class Program +namespace Bit.Api { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Exception != null && - (e.Exception.GetType() == typeof(SecurityTokenValidationException) || - e.Exception.Message == "Bad security stamp.")) + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return false; - } + var context = e.Properties["SourceContext"].ToString(); + if (e.Exception != null && + (e.Exception.GetType() == typeof(SecurityTokenValidationException) || + e.Exception.Message == "Bad security stamp.")) + { + return false; + } - if (e.Level == LogEventLevel.Information && - context.Contains(typeof(IpRateLimitMiddleware).FullName)) - { - return true; - } + if (e.Level == LogEventLevel.Information && + context.Contains(typeof(IpRateLimitMiddleware).FullName)) + { + return true; + } - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) - { - return e.Level > LogEventLevel.Error; - } + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + { + return e.Level > LogEventLevel.Error; + } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); + } } } diff --git a/src/Api/Public/Controllers/CollectionsController.cs b/src/Api/Public/Controllers/CollectionsController.cs index ae56d6824a..677d53861a 100644 --- a/src/Api/Public/Controllers/CollectionsController.cs +++ b/src/Api/Public/Controllers/CollectionsController.cs @@ -7,113 +7,114 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers; - -[Route("public/collections")] -[Authorize("Organization")] -public class CollectionsController : Controller +namespace Bit.Api.Public.Controllers { - private readonly ICollectionRepository _collectionRepository; - private readonly ICollectionService _collectionService; - private readonly ICurrentContext _currentContext; - - public CollectionsController( - ICollectionRepository collectionRepository, - ICollectionService collectionService, - ICurrentContext currentContext) + [Route("public/collections")] + [Authorize("Organization")] + public class CollectionsController : Controller { - _collectionRepository = collectionRepository; - _collectionService = collectionService; - _currentContext = currentContext; - } + private readonly ICollectionRepository _collectionRepository; + private readonly ICollectionService _collectionService; + private readonly ICurrentContext _currentContext; - /// - /// Retrieve a collection. - /// - /// - /// Retrieves the details of an existing collection. You need only supply the unique collection identifier - /// that was returned upon collection creation. - /// - /// The identifier of the collection to be retrieved. - [HttpGet("{id}")] - [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(Guid id) - { - var collectionWithGroups = await _collectionRepository.GetByIdWithGroupsAsync(id); - var collection = collectionWithGroups?.Item1; - if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) + public CollectionsController( + ICollectionRepository collectionRepository, + ICollectionService collectionService, + ICurrentContext currentContext) { - return new NotFoundResult(); + _collectionRepository = collectionRepository; + _collectionService = collectionService; + _currentContext = currentContext; } - var response = new CollectionResponseModel(collection, collectionWithGroups.Item2); - return new JsonResult(response); - } - /// - /// List all collections. - /// - /// - /// Returns a list of your organization's collections. - /// Collection objects listed in this call do not include information about their associated groups. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() - { - var collections = await _collectionRepository.GetManyByOrganizationIdAsync( - _currentContext.OrganizationId.Value); - // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. - var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null)); - var response = new ListResponseModel(collectionResponses); - return new JsonResult(response); - } - - /// - /// Update a collection. - /// - /// - /// Updates the specified collection object. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The identifier of the collection to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(Guid id, [FromBody] CollectionUpdateRequestModel model) - { - var existingCollection = await _collectionRepository.GetByIdAsync(id); - if (existingCollection == null || existingCollection.OrganizationId != _currentContext.OrganizationId) + /// + /// Retrieve a collection. + /// + /// + /// Retrieves the details of an existing collection. You need only supply the unique collection identifier + /// that was returned upon collection creation. + /// + /// The identifier of the collection to be retrieved. + [HttpGet("{id}")] + [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(Guid id) { - return new NotFoundResult(); + var collectionWithGroups = await _collectionRepository.GetByIdWithGroupsAsync(id); + var collection = collectionWithGroups?.Item1; + if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var response = new CollectionResponseModel(collection, collectionWithGroups.Item2); + return new JsonResult(response); } - var updatedCollection = model.ToCollection(existingCollection); - var associations = model.Groups?.Select(c => c.ToSelectionReadOnly()); - await _collectionService.SaveAsync(updatedCollection, associations); - var response = new CollectionResponseModel(updatedCollection, associations); - return new JsonResult(response); - } - /// - /// Delete a collection. - /// - /// - /// Permanently deletes a collection. This cannot be undone. - /// - /// The identifier of the collection to be deleted. - [HttpDelete("{id}")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Delete(Guid id) - { - var collection = await _collectionRepository.GetByIdAsync(id); - if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) + /// + /// List all collections. + /// + /// + /// Returns a list of your organization's collections. + /// Collection objects listed in this call do not include information about their associated groups. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() { - return new NotFoundResult(); + var collections = await _collectionRepository.GetManyByOrganizationIdAsync( + _currentContext.OrganizationId.Value); + // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. + var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null)); + var response = new ListResponseModel(collectionResponses); + return new JsonResult(response); + } + + /// + /// Update a collection. + /// + /// + /// Updates the specified collection object. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The identifier of the collection to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(Guid id, [FromBody] CollectionUpdateRequestModel model) + { + var existingCollection = await _collectionRepository.GetByIdAsync(id); + if (existingCollection == null || existingCollection.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var updatedCollection = model.ToCollection(existingCollection); + var associations = model.Groups?.Select(c => c.ToSelectionReadOnly()); + await _collectionService.SaveAsync(updatedCollection, associations); + var response = new CollectionResponseModel(updatedCollection, associations); + return new JsonResult(response); + } + + /// + /// Delete a collection. + /// + /// + /// Permanently deletes a collection. This cannot be undone. + /// + /// The identifier of the collection to be deleted. + [HttpDelete("{id}")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Delete(Guid id) + { + var collection = await _collectionRepository.GetByIdAsync(id); + if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + await _collectionRepository.DeleteAsync(collection); + return new OkResult(); } - await _collectionRepository.DeleteAsync(collection); - return new OkResult(); } } diff --git a/src/Api/Public/Controllers/EventsController.cs b/src/Api/Public/Controllers/EventsController.cs index 6e9c734c13..5fe5bdb7b3 100644 --- a/src/Api/Public/Controllers/EventsController.cs +++ b/src/Api/Public/Controllers/EventsController.cs @@ -7,64 +7,65 @@ using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers; - -[Route("public/events")] -[Authorize("Organization")] -public class EventsController : Controller +namespace Bit.Api.Public.Controllers { - private readonly IEventRepository _eventRepository; - private readonly ICipherRepository _cipherRepository; - private readonly ICurrentContext _currentContext; - - public EventsController( - IEventRepository eventRepository, - ICipherRepository cipherRepository, - ICurrentContext currentContext) + [Route("public/events")] + [Authorize("Organization")] + public class EventsController : Controller { - _eventRepository = eventRepository; - _cipherRepository = cipherRepository; - _currentContext = currentContext; - } + private readonly IEventRepository _eventRepository; + private readonly ICipherRepository _cipherRepository; + private readonly ICurrentContext _currentContext; - /// - /// List all events. - /// - /// - /// Returns a filtered list of your organization's event logs, paged by a continuation token. - /// If no filters are provided, it will return the last 30 days of event for the organization. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List([FromQuery] EventFilterRequestModel request) - { - var dateRange = request.ToDateRange(); - var result = new PagedResult(); - if (request.ActingUserId.HasValue) + public EventsController( + IEventRepository eventRepository, + ICipherRepository cipherRepository, + ICurrentContext currentContext) { - result = await _eventRepository.GetManyByOrganizationActingUserAsync( - _currentContext.OrganizationId.Value, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); + _eventRepository = eventRepository; + _cipherRepository = cipherRepository; + _currentContext = currentContext; } - else if (request.ItemId.HasValue) + + /// + /// List all events. + /// + /// + /// Returns a filtered list of your organization's event logs, paged by a continuation token. + /// If no filters are provided, it will return the last 30 days of event for the organization. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List([FromQuery] EventFilterRequestModel request) { - var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); - if (cipher != null && cipher.OrganizationId == _currentContext.OrganizationId.Value) + var dateRange = request.ToDateRange(); + var result = new PagedResult(); + if (request.ActingUserId.HasValue) { - result = await _eventRepository.GetManyByCipherAsync( - cipher, dateRange.Item1, dateRange.Item2, + result = await _eventRepository.GetManyByOrganizationActingUserAsync( + _currentContext.OrganizationId.Value, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + else if (request.ItemId.HasValue) + { + var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); + if (cipher != null && cipher.OrganizationId == _currentContext.OrganizationId.Value) + { + result = await _eventRepository.GetManyByCipherAsync( + cipher, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + } + else + { + result = await _eventRepository.GetManyByOrganizationAsync( + _currentContext.OrganizationId.Value, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = request.ContinuationToken }); } - } - else - { - result = await _eventRepository.GetManyByOrganizationAsync( - _currentContext.OrganizationId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); - } - var eventResponses = result.Data.Select(e => new EventResponseModel(e)); - var response = new ListResponseModel(eventResponses, result.ContinuationToken); - return new JsonResult(response); + var eventResponses = result.Data.Select(e => new EventResponseModel(e)); + var response = new ListResponseModel(eventResponses, result.ContinuationToken); + return new JsonResult(response); + } } } diff --git a/src/Api/Public/Controllers/GroupsController.cs b/src/Api/Public/Controllers/GroupsController.cs index f65f7b9fe3..ef29db5685 100644 --- a/src/Api/Public/Controllers/GroupsController.cs +++ b/src/Api/Public/Controllers/GroupsController.cs @@ -7,176 +7,177 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers; - -[Route("public/groups")] -[Authorize("Organization")] -public class GroupsController : Controller +namespace Bit.Api.Public.Controllers { - private readonly IGroupRepository _groupRepository; - private readonly IGroupService _groupService; - private readonly ICurrentContext _currentContext; - - public GroupsController( - IGroupRepository groupRepository, - IGroupService groupService, - ICurrentContext currentContext) + [Route("public/groups")] + [Authorize("Organization")] + public class GroupsController : Controller { - _groupRepository = groupRepository; - _groupService = groupService; - _currentContext = currentContext; - } + private readonly IGroupRepository _groupRepository; + private readonly IGroupService _groupService; + private readonly ICurrentContext _currentContext; - /// - /// Retrieve a group. - /// - /// - /// Retrieves the details of an existing group. You need only supply the unique group identifier - /// that was returned upon group creation. - /// - /// The identifier of the group to be retrieved. - [HttpGet("{id}")] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(Guid id) - { - var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(id); - var group = groupDetails?.Item1; - if (group == null || group.OrganizationId != _currentContext.OrganizationId) + public GroupsController( + IGroupRepository groupRepository, + IGroupService groupService, + ICurrentContext currentContext) { - return new NotFoundResult(); + _groupRepository = groupRepository; + _groupService = groupService; + _currentContext = currentContext; } - var response = new GroupResponseModel(group, groupDetails.Item2); - return new JsonResult(response); - } - /// - /// Retrieve a groups's member ids - /// - /// - /// Retrieves the unique identifiers for all members that are associated with this group. You need only - /// supply the unique group identifier that was returned upon group creation. - /// - /// The identifier of the group to be retrieved. - [HttpGet("{id}/member-ids")] - [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task GetMemberIds(Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != _currentContext.OrganizationId) + /// + /// Retrieve a group. + /// + /// + /// Retrieves the details of an existing group. You need only supply the unique group identifier + /// that was returned upon group creation. + /// + /// The identifier of the group to be retrieved. + [HttpGet("{id}")] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(Guid id) { - return new NotFoundResult(); + var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(id); + var group = groupDetails?.Item1; + if (group == null || group.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var response = new GroupResponseModel(group, groupDetails.Item2); + return new JsonResult(response); } - var orgUserIds = await _groupRepository.GetManyUserIdsByIdAsync(id); - return new JsonResult(orgUserIds); - } - /// - /// List all groups. - /// - /// - /// Returns a list of your organization's groups. - /// Group objects listed in this call do not include information about their associated collections. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() - { - var groups = await _groupRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); - // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. - var groupResponses = groups.Select(g => new GroupResponseModel(g, null)); - var response = new ListResponseModel(groupResponses); - return new JsonResult(response); - } - - /// - /// Create a group. - /// - /// - /// Creates a new group object. - /// - /// The request model. - [HttpPost] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - public async Task Post([FromBody] GroupCreateUpdateRequestModel model) - { - var group = model.ToGroup(_currentContext.OrganizationId.Value); - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - await _groupService.SaveAsync(group, associations); - var response = new GroupResponseModel(group, associations); - return new JsonResult(response); - } - - /// - /// Update a group. - /// - /// - /// Updates the specified group object. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The identifier of the group to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(Guid id, [FromBody] GroupCreateUpdateRequestModel model) - { - var existingGroup = await _groupRepository.GetByIdAsync(id); - if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) + /// + /// Retrieve a groups's member ids + /// + /// + /// Retrieves the unique identifiers for all members that are associated with this group. You need only + /// supply the unique group identifier that was returned upon group creation. + /// + /// The identifier of the group to be retrieved. + [HttpGet("{id}/member-ids")] + [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task GetMemberIds(Guid id) { - return new NotFoundResult(); + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var orgUserIds = await _groupRepository.GetManyUserIdsByIdAsync(id); + return new JsonResult(orgUserIds); } - var updatedGroup = model.ToGroup(existingGroup); - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - await _groupService.SaveAsync(updatedGroup, associations); - var response = new GroupResponseModel(updatedGroup, associations); - return new JsonResult(response); - } - /// - /// Update a group's members. - /// - /// - /// Updates the specified group's member associations. - /// - /// The identifier of the group to be updated. - /// The request model. - [HttpPut("{id}/member-ids")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task PutMemberIds(Guid id, [FromBody] UpdateMemberIdsRequestModel model) - { - var existingGroup = await _groupRepository.GetByIdAsync(id); - if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) + /// + /// List all groups. + /// + /// + /// Returns a list of your organization's groups. + /// Group objects listed in this call do not include information about their associated collections. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() { - return new NotFoundResult(); + var groups = await _groupRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); + // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. + var groupResponses = groups.Select(g => new GroupResponseModel(g, null)); + var response = new ListResponseModel(groupResponses); + return new JsonResult(response); } - await _groupRepository.UpdateUsersAsync(existingGroup.Id, model.MemberIds); - return new OkResult(); - } - /// - /// Delete a group. - /// - /// - /// Permanently deletes a group. This cannot be undone. - /// - /// The identifier of the group to be deleted. - [HttpDelete("{id}")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Delete(Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != _currentContext.OrganizationId) + /// + /// Create a group. + /// + /// + /// Creates a new group object. + /// + /// The request model. + [HttpPost] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + public async Task Post([FromBody] GroupCreateUpdateRequestModel model) { - return new NotFoundResult(); + var group = model.ToGroup(_currentContext.OrganizationId.Value); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + await _groupService.SaveAsync(group, associations); + var response = new GroupResponseModel(group, associations); + return new JsonResult(response); + } + + /// + /// Update a group. + /// + /// + /// Updates the specified group object. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The identifier of the group to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(Guid id, [FromBody] GroupCreateUpdateRequestModel model) + { + var existingGroup = await _groupRepository.GetByIdAsync(id); + if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var updatedGroup = model.ToGroup(existingGroup); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + await _groupService.SaveAsync(updatedGroup, associations); + var response = new GroupResponseModel(updatedGroup, associations); + return new JsonResult(response); + } + + /// + /// Update a group's members. + /// + /// + /// Updates the specified group's member associations. + /// + /// The identifier of the group to be updated. + /// The request model. + [HttpPut("{id}/member-ids")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task PutMemberIds(Guid id, [FromBody] UpdateMemberIdsRequestModel model) + { + var existingGroup = await _groupRepository.GetByIdAsync(id); + if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + await _groupRepository.UpdateUsersAsync(existingGroup.Id, model.MemberIds); + return new OkResult(); + } + + /// + /// Delete a group. + /// + /// + /// Permanently deletes a group. This cannot be undone. + /// + /// The identifier of the group to be deleted. + [HttpDelete("{id}")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Delete(Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + await _groupRepository.DeleteAsync(group); + return new OkResult(); } - await _groupRepository.DeleteAsync(group); - return new OkResult(); } } diff --git a/src/Api/Public/Controllers/MembersController.cs b/src/Api/Public/Controllers/MembersController.cs index 5ea079ee39..bfe7f86b77 100644 --- a/src/Api/Public/Controllers/MembersController.cs +++ b/src/Api/Public/Controllers/MembersController.cs @@ -8,226 +8,227 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers; - -[Route("public/members")] -[Authorize("Organization")] -public class MembersController : Controller +namespace Bit.Api.Public.Controllers { - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IGroupRepository _groupRepository; - private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - - public MembersController( - IOrganizationUserRepository organizationUserRepository, - IGroupRepository groupRepository, - IOrganizationService organizationService, - IUserService userService, - ICurrentContext currentContext) + [Route("public/members")] + [Authorize("Organization")] + public class MembersController : Controller { - _organizationUserRepository = organizationUserRepository; - _groupRepository = groupRepository; - _organizationService = organizationService; - _userService = userService; - _currentContext = currentContext; - } + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IGroupRepository _groupRepository; + private readonly IOrganizationService _organizationService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; - /// - /// Retrieve a member. - /// - /// - /// Retrieves the details of an existing member of the organization. You need only supply the - /// unique member identifier that was returned upon member creation. - /// - /// The identifier of the member to be retrieved. - [HttpGet("{id}")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(Guid id) - { - var userDetails = await _organizationUserRepository.GetDetailsByIdWithCollectionsAsync(id); - var orgUser = userDetails?.Item1; - if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) + public MembersController( + IOrganizationUserRepository organizationUserRepository, + IGroupRepository groupRepository, + IOrganizationService organizationService, + IUserService userService, + ICurrentContext currentContext) { - return new NotFoundResult(); + _organizationUserRepository = organizationUserRepository; + _groupRepository = groupRepository; + _organizationService = organizationService; + _userService = userService; + _currentContext = currentContext; } - var response = new MemberResponseModel(orgUser, await _userService.TwoFactorIsEnabledAsync(orgUser), - userDetails.Item2); - return new JsonResult(response); - } - /// - /// Retrieve a member's group ids - /// - /// - /// Retrieves the unique identifiers for all groups that are associated with this member. You need only - /// supply the unique member identifier that was returned upon member creation. - /// - /// The identifier of the member to be retrieved. - [HttpGet("{id}/group-ids")] - [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task GetGroupIds(Guid id) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) + /// + /// Retrieve a member. + /// + /// + /// Retrieves the details of an existing member of the organization. You need only supply the + /// unique member identifier that was returned upon member creation. + /// + /// The identifier of the member to be retrieved. + [HttpGet("{id}")] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(Guid id) { - return new NotFoundResult(); + var userDetails = await _organizationUserRepository.GetDetailsByIdWithCollectionsAsync(id); + var orgUser = userDetails?.Item1; + if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var response = new MemberResponseModel(orgUser, await _userService.TwoFactorIsEnabledAsync(orgUser), + userDetails.Item2); + return new JsonResult(response); } - var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(id); - return new JsonResult(groupIds); - } - /// - /// List all members. - /// - /// - /// Returns a list of your organization's members. - /// Member objects listed in this call do not include information about their associated collections. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() - { - var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( - _currentContext.OrganizationId.Value); - // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. - var memberResponsesTasks = users.Select(async u => new MemberResponseModel(u, - await _userService.TwoFactorIsEnabledAsync(u), null)); - var memberResponses = await Task.WhenAll(memberResponsesTasks); - var response = new ListResponseModel(memberResponses); - return new JsonResult(response); - } + /// + /// Retrieve a member's group ids + /// + /// + /// Retrieves the unique identifiers for all groups that are associated with this member. You need only + /// supply the unique member identifier that was returned upon member creation. + /// + /// The identifier of the member to be retrieved. + [HttpGet("{id}/group-ids")] + [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task GetGroupIds(Guid id) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(id); + return new JsonResult(groupIds); + } - /// - /// Create a member. - /// - /// - /// Creates a new member object by inviting a user to the organization. - /// - /// The request model. - [HttpPost] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - public async Task Post([FromBody] MemberCreateRequestModel model) - { - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - var invite = new OrganizationUserInvite + /// + /// List all members. + /// + /// + /// Returns a list of your organization's members. + /// Member objects listed in this call do not include information about their associated collections. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() { - Emails = new List { model.Email }, - Type = model.Type.Value, - AccessAll = model.AccessAll.Value, - Collections = associations - }; - var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, - model.Email, model.Type.Value, model.AccessAll.Value, model.ExternalId, associations); - var response = new MemberResponseModel(user, associations); - return new JsonResult(response); - } + var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( + _currentContext.OrganizationId.Value); + // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. + var memberResponsesTasks = users.Select(async u => new MemberResponseModel(u, + await _userService.TwoFactorIsEnabledAsync(u), null)); + var memberResponses = await Task.WhenAll(memberResponsesTasks); + var response = new ListResponseModel(memberResponses); + return new JsonResult(response); + } - /// - /// Update a member. - /// - /// - /// Updates the specified member object. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The identifier of the member to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(Guid id, [FromBody] MemberUpdateRequestModel model) - { - var existingUser = await _organizationUserRepository.GetByIdAsync(id); - if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) + /// + /// Create a member. + /// + /// + /// Creates a new member object by inviting a user to the organization. + /// + /// The request model. + [HttpPost] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + public async Task Post([FromBody] MemberCreateRequestModel model) { - return new NotFoundResult(); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + var invite = new OrganizationUserInvite + { + Emails = new List { model.Email }, + Type = model.Type.Value, + AccessAll = model.AccessAll.Value, + Collections = associations + }; + var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, + model.Email, model.Type.Value, model.AccessAll.Value, model.ExternalId, associations); + var response = new MemberResponseModel(user, associations); + return new JsonResult(response); } - var updatedUser = model.ToOrganizationUser(existingUser); - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - await _organizationService.SaveUserAsync(updatedUser, null, associations); - MemberResponseModel response = null; - if (existingUser.UserId.HasValue) - { - var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - response = new MemberResponseModel(existingUserDetails, - await _userService.TwoFactorIsEnabledAsync(existingUserDetails), associations); - } - else - { - response = new MemberResponseModel(updatedUser, associations); - } - return new JsonResult(response); - } - /// - /// Update a member's groups. - /// - /// - /// Updates the specified member's group associations. - /// - /// The identifier of the member to be updated. - /// The request model. - [HttpPut("{id}/group-ids")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task PutGroupIds(Guid id, [FromBody] UpdateGroupIdsRequestModel model) - { - var existingUser = await _organizationUserRepository.GetByIdAsync(id); - if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) + /// + /// Update a member. + /// + /// + /// Updates the specified member object. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The identifier of the member to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(Guid id, [FromBody] MemberUpdateRequestModel model) { - return new NotFoundResult(); + var existingUser = await _organizationUserRepository.GetByIdAsync(id); + if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + var updatedUser = model.ToOrganizationUser(existingUser); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + await _organizationService.SaveUserAsync(updatedUser, null, associations); + MemberResponseModel response = null; + if (existingUser.UserId.HasValue) + { + var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); + response = new MemberResponseModel(existingUserDetails, + await _userService.TwoFactorIsEnabledAsync(existingUserDetails), associations); + } + else + { + response = new MemberResponseModel(updatedUser, associations); + } + return new JsonResult(response); } - await _organizationService.UpdateUserGroupsAsync(existingUser, model.GroupIds, null); - return new OkResult(); - } - /// - /// Delete a member. - /// - /// - /// Permanently deletes a member from the organization. This cannot be undone. - /// The user account will still remain. The user is only removed from the organization. - /// - /// The identifier of the member to be deleted. - [HttpDelete("{id}")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Delete(Guid id) - { - var user = await _organizationUserRepository.GetByIdAsync(id); - if (user == null || user.OrganizationId != _currentContext.OrganizationId) + /// + /// Update a member's groups. + /// + /// + /// Updates the specified member's group associations. + /// + /// The identifier of the member to be updated. + /// The request model. + [HttpPut("{id}/group-ids")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task PutGroupIds(Guid id, [FromBody] UpdateGroupIdsRequestModel model) { - return new NotFoundResult(); + var existingUser = await _organizationUserRepository.GetByIdAsync(id); + if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + await _organizationService.UpdateUserGroupsAsync(existingUser, model.GroupIds, null); + return new OkResult(); } - await _organizationService.DeleteUserAsync(_currentContext.OrganizationId.Value, id, null); - return new OkResult(); - } - /// - /// Re-invite a member. - /// - /// - /// Re-sends the invitation email to an organization member. - /// - /// The identifier of the member to re-invite. - [HttpPost("{id}/reinvite")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task PostReinvite(Guid id) - { - var existingUser = await _organizationUserRepository.GetByIdAsync(id); - if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) + /// + /// Delete a member. + /// + /// + /// Permanently deletes a member from the organization. This cannot be undone. + /// The user account will still remain. The user is only removed from the organization. + /// + /// The identifier of the member to be deleted. + [HttpDelete("{id}")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Delete(Guid id) { - return new NotFoundResult(); + var user = await _organizationUserRepository.GetByIdAsync(id); + if (user == null || user.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + await _organizationService.DeleteUserAsync(_currentContext.OrganizationId.Value, id, null); + return new OkResult(); + } + + /// + /// Re-invite a member. + /// + /// + /// Re-sends the invitation email to an organization member. + /// + /// The identifier of the member to re-invite. + [HttpPost("{id}/reinvite")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task PostReinvite(Guid id) + { + var existingUser = await _organizationUserRepository.GetByIdAsync(id); + if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + await _organizationService.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); + return new OkResult(); } - await _organizationService.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); - return new OkResult(); } } diff --git a/src/Api/Public/Controllers/OrganizationController.cs b/src/Api/Public/Controllers/OrganizationController.cs index ce0683b95d..978811d39d 100644 --- a/src/Api/Public/Controllers/OrganizationController.cs +++ b/src/Api/Public/Controllers/OrganizationController.cs @@ -8,51 +8,52 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers; - -[Route("public/organization")] -[Authorize("Organization")] -public class OrganizationController : Controller +namespace Bit.Api.Public.Controllers { - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - - public OrganizationController( - IOrganizationService organizationService, - ICurrentContext currentContext, - GlobalSettings globalSettings) + [Route("public/organization")] + [Authorize("Organization")] + public class OrganizationController : Controller { - _organizationService = organizationService; - _currentContext = currentContext; - _globalSettings = globalSettings; - } + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; - /// - /// Import members and groups. - /// - /// - /// Import members and groups from an external system. - /// - /// The request model. - [HttpPost("import")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - public async Task Import([FromBody] OrganizationImportRequestModel model) - { - if (!_globalSettings.SelfHosted && !model.LargeImport && - (model.Groups.Count() > 2000 || model.Members.Count(u => !u.Deleted) > 2000)) + public OrganizationController( + IOrganizationService organizationService, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - throw new BadRequestException("You cannot import this much data at once."); + _organizationService = organizationService; + _currentContext = currentContext; + _globalSettings = globalSettings; } - await _organizationService.ImportAsync( - _currentContext.OrganizationId.Value, - null, - model.Groups.Select(g => g.ToImportedGroup(_currentContext.OrganizationId.Value)), - model.Members.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), - model.Members.Where(u => u.Deleted).Select(u => u.ExternalId), - model.OverwriteExisting.GetValueOrDefault()); - return new OkResult(); + /// + /// Import members and groups. + /// + /// + /// Import members and groups from an external system. + /// + /// The request model. + [HttpPost("import")] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + public async Task Import([FromBody] OrganizationImportRequestModel model) + { + if (!_globalSettings.SelfHosted && !model.LargeImport && + (model.Groups.Count() > 2000 || model.Members.Count(u => !u.Deleted) > 2000)) + { + throw new BadRequestException("You cannot import this much data at once."); + } + + await _organizationService.ImportAsync( + _currentContext.OrganizationId.Value, + null, + model.Groups.Select(g => g.ToImportedGroup(_currentContext.OrganizationId.Value)), + model.Members.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), + model.Members.Where(u => u.Deleted).Select(u => u.ExternalId), + model.OverwriteExisting.GetValueOrDefault()); + return new OkResult(); + } } } diff --git a/src/Api/Public/Controllers/PoliciesController.cs b/src/Api/Public/Controllers/PoliciesController.cs index b208938ed7..65556ebac7 100644 --- a/src/Api/Public/Controllers/PoliciesController.cs +++ b/src/Api/Public/Controllers/PoliciesController.cs @@ -8,97 +8,98 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers; - -[Route("public/policies")] -[Authorize("Organization")] -public class PoliciesController : Controller +namespace Bit.Api.Public.Controllers { - private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; - private readonly IUserService _userService; - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; - - public PoliciesController( - IPolicyRepository policyRepository, - IPolicyService policyService, - IUserService userService, - IOrganizationService organizationService, - ICurrentContext currentContext) + [Route("public/policies")] + [Authorize("Organization")] + public class PoliciesController : Controller { - _policyRepository = policyRepository; - _policyService = policyService; - _userService = userService; - _organizationService = organizationService; - _currentContext = currentContext; - } + private readonly IPolicyRepository _policyRepository; + private readonly IPolicyService _policyService; + private readonly IUserService _userService; + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; - /// - /// Retrieve a policy. - /// - /// - /// Retrieves the details of a policy. - /// - /// The type of policy to be retrieved. - [HttpGet("{type}")] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(PolicyType type) - { - var policy = await _policyRepository.GetByOrganizationIdTypeAsync( - _currentContext.OrganizationId.Value, type); - if (policy == null) + public PoliciesController( + IPolicyRepository policyRepository, + IPolicyService policyService, + IUserService userService, + IOrganizationService organizationService, + ICurrentContext currentContext) { - return new NotFoundResult(); + _policyRepository = policyRepository; + _policyService = policyService; + _userService = userService; + _organizationService = organizationService; + _currentContext = currentContext; } - var response = new PolicyResponseModel(policy); - return new JsonResult(response); - } - /// - /// List all policies. - /// - /// - /// Returns a list of your organization's policies. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() - { - var policies = await _policyRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); - var policyResponses = policies.Select(p => new PolicyResponseModel(p)); - var response = new ListResponseModel(policyResponses); - return new JsonResult(response); - } + /// + /// Retrieve a policy. + /// + /// + /// Retrieves the details of a policy. + /// + /// The type of policy to be retrieved. + [HttpGet("{type}")] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(PolicyType type) + { + var policy = await _policyRepository.GetByOrganizationIdTypeAsync( + _currentContext.OrganizationId.Value, type); + if (policy == null) + { + return new NotFoundResult(); + } + var response = new PolicyResponseModel(policy); + return new JsonResult(response); + } - /// - /// Update a policy. - /// - /// - /// Updates the specified policy. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The type of policy to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(PolicyResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) - { - var policy = await _policyRepository.GetByOrganizationIdTypeAsync( - _currentContext.OrganizationId.Value, type); - if (policy == null) + /// + /// List all policies. + /// + /// + /// Returns a list of your organization's policies. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() { - policy = model.ToPolicy(_currentContext.OrganizationId.Value); + var policies = await _policyRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); + var policyResponses = policies.Select(p => new PolicyResponseModel(p)); + var response = new ListResponseModel(policyResponses); + return new JsonResult(response); } - else + + /// + /// Update a policy. + /// + /// + /// Updates the specified policy. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The type of policy to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(PolicyResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) { - policy = model.ToPolicy(policy); + var policy = await _policyRepository.GetByOrganizationIdTypeAsync( + _currentContext.OrganizationId.Value, type); + if (policy == null) + { + policy = model.ToPolicy(_currentContext.OrganizationId.Value); + } + else + { + policy = model.ToPolicy(policy); + } + await _policyService.SaveAsync(policy, _userService, _organizationService, null); + var response = new PolicyResponseModel(policy); + return new JsonResult(response); } - await _policyService.SaveAsync(policy, _userService, _organizationService, null); - var response = new PolicyResponseModel(policy); - return new JsonResult(response); } } diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 20b707f5d2..7bebb1a9b2 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -17,208 +17,209 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Bit.Commercial.Core.Utilities; #endif -namespace Bit.Api; - -public class Startup +namespace Bit.Api { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; private set; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - if (!globalSettings.SelfHosted) + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - services.Configure(Configuration.GetSection("IpRateLimitOptions")); - services.Configure(Configuration.GetSection("IpRateLimitPolicies")); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); + public IConfiguration Configuration { get; private set; } + public IWebHostEnvironment Environment { get; set; } - // Event Grid - if (!string.IsNullOrWhiteSpace(globalSettings.EventGridKey)) + public void ConfigureServices(IServiceCollection services) { - ApiHelpers.EventGridKey = globalSettings.EventGridKey; - } + // Options + services.AddOptions(); - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - services.TryAddSingleton(); - - // Caching - services.AddMemoryCache(); - services.AddDistributedCache(globalSettings); - - // BitPay - services.AddSingleton(); - - if (!globalSettings.SelfHosted) - { - services.AddIpRateLimiting(globalSettings); - } - - // Identity - services.AddCustomIdentityServices(globalSettings); - services.AddIdentityAuthenticationServices(globalSettings, Environment, config => - { - config.AddPolicy("Application", policy => + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + if (!globalSettings.SelfHosted) { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - }); - config.AddPolicy("Web", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - policy.RequireClaim(JwtClaimTypes.ClientId, "web"); - }); - config.AddPolicy("Push", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.push"); - }); - config.AddPolicy("Licensing", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.licensing"); - }); - config.AddPolicy("Organization", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.organization"); - }); - config.AddPolicy("Installation", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.installation"); - }); - }); + services.Configure(Configuration.GetSection("IpRateLimitOptions")); + services.Configure(Configuration.GetSection("IpRateLimitPolicies")); + } - services.AddScoped(); + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - services.AddCoreLocalizationServices(); + // Event Grid + if (!string.IsNullOrWhiteSpace(globalSettings.EventGridKey)) + { + ApiHelpers.EventGridKey = globalSettings.EventGridKey; + } + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + services.TryAddSingleton(); + + // Caching + services.AddMemoryCache(); + services.AddDistributedCache(globalSettings); + + // BitPay + services.AddSingleton(); + + if (!globalSettings.SelfHosted) + { + services.AddIpRateLimiting(globalSettings); + } + + // Identity + services.AddCustomIdentityServices(globalSettings); + services.AddIdentityAuthenticationServices(globalSettings, Environment, config => + { + config.AddPolicy("Application", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + }); + config.AddPolicy("Web", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + policy.RequireClaim(JwtClaimTypes.ClientId, "web"); + }); + config.AddPolicy("Push", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.push"); + }); + config.AddPolicy("Licensing", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.licensing"); + }); + config.AddPolicy("Organization", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.organization"); + }); + config.AddPolicy("Installation", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.installation"); + }); + }); + + services.AddScoped(); + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + services.AddCoreLocalizationServices(); #if OSS - services.AddOosServices(); + services.AddOosServices(); #else - services.AddCommCoreServices(); + services.AddCommCoreServices(); #endif - // MVC - services.AddMvc(config => - { - config.Conventions.Add(new ApiExplorerGroupConvention()); - config.Conventions.Add(new PublicApiControllersModelConvention()); - }); - - services.AddSwagger(globalSettings); - Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); - services.AddHostedService(); - - if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) - { - services.AddHostedService(); - } - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings, - ILogger logger) - { - IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - if (!globalSettings.SelfHosted) - { - // Rate limiting - app.UseMiddleware(); - } - else - { - app.UseForwardedHeaders(globalSettings); - } - - // Add localization - app.UseCoreLocalization(); - - // Add static files to the request pipeline. - app.UseStaticFiles(); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add authentication and authorization to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add current context - app.UseMiddleware(); - - // Add endpoints to the request pipeline. - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - - // Add Swagger - if (Environment.IsDevelopment() || globalSettings.SelfHosted) - { - app.UseSwagger(config => + // MVC + services.AddMvc(config => { - config.RouteTemplate = "specs/{documentName}/swagger.json"; - config.PreSerializeFilters.Add((swaggerDoc, httpReq) => - swaggerDoc.Servers = new List - { - new OpenApiServer { Url = globalSettings.BaseServiceUri.Api } - }); + config.Conventions.Add(new ApiExplorerGroupConvention()); + config.Conventions.Add(new PublicApiControllersModelConvention()); }); - app.UseSwaggerUI(config => + + services.AddSwagger(globalSettings); + Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); + services.AddHostedService(); + + if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) { - config.DocumentTitle = "Bitwarden API Documentation"; - config.RoutePrefix = "docs"; - config.SwaggerEndpoint($"{globalSettings.BaseServiceUri.Api}/specs/public/swagger.json", - "Bitwarden Public API"); - config.OAuthClientId("accountType.id"); - config.OAuthClientSecret("secretKey"); - }); + services.AddHostedService(); + } } - // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings, + ILogger logger) + { + IdentityModelEventSource.ShowPII = true; + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + if (!globalSettings.SelfHosted) + { + // Rate limiting + app.UseMiddleware(); + } + else + { + app.UseForwardedHeaders(globalSettings); + } + + // Add localization + app.UseCoreLocalization(); + + // Add static files to the request pipeline. + app.UseStaticFiles(); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add authentication and authorization to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add current context + app.UseMiddleware(); + + // Add endpoints to the request pipeline. + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + + // Add Swagger + if (Environment.IsDevelopment() || globalSettings.SelfHosted) + { + app.UseSwagger(config => + { + config.RouteTemplate = "specs/{documentName}/swagger.json"; + config.PreSerializeFilters.Add((swaggerDoc, httpReq) => + swaggerDoc.Servers = new List + { + new OpenApiServer { Url = globalSettings.BaseServiceUri.Api } + }); + }); + app.UseSwaggerUI(config => + { + config.DocumentTitle = "Bitwarden API Documentation"; + config.RoutePrefix = "docs"; + config.SwaggerEndpoint($"{globalSettings.BaseServiceUri.Api}/specs/public/swagger.json", + "Bitwarden Public API"); + config.OAuthClientId("accountType.id"); + config.OAuthClientSecret("secretKey"); + }); + } + + // Log startup + logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + } } } diff --git a/src/Api/Utilities/ApiExplorerGroupConvention.cs b/src/Api/Utilities/ApiExplorerGroupConvention.cs index 42b1c8d6e7..5b8d7559ae 100644 --- a/src/Api/Utilities/ApiExplorerGroupConvention.cs +++ b/src/Api/Utilities/ApiExplorerGroupConvention.cs @@ -1,12 +1,13 @@ using Microsoft.AspNetCore.Mvc.ApplicationModels; -namespace Bit.Api.Utilities; - -public class ApiExplorerGroupConvention : IControllerModelConvention +namespace Bit.Api.Utilities { - public void Apply(ControllerModel controller) + public class ApiExplorerGroupConvention : IControllerModelConvention { - var controllerNamespace = controller.ControllerType.Namespace; - controller.ApiExplorer.GroupName = controllerNamespace.Contains(".Public.") ? "public" : "internal"; + public void Apply(ControllerModel controller) + { + var controllerNamespace = controller.ControllerType.Namespace; + controller.ApiExplorer.GroupName = controllerNamespace.Contains(".Public.") ? "public" : "internal"; + } } } diff --git a/src/Api/Utilities/ApiHelpers.cs b/src/Api/Utilities/ApiHelpers.cs index 58097089f4..920c15dd2c 100644 --- a/src/Api/Utilities/ApiHelpers.cs +++ b/src/Api/Utilities/ApiHelpers.cs @@ -4,69 +4,70 @@ using Azure.Messaging.EventGrid.SystemEvents; using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Utilities; - -public static class ApiHelpers +namespace Bit.Api.Utilities { - public static string EventGridKey { get; set; } - public async static Task ReadJsonFileFromBody(HttpContext httpContext, IFormFile file, long maxSize = 51200) + public static class ApiHelpers { - T obj = default(T); - if (file != null && httpContext.Request.ContentLength.HasValue && httpContext.Request.ContentLength.Value <= maxSize) + public static string EventGridKey { get; set; } + public async static Task ReadJsonFileFromBody(HttpContext httpContext, IFormFile file, long maxSize = 51200) { - try + T obj = default(T); + if (file != null && httpContext.Request.ContentLength.HasValue && httpContext.Request.ContentLength.Value <= maxSize) { - using var stream = file.OpenReadStream(); - obj = await JsonSerializer.DeserializeAsync(stream, JsonHelpers.IgnoreCase); - } - catch { } - } - - return obj; - } - - /// - /// Validates Azure event subscription and calls the appropriate event handler. Responds HttpOk. - /// - /// HttpRequest received from Azure - /// Dictionary of eventType strings and their associated handlers. - /// OkObjectResult - /// Reference https://docs.microsoft.com/en-us/azure/event-grid/receive-events - public async static Task HandleAzureEvents(HttpRequest request, - Dictionary> eventTypeHandlers) - { - var queryKey = request.Query["key"]; - - if (!CoreHelpers.FixedTimeEquals(queryKey, EventGridKey)) - { - return new UnauthorizedObjectResult("Authentication failed. Please use a valid key."); - } - - var response = string.Empty; - var requestData = await BinaryData.FromStreamAsync(request.Body); - var eventGridEvents = EventGridEvent.ParseMany(requestData); - foreach (var eventGridEvent in eventGridEvents) - { - if (eventGridEvent.TryGetSystemEventData(out object systemEvent)) - { - if (systemEvent is SubscriptionValidationEventData eventData) + try { - // Might want to enable additional validation: subject, topic etc. - var responseData = new SubscriptionValidationResponse() - { - ValidationResponse = eventData.ValidationCode - }; + using var stream = file.OpenReadStream(); + obj = await JsonSerializer.DeserializeAsync(stream, JsonHelpers.IgnoreCase); + } + catch { } + } - return new OkObjectResult(responseData); + return obj; + } + + /// + /// Validates Azure event subscription and calls the appropriate event handler. Responds HttpOk. + /// + /// HttpRequest received from Azure + /// Dictionary of eventType strings and their associated handlers. + /// OkObjectResult + /// Reference https://docs.microsoft.com/en-us/azure/event-grid/receive-events + public async static Task HandleAzureEvents(HttpRequest request, + Dictionary> eventTypeHandlers) + { + var queryKey = request.Query["key"]; + + if (!CoreHelpers.FixedTimeEquals(queryKey, EventGridKey)) + { + return new UnauthorizedObjectResult("Authentication failed. Please use a valid key."); + } + + var response = string.Empty; + var requestData = await BinaryData.FromStreamAsync(request.Body); + var eventGridEvents = EventGridEvent.ParseMany(requestData); + foreach (var eventGridEvent in eventGridEvents) + { + if (eventGridEvent.TryGetSystemEventData(out object systemEvent)) + { + if (systemEvent is SubscriptionValidationEventData eventData) + { + // Might want to enable additional validation: subject, topic etc. + var responseData = new SubscriptionValidationResponse() + { + ValidationResponse = eventData.ValidationCode + }; + + return new OkObjectResult(responseData); + } + } + + if (eventTypeHandlers.ContainsKey(eventGridEvent.EventType)) + { + await eventTypeHandlers[eventGridEvent.EventType](eventGridEvent); } } - if (eventTypeHandlers.ContainsKey(eventGridEvent.EventType)) - { - await eventTypeHandlers[eventGridEvent.EventType](eventGridEvent); - } + return new OkObjectResult(response); } - - return new OkObjectResult(response); } } diff --git a/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs b/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs index e0c6045461..27169a57c4 100644 --- a/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs +++ b/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs @@ -1,20 +1,21 @@ using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Api.Utilities; - -[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] -public class DisableFormValueModelBindingAttribute : Attribute, IResourceFilter +namespace Bit.Api.Utilities { - public void OnResourceExecuting(ResourceExecutingContext context) + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] + public class DisableFormValueModelBindingAttribute : Attribute, IResourceFilter { - var factories = context.ValueProviderFactories; - factories.RemoveType(); - factories.RemoveType(); - factories.RemoveType(); - } + public void OnResourceExecuting(ResourceExecutingContext context) + { + var factories = context.ValueProviderFactories; + factories.RemoveType(); + factories.RemoveType(); + factories.RemoveType(); + } - public void OnResourceExecuted(ResourceExecutedContext context) - { + public void OnResourceExecuted(ResourceExecutedContext context) + { + } } } diff --git a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs index 422bfa62d4..846bcda251 100644 --- a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs @@ -6,116 +6,117 @@ using Microsoft.IdentityModel.Tokens; using Stripe; using InternalApi = Bit.Core.Models.Api; -namespace Bit.Api.Utilities; - -public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute +namespace Bit.Api.Utilities { - private readonly bool _publicApi; - - public ExceptionHandlerFilterAttribute(bool publicApi) + public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute { - _publicApi = publicApi; - } + private readonly bool _publicApi; - public override void OnException(ExceptionContext context) - { - var errorMessage = "An error has occurred."; - - var exception = context.Exception; - if (exception == null) + public ExceptionHandlerFilterAttribute(bool publicApi) { - // Should never happen. - return; + _publicApi = publicApi; } - ErrorResponseModel publicErrorModel = null; - InternalApi.ErrorResponseModel internalErrorModel = null; - if (exception is BadRequestException badRequestException) + public override void OnException(ExceptionContext context) { - context.HttpContext.Response.StatusCode = 400; - if (badRequestException.ModelState != null) + var errorMessage = "An error has occurred."; + + var exception = context.Exception; + if (exception == null) { - if (_publicApi) + // Should never happen. + return; + } + + ErrorResponseModel publicErrorModel = null; + InternalApi.ErrorResponseModel internalErrorModel = null; + if (exception is BadRequestException badRequestException) + { + context.HttpContext.Response.StatusCode = 400; + if (badRequestException.ModelState != null) { - publicErrorModel = new ErrorResponseModel(badRequestException.ModelState); + if (_publicApi) + { + publicErrorModel = new ErrorResponseModel(badRequestException.ModelState); + } + else + { + internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); + } } else { - internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); + errorMessage = badRequestException.Message; } } + else if (exception is StripeException stripeException && stripeException?.StripeError?.Type == "card_error") + { + context.HttpContext.Response.StatusCode = 400; + if (_publicApi) + { + publicErrorModel = new ErrorResponseModel(stripeException.StripeError.Param, + stripeException.Message); + } + else + { + internalErrorModel = new InternalApi.ErrorResponseModel(stripeException.StripeError.Param, + stripeException.Message); + } + } + else if (exception is GatewayException) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is ApplicationException) + { + context.HttpContext.Response.StatusCode = 402; + } + else if (exception is NotFoundException) + { + errorMessage = "Resource not found."; + context.HttpContext.Response.StatusCode = 404; + } + else if (exception is SecurityTokenValidationException) + { + errorMessage = "Invalid token."; + context.HttpContext.Response.StatusCode = 403; + } + else if (exception is UnauthorizedAccessException) + { + errorMessage = "Unauthorized."; + context.HttpContext.Response.StatusCode = 401; + } else { - errorMessage = badRequestException.Message; + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + logger.LogError(0, exception, exception.Message); + errorMessage = "An unhandled server error has occurred."; + context.HttpContext.Response.StatusCode = 500; } - } - else if (exception is StripeException stripeException && stripeException?.StripeError?.Type == "card_error") - { - context.HttpContext.Response.StatusCode = 400; + if (_publicApi) { - publicErrorModel = new ErrorResponseModel(stripeException.StripeError.Param, - stripeException.Message); + var errorModel = publicErrorModel ?? new ErrorResponseModel(errorMessage); + context.Result = new ObjectResult(errorModel); } else { - internalErrorModel = new InternalApi.ErrorResponseModel(stripeException.StripeError.Param, - stripeException.Message); + var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); + var env = context.HttpContext.RequestServices.GetRequiredService(); + if (env.IsDevelopment()) + { + errorModel.ExceptionMessage = exception.Message; + errorModel.ExceptionStackTrace = exception.StackTrace; + errorModel.InnerExceptionMessage = exception?.InnerException?.Message; + } + context.Result = new ObjectResult(errorModel); } } - else if (exception is GatewayException) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is ApplicationException) - { - context.HttpContext.Response.StatusCode = 402; - } - else if (exception is NotFoundException) - { - errorMessage = "Resource not found."; - context.HttpContext.Response.StatusCode = 404; - } - else if (exception is SecurityTokenValidationException) - { - errorMessage = "Invalid token."; - context.HttpContext.Response.StatusCode = 403; - } - else if (exception is UnauthorizedAccessException) - { - errorMessage = "Unauthorized."; - context.HttpContext.Response.StatusCode = 401; - } - else - { - var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); - errorMessage = "An unhandled server error has occurred."; - context.HttpContext.Response.StatusCode = 500; - } - - if (_publicApi) - { - var errorModel = publicErrorModel ?? new ErrorResponseModel(errorMessage); - context.Result = new ObjectResult(errorModel); - } - else - { - var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); - var env = context.HttpContext.RequestServices.GetRequiredService(); - if (env.IsDevelopment()) - { - errorModel.ExceptionMessage = exception.Message; - errorModel.ExceptionStackTrace = exception.StackTrace; - errorModel.InnerExceptionMessage = exception?.InnerException?.Message; - } - context.Result = new ObjectResult(errorModel); - } } } diff --git a/src/Api/Utilities/ModelStateValidationFilterAttribute.cs b/src/Api/Utilities/ModelStateValidationFilterAttribute.cs index 3fe4f748fb..d6803f91c2 100644 --- a/src/Api/Utilities/ModelStateValidationFilterAttribute.cs +++ b/src/Api/Utilities/ModelStateValidationFilterAttribute.cs @@ -3,26 +3,27 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; using InternalApi = Bit.Core.Models.Api; -namespace Bit.Api.Utilities; - -public class ModelStateValidationFilterAttribute : SharedWeb.Utilities.ModelStateValidationFilterAttribute +namespace Bit.Api.Utilities { - private readonly bool _publicApi; - - public ModelStateValidationFilterAttribute(bool publicApi) + public class ModelStateValidationFilterAttribute : SharedWeb.Utilities.ModelStateValidationFilterAttribute { - _publicApi = publicApi; - } + private readonly bool _publicApi; - protected override void OnModelStateInvalid(ActionExecutingContext context) - { - if (_publicApi) + public ModelStateValidationFilterAttribute(bool publicApi) { - context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); + _publicApi = publicApi; } - else + + protected override void OnModelStateInvalid(ActionExecutingContext context) { - context.Result = new BadRequestObjectResult(new InternalApi.ErrorResponseModel(context.ModelState)); + if (_publicApi) + { + context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); + } + else + { + context.Result = new BadRequestObjectResult(new InternalApi.ErrorResponseModel(context.ModelState)); + } } } } diff --git a/src/Api/Utilities/MultipartFormDataHelper.cs b/src/Api/Utilities/MultipartFormDataHelper.cs index c7ca42d507..a3e4b1967a 100644 --- a/src/Api/Utilities/MultipartFormDataHelper.cs +++ b/src/Api/Utilities/MultipartFormDataHelper.cs @@ -5,41 +5,75 @@ using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; -namespace Bit.Api.Utilities; - -public static class MultipartFormDataHelper +namespace Bit.Api.Utilities { - private static readonly FormOptions _defaultFormOptions = new FormOptions(); - - public static async Task GetFileAsync(this HttpRequest request, Func callback) + public static class MultipartFormDataHelper { - var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), - _defaultFormOptions.MultipartBoundaryLengthLimit); - var reader = new MultipartReader(boundary, request.Body); + private static readonly FormOptions _defaultFormOptions = new FormOptions(); - var firstSection = await reader.ReadNextSectionAsync(); - if (firstSection != null) + public static async Task GetFileAsync(this HttpRequest request, Func callback) { - if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out var firstContent)) + var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), + _defaultFormOptions.MultipartBoundaryLengthLimit); + var reader = new MultipartReader(boundary, request.Body); + + var firstSection = await reader.ReadNextSectionAsync(); + if (firstSection != null) { - if (HasFileContentDisposition(firstContent)) + if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out var firstContent)) { - // Old style with just data - var fileName = HeaderUtilities.RemoveQuotes(firstContent.FileName).ToString(); - using (firstSection.Body) + if (HasFileContentDisposition(firstContent)) { - await callback(firstSection.Body, fileName, null); + // Old style with just data + var fileName = HeaderUtilities.RemoveQuotes(firstContent.FileName).ToString(); + using (firstSection.Body) + { + await callback(firstSection.Body, fileName, null); + } + } + else if (HasDispositionName(firstContent, "key")) + { + // New style with key, then data + string key = null; + using (var sr = new StreamReader(firstSection.Body)) + { + key = await sr.ReadToEndAsync(); + } + + var secondSection = await reader.ReadNextSectionAsync(); + if (secondSection != null) + { + if (ContentDispositionHeaderValue.TryParse(secondSection.ContentDisposition, + out var secondContent) && HasFileContentDisposition(secondContent)) + { + var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString(); + using (secondSection.Body) + { + await callback(secondSection.Body, fileName, key); + } + } + + secondSection = null; + } } } - else if (HasDispositionName(firstContent, "key")) - { - // New style with key, then data - string key = null; - using (var sr = new StreamReader(firstSection.Body)) - { - key = await sr.ReadToEndAsync(); - } + firstSection = null; + } + } + + public static async Task GetSendFileAsync(this HttpRequest request, Func callback) + { + var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), + _defaultFormOptions.MultipartBoundaryLengthLimit); + var reader = new MultipartReader(boundary, request.Body); + + var firstSection = await reader.ReadNextSectionAsync(); + if (firstSection != null) + { + if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out _)) + { var secondSection = await reader.ReadNextSectionAsync(); if (secondSection != null) { @@ -49,102 +83,69 @@ public static class MultipartFormDataHelper var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString(); using (secondSection.Body) { - await callback(secondSection.Body, fileName, key); + var model = await JsonSerializer.DeserializeAsync(firstSection.Body); + await callback(secondSection.Body, fileName, model); } } secondSection = null; } + } + + firstSection = null; } - - firstSection = null; } - } - public static async Task GetSendFileAsync(this HttpRequest request, Func callback) - { - var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), - _defaultFormOptions.MultipartBoundaryLengthLimit); - var reader = new MultipartReader(boundary, request.Body); - - var firstSection = await reader.ReadNextSectionAsync(); - if (firstSection != null) + public static async Task GetFileAsync(this HttpRequest request, Func callback) { - if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out _)) + var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), + _defaultFormOptions.MultipartBoundaryLengthLimit); + var reader = new MultipartReader(boundary, request.Body); + + var dataSection = await reader.ReadNextSectionAsync(); + if (dataSection != null) { - var secondSection = await reader.ReadNextSectionAsync(); - if (secondSection != null) + if (ContentDispositionHeaderValue.TryParse(dataSection.ContentDisposition, out var dataContent) + && HasFileContentDisposition(dataContent)) { - if (ContentDispositionHeaderValue.TryParse(secondSection.ContentDisposition, - out var secondContent) && HasFileContentDisposition(secondContent)) + using (dataSection.Body) { - var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString(); - using (secondSection.Body) - { - var model = await JsonSerializer.DeserializeAsync(firstSection.Body); - await callback(secondSection.Body, fileName, model); - } + await callback(dataSection.Body); } - - secondSection = null; } - + dataSection = null; } - - firstSection = null; } - } - public static async Task GetFileAsync(this HttpRequest request, Func callback) - { - var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), - _defaultFormOptions.MultipartBoundaryLengthLimit); - var reader = new MultipartReader(boundary, request.Body); - var dataSection = await reader.ReadNextSectionAsync(); - if (dataSection != null) + private static string GetBoundary(MediaTypeHeaderValue contentType, int lengthLimit) { - if (ContentDispositionHeaderValue.TryParse(dataSection.ContentDisposition, out var dataContent) - && HasFileContentDisposition(dataContent)) + var boundary = HeaderUtilities.RemoveQuotes(contentType.Boundary); + if (StringSegment.IsNullOrEmpty(boundary)) { - using (dataSection.Body) - { - await callback(dataSection.Body); - } + throw new InvalidDataException("Missing content-type boundary."); } - dataSection = null; + + if (boundary.Length > lengthLimit) + { + throw new InvalidDataException($"Multipart boundary length limit {lengthLimit} exceeded."); + } + + return boundary.ToString(); } - } - - private static string GetBoundary(MediaTypeHeaderValue contentType, int lengthLimit) - { - var boundary = HeaderUtilities.RemoveQuotes(contentType.Boundary); - if (StringSegment.IsNullOrEmpty(boundary)) + private static bool HasFileContentDisposition(ContentDispositionHeaderValue content) { - throw new InvalidDataException("Missing content-type boundary."); + // Content-Disposition: form-data; name="data"; filename="Misc 002.jpg" + return content != null && content.DispositionType.Equals("form-data") && + (!StringSegment.IsNullOrEmpty(content.FileName) || !StringSegment.IsNullOrEmpty(content.FileNameStar)); } - if (boundary.Length > lengthLimit) + private static bool HasDispositionName(ContentDispositionHeaderValue content, string name) { - throw new InvalidDataException($"Multipart boundary length limit {lengthLimit} exceeded."); + // Content-Disposition: form-data; name="key"; + return content != null && content.DispositionType.Equals("form-data") && content.Name == name; } - - return boundary.ToString(); - } - - private static bool HasFileContentDisposition(ContentDispositionHeaderValue content) - { - // Content-Disposition: form-data; name="data"; filename="Misc 002.jpg" - return content != null && content.DispositionType.Equals("form-data") && - (!StringSegment.IsNullOrEmpty(content.FileName) || !StringSegment.IsNullOrEmpty(content.FileNameStar)); - } - - private static bool HasDispositionName(ContentDispositionHeaderValue content, string name) - { - // Content-Disposition: form-data; name="key"; - return content != null && content.DispositionType.Equals("form-data") && content.Name == name; } } diff --git a/src/Api/Utilities/PublicApiControllersModelConvention.cs b/src/Api/Utilities/PublicApiControllersModelConvention.cs index a7fabb0319..64101148e1 100644 --- a/src/Api/Utilities/PublicApiControllersModelConvention.cs +++ b/src/Api/Utilities/PublicApiControllersModelConvention.cs @@ -1,14 +1,15 @@ using Microsoft.AspNetCore.Mvc.ApplicationModels; -namespace Bit.Api.Utilities; - -public class PublicApiControllersModelConvention : IControllerModelConvention +namespace Bit.Api.Utilities { - public void Apply(ControllerModel controller) + public class PublicApiControllersModelConvention : IControllerModelConvention { - var controllerNamespace = controller.ControllerType.Namespace; - var publicApi = controllerNamespace.Contains(".Public."); - controller.Filters.Add(new ExceptionHandlerFilterAttribute(publicApi)); - controller.Filters.Add(new ModelStateValidationFilterAttribute(publicApi)); + public void Apply(ControllerModel controller) + { + var controllerNamespace = controller.ControllerType.Namespace; + var publicApi = controllerNamespace.Contains(".Public."); + controller.Filters.Add(new ExceptionHandlerFilterAttribute(publicApi)); + controller.Filters.Add(new ModelStateValidationFilterAttribute(publicApi)); + } } } diff --git a/src/Api/Utilities/SecretsManagerAttribute.cs b/src/Api/Utilities/SecretsManagerAttribute.cs index 87540c56e4..44ba465861 100644 --- a/src/Api/Utilities/SecretsManagerAttribute.cs +++ b/src/Api/Utilities/SecretsManagerAttribute.cs @@ -1,20 +1,22 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; -namespace Bit.Api.Utilities; - -[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] -public class SecretsManagerAttribute : Attribute, IResourceFilter +namespace Bit.Api.Utilities { - public void OnResourceExecuting(ResourceExecutingContext context) - { - var env = context.HttpContext.RequestServices.GetService(); - if (!env.IsDevelopment()) - { - context.Result = new NotFoundResult(); - } - } - public void OnResourceExecuted(ResourceExecutedContext context) { } + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] + public class SecretsManagerAttribute : Attribute, IResourceFilter + { + public void OnResourceExecuting(ResourceExecutingContext context) + { + var env = context.HttpContext.RequestServices.GetService(); + if (!env.IsDevelopment()) + { + context.Result = new NotFoundResult(); + } + } + + public void OnResourceExecuted(ResourceExecutedContext context) { } + } } diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index ff0ff0705e..4e57d51650 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -1,71 +1,72 @@ using Bit.Core.Settings; using Microsoft.OpenApi.Models; -namespace Bit.Api.Utilities; - -public static class ServiceCollectionExtensions +namespace Bit.Api.Utilities { - public static void AddSwagger(this IServiceCollection services, GlobalSettings globalSettings) + public static class ServiceCollectionExtensions { - services.AddSwaggerGen(config => + public static void AddSwagger(this IServiceCollection services, GlobalSettings globalSettings) { - config.SwaggerDoc("public", new OpenApiInfo + services.AddSwaggerGen(config => { - Title = "Bitwarden Public API", - Version = "latest", - Contact = new OpenApiContact + config.SwaggerDoc("public", new OpenApiInfo { - Name = "Bitwarden Support", - Url = new Uri("https://bitwarden.com"), - Email = "support@bitwarden.com" - }, - Description = "The Bitwarden public APIs.", - License = new OpenApiLicense - { - Name = "GNU Affero General Public License v3.0", - Url = new Uri("https://github.com/bitwarden/server/blob/master/LICENSE.txt") - } - }); - config.SwaggerDoc("internal", new OpenApiInfo { Title = "Bitwarden Internal API", Version = "latest" }); - - config.AddSecurityDefinition("OAuth2 Client Credentials", new OpenApiSecurityScheme - { - Type = SecuritySchemeType.OAuth2, - Flows = new OpenApiOAuthFlows - { - ClientCredentials = new OpenApiOAuthFlow + Title = "Bitwarden Public API", + Version = "latest", + Contact = new OpenApiContact { - TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), - Scopes = new Dictionary - { - { "api.organization", "Organization APIs" }, - }, - } - }, - }); - - config.AddSecurityRequirement(new OpenApiSecurityRequirement - { - { - new OpenApiSecurityScheme - { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = "OAuth2 Client Credentials" - }, + Name = "Bitwarden Support", + Url = new Uri("https://bitwarden.com"), + Email = "support@bitwarden.com" }, - new[] { "api.organization" } - } + Description = "The Bitwarden public APIs.", + License = new OpenApiLicense + { + Name = "GNU Affero General Public License v3.0", + Url = new Uri("https://github.com/bitwarden/server/blob/master/LICENSE.txt") + } + }); + config.SwaggerDoc("internal", new OpenApiInfo { Title = "Bitwarden Internal API", Version = "latest" }); + + config.AddSecurityDefinition("OAuth2 Client Credentials", new OpenApiSecurityScheme + { + Type = SecuritySchemeType.OAuth2, + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), + Scopes = new Dictionary + { + { "api.organization", "Organization APIs" }, + }, + } + }, + }); + + config.AddSecurityRequirement(new OpenApiSecurityRequirement + { + { + new OpenApiSecurityScheme + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = "OAuth2 Client Credentials" + }, + }, + new[] { "api.organization" } + } + }); + + config.DescribeAllParametersInCamelCase(); + // config.UseReferencedDefinitionsForEnums(); + + var apiFilePath = Path.Combine(AppContext.BaseDirectory, "Api.xml"); + config.IncludeXmlComments(apiFilePath, true); + var coreFilePath = Path.Combine(AppContext.BaseDirectory, "Core.xml"); + config.IncludeXmlComments(coreFilePath); }); - - config.DescribeAllParametersInCamelCase(); - // config.UseReferencedDefinitionsForEnums(); - - var apiFilePath = Path.Combine(AppContext.BaseDirectory, "Api.xml"); - config.IncludeXmlComments(apiFilePath, true); - var coreFilePath = Path.Combine(AppContext.BaseDirectory, "Core.xml"); - config.IncludeXmlComments(coreFilePath); - }); + } } } diff --git a/src/Billing/BillingSettings.cs b/src/Billing/BillingSettings.cs index 5be6b205fe..0e61775ee4 100644 --- a/src/Billing/BillingSettings.cs +++ b/src/Billing/BillingSettings.cs @@ -1,22 +1,23 @@ -namespace Bit.Billing; - -public class BillingSettings +namespace Bit.Billing { - public virtual string JobsKey { get; set; } - public virtual string StripeWebhookKey { get; set; } - public virtual string StripeWebhookSecret { get; set; } - public virtual bool StripeEventParseThrowMismatch { get; set; } = true; - public virtual string BitPayWebhookKey { get; set; } - public virtual string AppleWebhookKey { get; set; } - public virtual string FreshdeskWebhookKey { get; set; } - public virtual string FreshdeskApiKey { get; set; } - public virtual string FreshsalesApiKey { get; set; } - public virtual PayPalSettings PayPal { get; set; } = new PayPalSettings(); - - public class PayPalSettings + public class BillingSettings { - public virtual bool Production { get; set; } - public virtual string BusinessId { get; set; } - public virtual string WebhookKey { get; set; } + public virtual string JobsKey { get; set; } + public virtual string StripeWebhookKey { get; set; } + public virtual string StripeWebhookSecret { get; set; } + public virtual bool StripeEventParseThrowMismatch { get; set; } = true; + public virtual string BitPayWebhookKey { get; set; } + public virtual string AppleWebhookKey { get; set; } + public virtual string FreshdeskWebhookKey { get; set; } + public virtual string FreshdeskApiKey { get; set; } + public virtual string FreshsalesApiKey { get; set; } + public virtual PayPalSettings PayPal { get; set; } = new PayPalSettings(); + + public class PayPalSettings + { + public virtual bool Production { get; set; } + public virtual string BusinessId { get; set; } + public virtual string WebhookKey { get; set; } + } } } diff --git a/src/Billing/Constants/HandledStripeWebhook.cs b/src/Billing/Constants/HandledStripeWebhook.cs index f40b370f4d..08d6daafcd 100644 --- a/src/Billing/Constants/HandledStripeWebhook.cs +++ b/src/Billing/Constants/HandledStripeWebhook.cs @@ -1,13 +1,14 @@ -namespace Bit.Billing.Constants; - -public static class HandledStripeWebhook +namespace Bit.Billing.Constants { - public static string SubscriptionDeleted => "customer.subscription.deleted"; - public static string SubscriptionUpdated => "customer.subscription.updated"; - public static string UpcomingInvoice => "invoice.upcoming"; - public static string ChargeSucceeded => "charge.succeeded"; - public static string ChargeRefunded => "charge.refunded"; - public static string PaymentSucceeded => "invoice.payment_succeeded"; - public static string PaymentFailed => "invoice.payment_failed"; - public static string InvoiceCreated => "invoice.created"; + public static class HandledStripeWebhook + { + public static string SubscriptionDeleted => "customer.subscription.deleted"; + public static string SubscriptionUpdated => "customer.subscription.updated"; + public static string UpcomingInvoice => "invoice.upcoming"; + public static string ChargeSucceeded => "charge.succeeded"; + public static string ChargeRefunded => "charge.refunded"; + public static string PaymentSucceeded => "invoice.payment_succeeded"; + public static string PaymentFailed => "invoice.payment_failed"; + public static string InvoiceCreated => "invoice.created"; + } } diff --git a/src/Billing/Controllers/AppleController.cs b/src/Billing/Controllers/AppleController.cs index 1bcbbf2ad6..dc8c827868 100644 --- a/src/Billing/Controllers/AppleController.cs +++ b/src/Billing/Controllers/AppleController.cs @@ -4,58 +4,59 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers; - -[Route("apple")] -public class AppleController : Controller +namespace Bit.Billing.Controllers { - private readonly BillingSettings _billingSettings; - private readonly ILogger _logger; - - public AppleController( - IOptions billingSettings, - ILogger logger) + [Route("apple")] + public class AppleController : Controller { - _billingSettings = billingSettings?.Value; - _logger = logger; - } + private readonly BillingSettings _billingSettings; + private readonly ILogger _logger; - [HttpPost("iap")] - public async Task PostIap() - { - if (HttpContext?.Request?.Query == null) + public AppleController( + IOptions billingSettings, + ILogger logger) { - return new BadRequestResult(); + _billingSettings = billingSettings?.Value; + _logger = logger; } - var key = HttpContext.Request.Query.ContainsKey("key") ? - HttpContext.Request.Query["key"].ToString() : null; - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.AppleWebhookKey)) + [HttpPost("iap")] + public async Task PostIap() { - return new BadRequestResult(); - } + if (HttpContext?.Request?.Query == null) + { + return new BadRequestResult(); + } - string body = null; - using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) - { - body = await reader.ReadToEndAsync(); - } + var key = HttpContext.Request.Query.ContainsKey("key") ? + HttpContext.Request.Query["key"].ToString() : null; + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.AppleWebhookKey)) + { + return new BadRequestResult(); + } - if (string.IsNullOrWhiteSpace(body)) - { - return new BadRequestResult(); - } + string body = null; + using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) + { + body = await reader.ReadToEndAsync(); + } - try - { - var json = JsonSerializer.Serialize(JsonSerializer.Deserialize(body), JsonHelpers.Indented); - _logger.LogInformation(Bit.Core.Constants.BypassFiltersEventId, "Apple IAP Notification:\n\n{0}", json); - return new OkResult(); - } - catch (Exception e) - { - _logger.LogError(e, "Error processing IAP status notification."); - return new BadRequestResult(); + if (string.IsNullOrWhiteSpace(body)) + { + return new BadRequestResult(); + } + + try + { + var json = JsonSerializer.Serialize(JsonSerializer.Deserialize(body), JsonHelpers.Indented); + _logger.LogInformation(Bit.Core.Constants.BypassFiltersEventId, "Apple IAP Notification:\n\n{0}", json); + return new OkResult(); + } + catch (Exception e) + { + _logger.LogError(e, "Error processing IAP status notification."); + return new BadRequestResult(); + } } } } diff --git a/src/Billing/Controllers/BitPayController.cs b/src/Billing/Controllers/BitPayController.cs index 539d355951..520f00a22a 100644 --- a/src/Billing/Controllers/BitPayController.cs +++ b/src/Billing/Controllers/BitPayController.cs @@ -9,199 +9,200 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers; - -[Route("bitpay")] -public class BitPayController : Controller +namespace Bit.Billing.Controllers { - private readonly BillingSettings _billingSettings; - private readonly BitPayClient _bitPayClient; - private readonly ITransactionRepository _transactionRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IPaymentService _paymentService; - private readonly ILogger _logger; - - public BitPayController( - IOptions billingSettings, - BitPayClient bitPayClient, - ITransactionRepository transactionRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - IMailService mailService, - IPaymentService paymentService, - ILogger logger) + [Route("bitpay")] + public class BitPayController : Controller { - _billingSettings = billingSettings?.Value; - _bitPayClient = bitPayClient; - _transactionRepository = transactionRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _mailService = mailService; - _paymentService = paymentService; - _logger = logger; - } + private readonly BillingSettings _billingSettings; + private readonly BitPayClient _bitPayClient; + private readonly ITransactionRepository _transactionRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IPaymentService _paymentService; + private readonly ILogger _logger; - [HttpPost("ipn")] - public async Task PostIpn([FromBody] BitPayEventModel model, [FromQuery] string key) - { - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.BitPayWebhookKey)) + public BitPayController( + IOptions billingSettings, + BitPayClient bitPayClient, + ITransactionRepository transactionRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IMailService mailService, + IPaymentService paymentService, + ILogger logger) { - return new BadRequestResult(); - } - if (model == null || string.IsNullOrWhiteSpace(model.Data?.Id) || - string.IsNullOrWhiteSpace(model.Event?.Name)) - { - return new BadRequestResult(); + _billingSettings = billingSettings?.Value; + _bitPayClient = bitPayClient; + _transactionRepository = transactionRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _mailService = mailService; + _paymentService = paymentService; + _logger = logger; } - if (model.Event.Name != "invoice_confirmed") + [HttpPost("ipn")] + public async Task PostIpn([FromBody] BitPayEventModel model, [FromQuery] string key) { - // Only processing confirmed invoice events for now. - return new OkResult(); - } - - var invoice = await _bitPayClient.GetInvoiceAsync(model.Data.Id); - if (invoice == null) - { - // Request forged...? - _logger.LogWarning("Invoice not found. #" + model.Data.Id); - return new BadRequestResult(); - } - - if (invoice.Status != "confirmed" && invoice.Status != "completed") - { - _logger.LogWarning("Invoice status of '" + invoice.Status + "' is not acceptable. #" + invoice.Id); - return new BadRequestResult(); - } - - if (invoice.Currency != "USD") - { - // Only process USD payments - _logger.LogWarning("Non USD payment received. #" + invoice.Id); - return new OkResult(); - } - - var ids = GetIdsFromPosData(invoice); - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - return new OkResult(); - } - - var isAccountCredit = IsAccountCredit(invoice); - if (!isAccountCredit) - { - // Only processing credits - _logger.LogWarning("Non-credit payment received. #" + invoice.Id); - return new OkResult(); - } - - var transaction = await _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); - if (transaction != null) - { - _logger.LogWarning("Already processed this invoice. #" + invoice.Id); - return new OkResult(); - } - - try - { - var tx = new Transaction + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.BitPayWebhookKey)) { - Amount = Convert.ToDecimal(invoice.Price), - CreationDate = GetTransactionDate(invoice), - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = TransactionType.Credit, - Gateway = GatewayType.BitPay, - GatewayId = invoice.Id, - PaymentMethodType = PaymentMethodType.BitPay, - Details = $"{invoice.Currency}, BitPay {invoice.Id}" - }; - await _transactionRepository.CreateAsync(tx); - - if (isAccountCredit) - { - string billingEmail = null; - if (tx.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); - if (org != null) - { - billingEmail = org.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(org, tx.Amount)) - { - await _organizationRepository.ReplaceAsync(org); - } - } - } - else - { - var user = await _userRepository.GetByIdAsync(tx.UserId.Value); - if (user != null) - { - billingEmail = user.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(user, tx.Amount)) - { - await _userRepository.ReplaceAsync(user); - } - } - } - - if (!string.IsNullOrWhiteSpace(billingEmail)) - { - await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); - } + return new BadRequestResult(); } - } - // Catch foreign key violations because user/org could have been deleted. - catch (SqlException e) when (e.Number == 547) { } - - return new OkResult(); - } - - private bool IsAccountCredit(BitPayLight.Models.Invoice.Invoice invoice) - { - return invoice != null && invoice.PosData != null && invoice.PosData.Contains("accountCredit:1"); - } - - private DateTime GetTransactionDate(BitPayLight.Models.Invoice.Invoice invoice) - { - var transactions = invoice.Transactions?.Where(t => t.Type == null && - !string.IsNullOrWhiteSpace(t.Confirmations) && t.Confirmations != "0"); - if (transactions != null && transactions.Count() == 1) - { - return DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, - DateTimeStyles.RoundtripKind); - } - return CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); - } - - public Tuple GetIdsFromPosData(BitPayLight.Models.Invoice.Invoice invoice) - { - Guid? orgId = null; - Guid? userId = null; - - if (invoice != null && !string.IsNullOrWhiteSpace(invoice.PosData) && invoice.PosData.Contains(":")) - { - var mainParts = invoice.PosData.Split(','); - foreach (var mainPart in mainParts) + if (model == null || string.IsNullOrWhiteSpace(model.Data?.Id) || + string.IsNullOrWhiteSpace(model.Event?.Name)) { - var parts = mainPart.Split(':'); - if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) + return new BadRequestResult(); + } + + if (model.Event.Name != "invoice_confirmed") + { + // Only processing confirmed invoice events for now. + return new OkResult(); + } + + var invoice = await _bitPayClient.GetInvoiceAsync(model.Data.Id); + if (invoice == null) + { + // Request forged...? + _logger.LogWarning("Invoice not found. #" + model.Data.Id); + return new BadRequestResult(); + } + + if (invoice.Status != "confirmed" && invoice.Status != "completed") + { + _logger.LogWarning("Invoice status of '" + invoice.Status + "' is not acceptable. #" + invoice.Id); + return new BadRequestResult(); + } + + if (invoice.Currency != "USD") + { + // Only process USD payments + _logger.LogWarning("Non USD payment received. #" + invoice.Id); + return new OkResult(); + } + + var ids = GetIdsFromPosData(invoice); + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + return new OkResult(); + } + + var isAccountCredit = IsAccountCredit(invoice); + if (!isAccountCredit) + { + // Only processing credits + _logger.LogWarning("Non-credit payment received. #" + invoice.Id); + return new OkResult(); + } + + var transaction = await _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); + if (transaction != null) + { + _logger.LogWarning("Already processed this invoice. #" + invoice.Id); + return new OkResult(); + } + + try + { + var tx = new Transaction { - if (parts[0] == "userId") + Amount = Convert.ToDecimal(invoice.Price), + CreationDate = GetTransactionDate(invoice), + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = TransactionType.Credit, + Gateway = GatewayType.BitPay, + GatewayId = invoice.Id, + PaymentMethodType = PaymentMethodType.BitPay, + Details = $"{invoice.Currency}, BitPay {invoice.Id}" + }; + await _transactionRepository.CreateAsync(tx); + + if (isAccountCredit) + { + string billingEmail = null; + if (tx.OrganizationId.HasValue) { - userId = id; + var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); + if (org != null) + { + billingEmail = org.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(org, tx.Amount)) + { + await _organizationRepository.ReplaceAsync(org); + } + } } - else if (parts[0] == "organizationId") + else { - orgId = id; + var user = await _userRepository.GetByIdAsync(tx.UserId.Value); + if (user != null) + { + billingEmail = user.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(user, tx.Amount)) + { + await _userRepository.ReplaceAsync(user); + } + } + } + + if (!string.IsNullOrWhiteSpace(billingEmail)) + { + await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); } } } + // Catch foreign key violations because user/org could have been deleted. + catch (SqlException e) when (e.Number == 547) { } + + return new OkResult(); } - return new Tuple(orgId, userId); + private bool IsAccountCredit(BitPayLight.Models.Invoice.Invoice invoice) + { + return invoice != null && invoice.PosData != null && invoice.PosData.Contains("accountCredit:1"); + } + + private DateTime GetTransactionDate(BitPayLight.Models.Invoice.Invoice invoice) + { + var transactions = invoice.Transactions?.Where(t => t.Type == null && + !string.IsNullOrWhiteSpace(t.Confirmations) && t.Confirmations != "0"); + if (transactions != null && transactions.Count() == 1) + { + return DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, + DateTimeStyles.RoundtripKind); + } + return CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); + } + + public Tuple GetIdsFromPosData(BitPayLight.Models.Invoice.Invoice invoice) + { + Guid? orgId = null; + Guid? userId = null; + + if (invoice != null && !string.IsNullOrWhiteSpace(invoice.PosData) && invoice.PosData.Contains(":")) + { + var mainParts = invoice.PosData.Split(','); + foreach (var mainPart in mainParts) + { + var parts = mainPart.Split(':'); + if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) + { + if (parts[0] == "userId") + { + userId = id; + } + else if (parts[0] == "organizationId") + { + orgId = id; + } + } + } + } + + return new Tuple(orgId, userId); + } } } diff --git a/src/Billing/Controllers/FreshdeskController.cs b/src/Billing/Controllers/FreshdeskController.cs index e38a892425..7e7b0a6b41 100644 --- a/src/Billing/Controllers/FreshdeskController.cs +++ b/src/Billing/Controllers/FreshdeskController.cs @@ -8,165 +8,166 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers; - -[Route("freshdesk")] -public class FreshdeskController : Controller +namespace Bit.Billing.Controllers { - private readonly BillingSettings _billingSettings; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly IHttpClientFactory _httpClientFactory; - - public FreshdeskController( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOptions billingSettings, - ILogger logger, - GlobalSettings globalSettings, - IHttpClientFactory httpClientFactory) + [Route("freshdesk")] + public class FreshdeskController : Controller { - _billingSettings = billingSettings?.Value; - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _logger = logger; - _globalSettings = globalSettings; - _httpClientFactory = httpClientFactory; - } + private readonly BillingSettings _billingSettings; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly IHttpClientFactory _httpClientFactory; - [HttpPost("webhook")] - public async Task PostWebhook([FromQuery, Required] string key, - [FromBody, Required] FreshdeskWebhookModel model) - { - if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(key, _billingSettings.FreshdeskWebhookKey)) + public FreshdeskController( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOptions billingSettings, + ILogger logger, + GlobalSettings globalSettings, + IHttpClientFactory httpClientFactory) { - return new BadRequestResult(); + _billingSettings = billingSettings?.Value; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _logger = logger; + _globalSettings = globalSettings; + _httpClientFactory = httpClientFactory; } - try + [HttpPost("webhook")] + public async Task PostWebhook([FromQuery, Required] string key, + [FromBody, Required] FreshdeskWebhookModel model) { - var ticketId = model.TicketId; - var ticketContactEmail = model.TicketContactEmail; - var ticketTags = model.TicketTags; - if (string.IsNullOrWhiteSpace(ticketId) || string.IsNullOrWhiteSpace(ticketContactEmail)) + if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(key, _billingSettings.FreshdeskWebhookKey)) { return new BadRequestResult(); } - var updateBody = new Dictionary(); - var note = string.Empty; - var customFields = new Dictionary(); - var user = await _userRepository.GetByEmailAsync(ticketContactEmail); - if (user != null) + try { - var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}"; - note += $"
  • User, {user.Email}: {userLink}
  • "; - customFields.Add("cf_user", userLink); - var tags = new HashSet(); - if (user.Premium) + var ticketId = model.TicketId; + var ticketContactEmail = model.TicketContactEmail; + var ticketTags = model.TicketTags; + if (string.IsNullOrWhiteSpace(ticketId) || string.IsNullOrWhiteSpace(ticketContactEmail)) { - tags.Add("Premium"); + return new BadRequestResult(); } - var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - foreach (var org in orgs) + var updateBody = new Dictionary(); + var note = string.Empty; + var customFields = new Dictionary(); + var user = await _userRepository.GetByEmailAsync(ticketContactEmail); + if (user != null) { - var orgNote = $"{org.Name} ({org.Seats.GetValueOrDefault()}): " + - $"{_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"; - note += $"
  • Org, {orgNote}
  • "; - if (!customFields.Any(kvp => kvp.Key == "cf_org")) + var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}"; + note += $"
  • User, {user.Email}: {userLink}
  • "; + customFields.Add("cf_user", userLink); + var tags = new HashSet(); + if (user.Premium) { - customFields.Add("cf_org", orgNote); - } - else - { - customFields["cf_org"] += $"\n{orgNote}"; + tags.Add("Premium"); } + var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - var planName = GetAttribute(org.PlanType).Name.Split(" ").FirstOrDefault(); - if (!string.IsNullOrWhiteSpace(planName)) + foreach (var org in orgs) { - tags.Add(string.Format("Org: {0}", planName)); - } - } - if (tags.Any()) - { - var tagsToUpdate = tags.ToList(); - if (!string.IsNullOrWhiteSpace(ticketTags)) - { - var splitTicketTags = ticketTags.Split(','); - for (var i = 0; i < splitTicketTags.Length; i++) + var orgNote = $"{org.Name} ({org.Seats.GetValueOrDefault()}): " + + $"{_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"; + note += $"
  • Org, {orgNote}
  • "; + if (!customFields.Any(kvp => kvp.Key == "cf_org")) { - tagsToUpdate.Insert(i, splitTicketTags[i]); + customFields.Add("cf_org", orgNote); + } + else + { + customFields["cf_org"] += $"\n{orgNote}"; + } + + var planName = GetAttribute(org.PlanType).Name.Split(" ").FirstOrDefault(); + if (!string.IsNullOrWhiteSpace(planName)) + { + tags.Add(string.Format("Org: {0}", planName)); } } - updateBody.Add("tags", tagsToUpdate); + if (tags.Any()) + { + var tagsToUpdate = tags.ToList(); + if (!string.IsNullOrWhiteSpace(ticketTags)) + { + var splitTicketTags = ticketTags.Split(','); + for (var i = 0; i < splitTicketTags.Length; i++) + { + tagsToUpdate.Insert(i, splitTicketTags[i]); + } + } + updateBody.Add("tags", tagsToUpdate); + } + + if (customFields.Any()) + { + updateBody.Add("custom_fields", customFields); + } + var updateRequest = new HttpRequestMessage(HttpMethod.Put, + string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}", ticketId)) + { + Content = JsonContent.Create(updateBody), + }; + await CallFreshdeskApiAsync(updateRequest); + + var noteBody = new Dictionary + { + { "body", $"
      {note}
    " }, + { "private", true } + }; + var noteRequest = new HttpRequestMessage(HttpMethod.Post, + string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId)) + { + Content = JsonContent.Create(noteBody), + }; + await CallFreshdeskApiAsync(noteRequest); } - if (customFields.Any()) + return new OkResult(); + } + catch (Exception e) + { + _logger.LogError(e, "Error processing freshdesk webhook."); + return new BadRequestResult(); + } + } + + private async Task CallFreshdeskApiAsync(HttpRequestMessage request, int retriedCount = 0) + { + try + { + var freshdeskAuthkey = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_billingSettings.FreshdeskApiKey}:X")); + var httpClient = _httpClientFactory.CreateClient("FreshdeskApi"); + request.Headers.Add("Authorization", freshdeskAuthkey); + var response = await httpClient.SendAsync(request); + if (response.StatusCode != System.Net.HttpStatusCode.TooManyRequests || retriedCount > 3) { - updateBody.Add("custom_fields", customFields); + return response; } - var updateRequest = new HttpRequestMessage(HttpMethod.Put, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}", ticketId)) - { - Content = JsonContent.Create(updateBody), - }; - await CallFreshdeskApiAsync(updateRequest); - - var noteBody = new Dictionary - { - { "body", $"
      {note}
    " }, - { "private", true } - }; - var noteRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId)) - { - Content = JsonContent.Create(noteBody), - }; - await CallFreshdeskApiAsync(noteRequest); } - - return new OkResult(); - } - catch (Exception e) - { - _logger.LogError(e, "Error processing freshdesk webhook."); - return new BadRequestResult(); - } - } - - private async Task CallFreshdeskApiAsync(HttpRequestMessage request, int retriedCount = 0) - { - try - { - var freshdeskAuthkey = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_billingSettings.FreshdeskApiKey}:X")); - var httpClient = _httpClientFactory.CreateClient("FreshdeskApi"); - request.Headers.Add("Authorization", freshdeskAuthkey); - var response = await httpClient.SendAsync(request); - if (response.StatusCode != System.Net.HttpStatusCode.TooManyRequests || retriedCount > 3) + catch { - return response; + if (retriedCount > 3) + { + throw; + } } + await Task.Delay(30000 * (retriedCount + 1)); + return await CallFreshdeskApiAsync(request, retriedCount++); } - catch - { - if (retriedCount > 3) - { - throw; - } - } - await Task.Delay(30000 * (retriedCount + 1)); - return await CallFreshdeskApiAsync(request, retriedCount++); - } - private TAttribute GetAttribute(Enum enumValue) where TAttribute : Attribute - { - return enumValue.GetType().GetMember(enumValue.ToString()).First().GetCustomAttribute(); + private TAttribute GetAttribute(Enum enumValue) where TAttribute : Attribute + { + return enumValue.GetType().GetMember(enumValue.ToString()).First().GetCustomAttribute(); + } } } diff --git a/src/Billing/Controllers/FreshsalesController.cs b/src/Billing/Controllers/FreshsalesController.cs index 95b9e25065..866b95d178 100644 --- a/src/Billing/Controllers/FreshsalesController.cs +++ b/src/Billing/Controllers/FreshsalesController.cs @@ -7,228 +7,229 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers; - -[Route("freshsales")] -public class FreshsalesController : Controller +namespace Bit.Billing.Controllers { - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - - private readonly string _freshsalesApiKey; - - private readonly HttpClient _httpClient; - - public FreshsalesController(IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOptions billingSettings, - ILogger logger, - GlobalSettings globalSettings) + [Route("freshsales")] + public class FreshsalesController : Controller { - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _logger = logger; - _globalSettings = globalSettings; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; - _httpClient = new HttpClient + private readonly string _freshsalesApiKey; + + private readonly HttpClient _httpClient; + + public FreshsalesController(IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOptions billingSettings, + ILogger logger, + GlobalSettings globalSettings) { - BaseAddress = new Uri("https://bitwarden.freshsales.io/api/") - }; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _logger = logger; + _globalSettings = globalSettings; - _freshsalesApiKey = billingSettings.Value.FreshsalesApiKey; - - _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue( - "Token", - $"token={_freshsalesApiKey}"); - } - - - [HttpPost("webhook")] - public async Task PostWebhook([FromHeader(Name = "Authorization")] string key, - [FromBody] CustomWebhookRequestModel request, - CancellationToken cancellationToken) - { - if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(_freshsalesApiKey, key)) - { - return Unauthorized(); - } - - try - { - var leadResponse = await _httpClient.GetFromJsonAsync>( - $"leads/{request.LeadId}", - cancellationToken); - - var lead = leadResponse.Lead; - - var primaryEmail = lead.Emails - .Where(e => e.IsPrimary) - .FirstOrDefault(); - - if (primaryEmail == null) + _httpClient = new HttpClient { - return BadRequest(new { Message = "Lead has not primary email." }); - } - - var user = await _userRepository.GetByEmailAsync(primaryEmail.Value); - - if (user == null) - { - return NoContent(); - } - - var newTags = new HashSet(); - - if (user.Premium) - { - newTags.Add("Premium"); - } - - var noteItems = new List - { - $"User, {user.Email}: {_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}" + BaseAddress = new Uri("https://bitwarden.freshsales.io/api/") }; - var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); + _freshsalesApiKey = billingSettings.Value.FreshsalesApiKey; - foreach (var org in orgs) + _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue( + "Token", + $"token={_freshsalesApiKey}"); + } + + + [HttpPost("webhook")] + public async Task PostWebhook([FromHeader(Name = "Authorization")] string key, + [FromBody] CustomWebhookRequestModel request, + CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(_freshsalesApiKey, key)) { - noteItems.Add($"Org, {org.Name}: {_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"); - if (TryGetPlanName(org.PlanType, out var planName)) + return Unauthorized(); + } + + try + { + var leadResponse = await _httpClient.GetFromJsonAsync>( + $"leads/{request.LeadId}", + cancellationToken); + + var lead = leadResponse.Lead; + + var primaryEmail = lead.Emails + .Where(e => e.IsPrimary) + .FirstOrDefault(); + + if (primaryEmail == null) { - newTags.Add($"Org: {planName}"); + return BadRequest(new { Message = "Lead has not primary email." }); } + + var user = await _userRepository.GetByEmailAsync(primaryEmail.Value); + + if (user == null) + { + return NoContent(); + } + + var newTags = new HashSet(); + + if (user.Premium) + { + newTags.Add("Premium"); + } + + var noteItems = new List + { + $"User, {user.Email}: {_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}" + }; + + var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); + + foreach (var org in orgs) + { + noteItems.Add($"Org, {org.Name}: {_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"); + if (TryGetPlanName(org.PlanType, out var planName)) + { + newTags.Add($"Org: {planName}"); + } + } + + if (newTags.Any()) + { + var allTags = newTags.Concat(lead.Tags); + var updateLeadResponse = await _httpClient.PutAsJsonAsync( + $"leads/{request.LeadId}", + CreateWrapper(new { tags = allTags }), + cancellationToken); + updateLeadResponse.EnsureSuccessStatusCode(); + } + + var createNoteResponse = await _httpClient.PostAsJsonAsync( + "notes", + CreateNoteRequestModel(request.LeadId, string.Join('\n', noteItems)), cancellationToken); + createNoteResponse.EnsureSuccessStatusCode(); + return NoContent(); } - - if (newTags.Any()) + catch (Exception ex) { - var allTags = newTags.Concat(lead.Tags); - var updateLeadResponse = await _httpClient.PutAsJsonAsync( - $"leads/{request.LeadId}", - CreateWrapper(new { tags = allTags }), - cancellationToken); - updateLeadResponse.EnsureSuccessStatusCode(); + Console.WriteLine(ex); + _logger.LogError(ex, "Error processing freshsales webhook"); + return BadRequest(new { ex.Message }); } - - var createNoteResponse = await _httpClient.PostAsJsonAsync( - "notes", - CreateNoteRequestModel(request.LeadId, string.Join('\n', noteItems)), cancellationToken); - createNoteResponse.EnsureSuccessStatusCode(); - return NoContent(); } - catch (Exception ex) - { - Console.WriteLine(ex); - _logger.LogError(ex, "Error processing freshsales webhook"); - return BadRequest(new { ex.Message }); - } - } - private static LeadWrapper CreateWrapper(T lead) - { - return new LeadWrapper + private static LeadWrapper CreateWrapper(T lead) { - Lead = lead, - }; - } - - private static CreateNoteRequestModel CreateNoteRequestModel(long leadId, string content) - { - return new CreateNoteRequestModel - { - Note = new EditNoteModel + return new LeadWrapper { - Description = content, - TargetableType = "Lead", - TargetableId = leadId, - }, - }; - } + Lead = lead, + }; + } - private static bool TryGetPlanName(PlanType planType, out string planName) - { - switch (planType) + private static CreateNoteRequestModel CreateNoteRequestModel(long leadId, string content) { - case PlanType.Free: - planName = "Free"; - return true; - case PlanType.FamiliesAnnually: - case PlanType.FamiliesAnnually2019: - planName = "Families"; - return true; - case PlanType.TeamsAnnually: - case PlanType.TeamsAnnually2019: - case PlanType.TeamsMonthly: - case PlanType.TeamsMonthly2019: - planName = "Teams"; - return true; - case PlanType.EnterpriseAnnually: - case PlanType.EnterpriseAnnually2019: - case PlanType.EnterpriseMonthly: - case PlanType.EnterpriseMonthly2019: - planName = "Enterprise"; - return true; - case PlanType.Custom: - planName = "Custom"; - return true; - default: - planName = null; - return false; + return new CreateNoteRequestModel + { + Note = new EditNoteModel + { + Description = content, + TargetableType = "Lead", + TargetableId = leadId, + }, + }; + } + + private static bool TryGetPlanName(PlanType planType, out string planName) + { + switch (planType) + { + case PlanType.Free: + planName = "Free"; + return true; + case PlanType.FamiliesAnnually: + case PlanType.FamiliesAnnually2019: + planName = "Families"; + return true; + case PlanType.TeamsAnnually: + case PlanType.TeamsAnnually2019: + case PlanType.TeamsMonthly: + case PlanType.TeamsMonthly2019: + planName = "Teams"; + return true; + case PlanType.EnterpriseAnnually: + case PlanType.EnterpriseAnnually2019: + case PlanType.EnterpriseMonthly: + case PlanType.EnterpriseMonthly2019: + planName = "Enterprise"; + return true; + case PlanType.Custom: + planName = "Custom"; + return true; + default: + planName = null; + return false; + } } } -} -public class CustomWebhookRequestModel -{ - [JsonPropertyName("leadId")] - public long LeadId { get; set; } -} - -public class LeadWrapper -{ - [JsonPropertyName("lead")] - public T Lead { get; set; } - - public static LeadWrapper Create(TItem lead) + public class CustomWebhookRequestModel { - return new LeadWrapper + [JsonPropertyName("leadId")] + public long LeadId { get; set; } + } + + public class LeadWrapper + { + [JsonPropertyName("lead")] + public T Lead { get; set; } + + public static LeadWrapper Create(TItem lead) { - Lead = lead, - }; + return new LeadWrapper + { + Lead = lead, + }; + } + } + + public class FreshsalesLeadModel + { + public string[] Tags { get; set; } + public FreshsalesEmailModel[] Emails { get; set; } + } + + public class FreshsalesEmailModel + { + [JsonPropertyName("value")] + public string Value { get; set; } + + [JsonPropertyName("is_primary")] + public bool IsPrimary { get; set; } + } + + public class CreateNoteRequestModel + { + [JsonPropertyName("note")] + public EditNoteModel Note { get; set; } + } + + public class EditNoteModel + { + [JsonPropertyName("description")] + public string Description { get; set; } + + [JsonPropertyName("targetable_type")] + public string TargetableType { get; set; } + + [JsonPropertyName("targetable_id")] + public long TargetableId { get; set; } } } - -public class FreshsalesLeadModel -{ - public string[] Tags { get; set; } - public FreshsalesEmailModel[] Emails { get; set; } -} - -public class FreshsalesEmailModel -{ - [JsonPropertyName("value")] - public string Value { get; set; } - - [JsonPropertyName("is_primary")] - public bool IsPrimary { get; set; } -} - -public class CreateNoteRequestModel -{ - [JsonPropertyName("note")] - public EditNoteModel Note { get; set; } -} - -public class EditNoteModel -{ - [JsonPropertyName("description")] - public string Description { get; set; } - - [JsonPropertyName("targetable_type")] - public string TargetableType { get; set; } - - [JsonPropertyName("targetable_id")] - public long TargetableId { get; set; } -} diff --git a/src/Billing/Controllers/InfoController.cs b/src/Billing/Controllers/InfoController.cs index 58b29f4c4f..5d7ce57541 100644 --- a/src/Billing/Controllers/InfoController.cs +++ b/src/Billing/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Billing.Controllers; - -public class InfoController : Controller +namespace Bit.Billing.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/src/Billing/Controllers/LoginController.cs b/src/Billing/Controllers/LoginController.cs index c2df41b92c..448e2b9b20 100644 --- a/src/Billing/Controllers/LoginController.cs +++ b/src/Billing/Controllers/LoginController.cs @@ -1,53 +1,54 @@ using Microsoft.AspNetCore.Mvc; -namespace Billing.Controllers; - -public class LoginController : Controller +namespace Billing.Controllers { - /* - private readonly PasswordlessSignInManager _signInManager; - - public LoginController( - PasswordlessSignInManager signInManager) + public class LoginController : Controller { - _signInManager = signInManager; - } + /* + private readonly PasswordlessSignInManager _signInManager; - public IActionResult Index() - { - return View(); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Index(LoginModel model) - { - if (ModelState.IsValid) + public LoginController( + PasswordlessSignInManager signInManager) { - var result = await _signInManager.PasswordlessSignInAsync(model.Email, - Url.Action("Confirm", "Login", null, Request.Scheme)); - if (result.Succeeded) - { - return RedirectToAction("Index", "Home"); - } - else - { - ModelState.AddModelError(string.Empty, "Account not found."); - } + _signInManager = signInManager; } - return View(model); - } - - public async Task Confirm(string email, string token) - { - var result = await _signInManager.PasswordlessSignInAsync(email, token, false); - if (!result.Succeeded) + public IActionResult Index() { - return View("Error"); + return View(); } - return RedirectToAction("Index", "Home"); + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Index(LoginModel model) + { + if (ModelState.IsValid) + { + var result = await _signInManager.PasswordlessSignInAsync(model.Email, + Url.Action("Confirm", "Login", null, Request.Scheme)); + if (result.Succeeded) + { + return RedirectToAction("Index", "Home"); + } + else + { + ModelState.AddModelError(string.Empty, "Account not found."); + } + } + + return View(model); + } + + public async Task Confirm(string email, string token) + { + var result = await _signInManager.PasswordlessSignInAsync(email, token, false); + if (!result.Succeeded) + { + return View("Error"); + } + + return RedirectToAction("Index", "Home"); + } + */ } - */ } diff --git a/src/Billing/Controllers/PayPalController.cs b/src/Billing/Controllers/PayPalController.cs index 67826afc68..64811b5aef 100644 --- a/src/Billing/Controllers/PayPalController.cs +++ b/src/Billing/Controllers/PayPalController.cs @@ -9,226 +9,227 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers; - -[Route("paypal")] -public class PayPalController : Controller +namespace Bit.Billing.Controllers { - private readonly BillingSettings _billingSettings; - private readonly PayPalIpnClient _paypalIpnClient; - private readonly ITransactionRepository _transactionRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IPaymentService _paymentService; - private readonly ILogger _logger; - - public PayPalController( - IOptions billingSettings, - PayPalIpnClient paypalIpnClient, - ITransactionRepository transactionRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - IMailService mailService, - IPaymentService paymentService, - ILogger logger) + [Route("paypal")] + public class PayPalController : Controller { - _billingSettings = billingSettings?.Value; - _paypalIpnClient = paypalIpnClient; - _transactionRepository = transactionRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _mailService = mailService; - _paymentService = paymentService; - _logger = logger; - } + private readonly BillingSettings _billingSettings; + private readonly PayPalIpnClient _paypalIpnClient; + private readonly ITransactionRepository _transactionRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IPaymentService _paymentService; + private readonly ILogger _logger; - [HttpPost("ipn")] - public async Task PostIpn() - { - _logger.LogDebug("PayPal webhook has been hit."); - if (HttpContext?.Request?.Query == null) + public PayPalController( + IOptions billingSettings, + PayPalIpnClient paypalIpnClient, + ITransactionRepository transactionRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IMailService mailService, + IPaymentService paymentService, + ILogger logger) { - return new BadRequestResult(); + _billingSettings = billingSettings?.Value; + _paypalIpnClient = paypalIpnClient; + _transactionRepository = transactionRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _mailService = mailService; + _paymentService = paymentService; + _logger = logger; } - var key = HttpContext.Request.Query.ContainsKey("key") ? - HttpContext.Request.Query["key"].ToString() : null; - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.PayPal.WebhookKey)) + [HttpPost("ipn")] + public async Task PostIpn() { - _logger.LogWarning("PayPal webhook key is incorrect or does not exist."); - return new BadRequestResult(); - } - - string body = null; - using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) - { - body = await reader.ReadToEndAsync(); - } - - if (string.IsNullOrWhiteSpace(body)) - { - return new BadRequestResult(); - } - - var verified = await _paypalIpnClient.VerifyIpnAsync(body); - if (!verified) - { - _logger.LogWarning("Unverified IPN received."); - return new BadRequestResult(); - } - - var ipnTransaction = new PayPalIpnClient.IpnTransaction(body); - if (ipnTransaction.TxnType != "web_accept" && ipnTransaction.TxnType != "merch_pmt" && - ipnTransaction.PaymentStatus != "Refunded") - { - // Only processing billing agreement payments, buy now button payments, and refunds for now. - return new OkResult(); - } - - if (ipnTransaction.ReceiverId != _billingSettings.PayPal.BusinessId) - { - _logger.LogWarning("Receiver was not proper business id. " + ipnTransaction.ReceiverId); - return new BadRequestResult(); - } - - if (ipnTransaction.PaymentStatus == "Refunded" && ipnTransaction.ParentTxnId == null) - { - // Refunds require parent transaction - return new OkResult(); - } - - if (ipnTransaction.PaymentType == "echeck" && ipnTransaction.PaymentStatus != "Refunded") - { - // Not accepting eChecks, unless it is a refund - _logger.LogWarning("Got an eCheck payment. " + ipnTransaction.TxnId); - return new OkResult(); - } - - if (ipnTransaction.McCurrency != "USD") - { - // Only process USD payments - _logger.LogWarning("Received a payment not in USD. " + ipnTransaction.TxnId); - return new OkResult(); - } - - var ids = ipnTransaction.GetIdsFromCustom(); - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - return new OkResult(); - } - - if (ipnTransaction.PaymentStatus == "Completed") - { - var transaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.PayPal, ipnTransaction.TxnId); - if (transaction != null) + _logger.LogDebug("PayPal webhook has been hit."); + if (HttpContext?.Request?.Query == null) { - _logger.LogWarning("Already processed this completed transaction. #" + ipnTransaction.TxnId); - return new OkResult(); - } - - var isAccountCredit = ipnTransaction.IsAccountCredit(); - try - { - var tx = new Transaction - { - Amount = ipnTransaction.McGross, - CreationDate = ipnTransaction.PaymentDate, - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = isAccountCredit ? TransactionType.Credit : TransactionType.Charge, - Gateway = GatewayType.PayPal, - GatewayId = ipnTransaction.TxnId, - PaymentMethodType = PaymentMethodType.PayPal, - Details = ipnTransaction.TxnId - }; - await _transactionRepository.CreateAsync(tx); - - if (isAccountCredit) - { - string billingEmail = null; - if (tx.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); - if (org != null) - { - billingEmail = org.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(org, tx.Amount)) - { - await _organizationRepository.ReplaceAsync(org); - } - } - } - else - { - var user = await _userRepository.GetByIdAsync(tx.UserId.Value); - if (user != null) - { - billingEmail = user.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(user, tx.Amount)) - { - await _userRepository.ReplaceAsync(user); - } - } - } - - if (!string.IsNullOrWhiteSpace(billingEmail)) - { - await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); - } - } - } - // Catch foreign key violations because user/org could have been deleted. - catch (SqlException e) when (e.Number == 547) { } - } - else if (ipnTransaction.PaymentStatus == "Refunded" || ipnTransaction.PaymentStatus == "Reversed") - { - var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.PayPal, ipnTransaction.TxnId); - if (refundTransaction != null) - { - _logger.LogWarning("Already processed this refunded transaction. #" + ipnTransaction.TxnId); - return new OkResult(); - } - - var parentTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.PayPal, ipnTransaction.ParentTxnId); - if (parentTransaction == null) - { - _logger.LogWarning("Parent transaction was not found. " + ipnTransaction.TxnId); return new BadRequestResult(); } - var refundAmount = System.Math.Abs(ipnTransaction.McGross); - var remainingAmount = parentTransaction.Amount - - parentTransaction.RefundedAmount.GetValueOrDefault(); - if (refundAmount > 0 && !parentTransaction.Refunded.GetValueOrDefault() && - remainingAmount >= refundAmount) + var key = HttpContext.Request.Query.ContainsKey("key") ? + HttpContext.Request.Query["key"].ToString() : null; + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.PayPal.WebhookKey)) { - parentTransaction.RefundedAmount = - parentTransaction.RefundedAmount.GetValueOrDefault() + refundAmount; - if (parentTransaction.RefundedAmount == parentTransaction.Amount) + _logger.LogWarning("PayPal webhook key is incorrect or does not exist."); + return new BadRequestResult(); + } + + string body = null; + using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) + { + body = await reader.ReadToEndAsync(); + } + + if (string.IsNullOrWhiteSpace(body)) + { + return new BadRequestResult(); + } + + var verified = await _paypalIpnClient.VerifyIpnAsync(body); + if (!verified) + { + _logger.LogWarning("Unverified IPN received."); + return new BadRequestResult(); + } + + var ipnTransaction = new PayPalIpnClient.IpnTransaction(body); + if (ipnTransaction.TxnType != "web_accept" && ipnTransaction.TxnType != "merch_pmt" && + ipnTransaction.PaymentStatus != "Refunded") + { + // Only processing billing agreement payments, buy now button payments, and refunds for now. + return new OkResult(); + } + + if (ipnTransaction.ReceiverId != _billingSettings.PayPal.BusinessId) + { + _logger.LogWarning("Receiver was not proper business id. " + ipnTransaction.ReceiverId); + return new BadRequestResult(); + } + + if (ipnTransaction.PaymentStatus == "Refunded" && ipnTransaction.ParentTxnId == null) + { + // Refunds require parent transaction + return new OkResult(); + } + + if (ipnTransaction.PaymentType == "echeck" && ipnTransaction.PaymentStatus != "Refunded") + { + // Not accepting eChecks, unless it is a refund + _logger.LogWarning("Got an eCheck payment. " + ipnTransaction.TxnId); + return new OkResult(); + } + + if (ipnTransaction.McCurrency != "USD") + { + // Only process USD payments + _logger.LogWarning("Received a payment not in USD. " + ipnTransaction.TxnId); + return new OkResult(); + } + + var ids = ipnTransaction.GetIdsFromCustom(); + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + return new OkResult(); + } + + if (ipnTransaction.PaymentStatus == "Completed") + { + var transaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.PayPal, ipnTransaction.TxnId); + if (transaction != null) { - parentTransaction.Refunded = true; + _logger.LogWarning("Already processed this completed transaction. #" + ipnTransaction.TxnId); + return new OkResult(); } - await _transactionRepository.ReplaceAsync(parentTransaction); - await _transactionRepository.CreateAsync(new Transaction + var isAccountCredit = ipnTransaction.IsAccountCredit(); + try { - Amount = refundAmount, - CreationDate = ipnTransaction.PaymentDate, - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = TransactionType.Refund, - Gateway = GatewayType.PayPal, - GatewayId = ipnTransaction.TxnId, - PaymentMethodType = PaymentMethodType.PayPal, - Details = ipnTransaction.TxnId - }); - } - } + var tx = new Transaction + { + Amount = ipnTransaction.McGross, + CreationDate = ipnTransaction.PaymentDate, + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = isAccountCredit ? TransactionType.Credit : TransactionType.Charge, + Gateway = GatewayType.PayPal, + GatewayId = ipnTransaction.TxnId, + PaymentMethodType = PaymentMethodType.PayPal, + Details = ipnTransaction.TxnId + }; + await _transactionRepository.CreateAsync(tx); - return new OkResult(); + if (isAccountCredit) + { + string billingEmail = null; + if (tx.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); + if (org != null) + { + billingEmail = org.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(org, tx.Amount)) + { + await _organizationRepository.ReplaceAsync(org); + } + } + } + else + { + var user = await _userRepository.GetByIdAsync(tx.UserId.Value); + if (user != null) + { + billingEmail = user.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(user, tx.Amount)) + { + await _userRepository.ReplaceAsync(user); + } + } + } + + if (!string.IsNullOrWhiteSpace(billingEmail)) + { + await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); + } + } + } + // Catch foreign key violations because user/org could have been deleted. + catch (SqlException e) when (e.Number == 547) { } + } + else if (ipnTransaction.PaymentStatus == "Refunded" || ipnTransaction.PaymentStatus == "Reversed") + { + var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.PayPal, ipnTransaction.TxnId); + if (refundTransaction != null) + { + _logger.LogWarning("Already processed this refunded transaction. #" + ipnTransaction.TxnId); + return new OkResult(); + } + + var parentTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.PayPal, ipnTransaction.ParentTxnId); + if (parentTransaction == null) + { + _logger.LogWarning("Parent transaction was not found. " + ipnTransaction.TxnId); + return new BadRequestResult(); + } + + var refundAmount = System.Math.Abs(ipnTransaction.McGross); + var remainingAmount = parentTransaction.Amount - + parentTransaction.RefundedAmount.GetValueOrDefault(); + if (refundAmount > 0 && !parentTransaction.Refunded.GetValueOrDefault() && + remainingAmount >= refundAmount) + { + parentTransaction.RefundedAmount = + parentTransaction.RefundedAmount.GetValueOrDefault() + refundAmount; + if (parentTransaction.RefundedAmount == parentTransaction.Amount) + { + parentTransaction.Refunded = true; + } + + await _transactionRepository.ReplaceAsync(parentTransaction); + await _transactionRepository.CreateAsync(new Transaction + { + Amount = refundAmount, + CreationDate = ipnTransaction.PaymentDate, + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = TransactionType.Refund, + Gateway = GatewayType.PayPal, + GatewayId = ipnTransaction.TxnId, + PaymentMethodType = PaymentMethodType.PayPal, + Details = ipnTransaction.TxnId + }); + } + } + + return new OkResult(); + } } } diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index d9f3bc744f..4cabb96456 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -13,825 +13,826 @@ using Microsoft.Extensions.Options; using Stripe; using TaxRate = Bit.Core.Entities.TaxRate; -namespace Bit.Billing.Controllers; - -[Route("stripe")] -public class StripeController : Controller +namespace Bit.Billing.Controllers { - private const decimal PremiumPlanAppleIapPrice = 14.99M; - private const string PremiumPlanId = "premium-annually"; - private const string PremiumPlanIdAppStore = "premium-annually-app"; - - private readonly BillingSettings _billingSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly IOrganizationService _organizationService; - private readonly IValidateSponsorshipCommand _validateSponsorshipCommand; - private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; - private readonly IOrganizationRepository _organizationRepository; - private readonly ITransactionRepository _transactionRepository; - private readonly IUserService _userService; - private readonly IAppleIapService _appleIapService; - private readonly IMailService _mailService; - private readonly ILogger _logger; - private readonly Braintree.BraintreeGateway _btGateway; - private readonly IReferenceEventService _referenceEventService; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IUserRepository _userRepository; - - public StripeController( - GlobalSettings globalSettings, - IOptions billingSettings, - IWebHostEnvironment hostingEnvironment, - IOrganizationService organizationService, - IValidateSponsorshipCommand validateSponsorshipCommand, - IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, - IOrganizationRepository organizationRepository, - ITransactionRepository transactionRepository, - IUserService userService, - IAppleIapService appleIapService, - IMailService mailService, - IReferenceEventService referenceEventService, - ILogger logger, - ITaxRateRepository taxRateRepository, - IUserRepository userRepository) + [Route("stripe")] + public class StripeController : Controller { - _billingSettings = billingSettings?.Value; - _hostingEnvironment = hostingEnvironment; - _organizationService = organizationService; - _validateSponsorshipCommand = validateSponsorshipCommand; - _organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand; - _organizationRepository = organizationRepository; - _transactionRepository = transactionRepository; - _userService = userService; - _appleIapService = appleIapService; - _mailService = mailService; - _referenceEventService = referenceEventService; - _taxRateRepository = taxRateRepository; - _userRepository = userRepository; - _logger = logger; - _btGateway = new Braintree.BraintreeGateway - { - Environment = globalSettings.Braintree.Production ? - Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, - MerchantId = globalSettings.Braintree.MerchantId, - PublicKey = globalSettings.Braintree.PublicKey, - PrivateKey = globalSettings.Braintree.PrivateKey - }; - } + private const decimal PremiumPlanAppleIapPrice = 14.99M; + private const string PremiumPlanId = "premium-annually"; + private const string PremiumPlanIdAppStore = "premium-annually-app"; - [HttpPost("webhook")] - public async Task PostWebhook([FromQuery] string key) - { - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.StripeWebhookKey)) + private readonly BillingSettings _billingSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly IOrganizationService _organizationService; + private readonly IValidateSponsorshipCommand _validateSponsorshipCommand; + private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; + private readonly IOrganizationRepository _organizationRepository; + private readonly ITransactionRepository _transactionRepository; + private readonly IUserService _userService; + private readonly IAppleIapService _appleIapService; + private readonly IMailService _mailService; + private readonly ILogger _logger; + private readonly Braintree.BraintreeGateway _btGateway; + private readonly IReferenceEventService _referenceEventService; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IUserRepository _userRepository; + + public StripeController( + GlobalSettings globalSettings, + IOptions billingSettings, + IWebHostEnvironment hostingEnvironment, + IOrganizationService organizationService, + IValidateSponsorshipCommand validateSponsorshipCommand, + IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, + IOrganizationRepository organizationRepository, + ITransactionRepository transactionRepository, + IUserService userService, + IAppleIapService appleIapService, + IMailService mailService, + IReferenceEventService referenceEventService, + ILogger logger, + ITaxRateRepository taxRateRepository, + IUserRepository userRepository) { - return new BadRequestResult(); + _billingSettings = billingSettings?.Value; + _hostingEnvironment = hostingEnvironment; + _organizationService = organizationService; + _validateSponsorshipCommand = validateSponsorshipCommand; + _organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand; + _organizationRepository = organizationRepository; + _transactionRepository = transactionRepository; + _userService = userService; + _appleIapService = appleIapService; + _mailService = mailService; + _referenceEventService = referenceEventService; + _taxRateRepository = taxRateRepository; + _userRepository = userRepository; + _logger = logger; + _btGateway = new Braintree.BraintreeGateway + { + Environment = globalSettings.Braintree.Production ? + Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, + MerchantId = globalSettings.Braintree.MerchantId, + PublicKey = globalSettings.Braintree.PublicKey, + PrivateKey = globalSettings.Braintree.PrivateKey + }; } - Stripe.Event parsedEvent; - using (var sr = new StreamReader(HttpContext.Request.Body)) + [HttpPost("webhook")] + public async Task PostWebhook([FromQuery] string key) { - var json = await sr.ReadToEndAsync(); - parsedEvent = EventUtility.ConstructEvent(json, Request.Headers["Stripe-Signature"], - _billingSettings.StripeWebhookSecret, - throwOnApiVersionMismatch: _billingSettings.StripeEventParseThrowMismatch); - } - - if (string.IsNullOrWhiteSpace(parsedEvent?.Id)) - { - _logger.LogWarning("No event id."); - return new BadRequestResult(); - } - - if (_hostingEnvironment.IsProduction() && !parsedEvent.Livemode) - { - _logger.LogWarning("Getting test events in production."); - return new BadRequestResult(); - } - - var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted); - var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated); - - if (subDeleted || subUpdated) - { - var subscription = await GetSubscriptionAsync(parsedEvent, true); - var ids = GetIdsFromMetaData(subscription.Metadata); - - var subCanceled = subDeleted && subscription.Status == "canceled"; - var subUnpaid = subUpdated && subscription.Status == "unpaid"; - var subIncompleteExpired = subUpdated && subscription.Status == "incomplete_expired"; - - if (subCanceled || subUnpaid || subIncompleteExpired) + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.StripeWebhookKey)) { - // org - if (ids.Item1.HasValue) - { - await _organizationService.DisableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); - } - // user - else if (ids.Item2.HasValue) - { - await _userService.DisablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); - } - } - - if (subUpdated) - { - // org - if (ids.Item1.HasValue) - { - await _organizationService.UpdateExpirationDateAsync(ids.Item1.Value, - subscription.CurrentPeriodEnd); - if (IsSponsoredSubscription(subscription)) - { - await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); - } - } - // user - else if (ids.Item2.HasValue) - { - await _userService.UpdatePremiumExpirationAsync(ids.Item2.Value, - subscription.CurrentPeriodEnd); - } - } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice)) - { - var invoice = await GetInvoiceAsync(parsedEvent); - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - if (subscription == null) - { - throw new Exception("Invoice subscription is null. " + invoice.Id); - } - - subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); - - string email = null; - var ids = GetIdsFromMetaData(subscription.Metadata); - // org - if (ids.Item1.HasValue) - { - // sponsored org - if (IsSponsoredSubscription(subscription)) - { - await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value); - } - - var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value); - if (org != null && OrgPlanForInvoiceNotifications(org)) - { - email = org.BillingEmail; - } - } - // user - else if (ids.Item2.HasValue) - { - var user = await _userService.GetUserByIdAsync(ids.Item2.Value); - if (user.Premium) - { - email = user.Email; - } - } - - if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue) - { - var items = invoice.Lines.Select(i => i.Description).ToList(); - await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M, - invoice.NextPaymentAttempt.Value, items, true); - } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded)) - { - var charge = await GetChargeAsync(parsedEvent); - var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.Stripe, charge.Id); - if (chargeTransaction != null) - { - _logger.LogWarning("Charge success already processed. " + charge.Id); - return new OkResult(); - } - - Tuple ids = null; - Subscription subscription = null; - var subscriptionService = new SubscriptionService(); - - if (charge.InvoiceId != null) - { - var invoiceService = new InvoiceService(); - var invoice = await invoiceService.GetAsync(charge.InvoiceId); - if (invoice?.SubscriptionId != null) - { - subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - ids = GetIdsFromMetaData(subscription?.Metadata); - } - } - - if (subscription == null || ids == null || (ids.Item1.HasValue && ids.Item2.HasValue)) - { - var subscriptions = await subscriptionService.ListAsync(new SubscriptionListOptions - { - Customer = charge.CustomerId - }); - foreach (var sub in subscriptions) - { - if (sub.Status != "canceled" && sub.Status != "incomplete_expired") - { - ids = GetIdsFromMetaData(sub.Metadata); - if (ids.Item1.HasValue || ids.Item2.HasValue) - { - subscription = sub; - break; - } - } - } - } - - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - _logger.LogWarning("Charge success has no subscriber ids. " + charge.Id); return new BadRequestResult(); } - var tx = new Transaction + Stripe.Event parsedEvent; + using (var sr = new StreamReader(HttpContext.Request.Body)) { - Amount = charge.Amount / 100M, - CreationDate = charge.Created, - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = TransactionType.Charge, - Gateway = GatewayType.Stripe, - GatewayId = charge.Id - }; - - if (charge.Source != null && charge.Source is Card card) - { - tx.PaymentMethodType = PaymentMethodType.Card; - tx.Details = $"{card.Brand}, *{card.Last4}"; - } - else if (charge.Source != null && charge.Source is BankAccount bankAccount) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"{bankAccount.BankName}, *{bankAccount.Last4}"; - } - else if (charge.Source != null && charge.Source is Source source) - { - if (source.Card != null) - { - tx.PaymentMethodType = PaymentMethodType.Card; - tx.Details = $"{source.Card.Brand}, *{source.Card.Last4}"; - } - else if (source.AchDebit != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"{source.AchDebit.BankName}, *{source.AchDebit.Last4}"; - } - else if (source.AchCreditTransfer != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"ACH => {source.AchCreditTransfer.BankName}, " + - $"{source.AchCreditTransfer.AccountNumber}"; - } - } - else if (charge.PaymentMethodDetails != null) - { - if (charge.PaymentMethodDetails.Card != null) - { - tx.PaymentMethodType = PaymentMethodType.Card; - tx.Details = $"{charge.PaymentMethodDetails.Card.Brand?.ToUpperInvariant()}, " + - $"*{charge.PaymentMethodDetails.Card.Last4}"; - } - else if (charge.PaymentMethodDetails.AchDebit != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"{charge.PaymentMethodDetails.AchDebit.BankName}, " + - $"*{charge.PaymentMethodDetails.AchDebit.Last4}"; - } - else if (charge.PaymentMethodDetails.AchCreditTransfer != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"ACH => {charge.PaymentMethodDetails.AchCreditTransfer.BankName}, " + - $"{charge.PaymentMethodDetails.AchCreditTransfer.AccountNumber}"; - } + var json = await sr.ReadToEndAsync(); + parsedEvent = EventUtility.ConstructEvent(json, Request.Headers["Stripe-Signature"], + _billingSettings.StripeWebhookSecret, + throwOnApiVersionMismatch: _billingSettings.StripeEventParseThrowMismatch); } - if (!tx.PaymentMethodType.HasValue) + if (string.IsNullOrWhiteSpace(parsedEvent?.Id)) { - _logger.LogWarning("Charge success from unsupported source/method. " + charge.Id); - return new OkResult(); + _logger.LogWarning("No event id."); + return new BadRequestResult(); } - try + if (_hostingEnvironment.IsProduction() && !parsedEvent.Livemode) { - await _transactionRepository.CreateAsync(tx); - } - // Catch foreign key violations because user/org could have been deleted. - catch (SqlException e) when (e.Number == 547) { } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeRefunded)) - { - var charge = await GetChargeAsync(parsedEvent); - var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.Stripe, charge.Id); - if (chargeTransaction == null) - { - throw new Exception("Cannot find refunded charge. " + charge.Id); + _logger.LogWarning("Getting test events in production."); + return new BadRequestResult(); } - var amountRefunded = charge.AmountRefunded / 100M; + var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted); + var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated); - if (!chargeTransaction.Refunded.GetValueOrDefault() && - chargeTransaction.RefundedAmount.GetValueOrDefault() < amountRefunded) + if (subDeleted || subUpdated) { - chargeTransaction.RefundedAmount = amountRefunded; - if (charge.Refunded) + var subscription = await GetSubscriptionAsync(parsedEvent, true); + var ids = GetIdsFromMetaData(subscription.Metadata); + + var subCanceled = subDeleted && subscription.Status == "canceled"; + var subUnpaid = subUpdated && subscription.Status == "unpaid"; + var subIncompleteExpired = subUpdated && subscription.Status == "incomplete_expired"; + + if (subCanceled || subUnpaid || subIncompleteExpired) { - chargeTransaction.Refunded = true; - } - await _transactionRepository.ReplaceAsync(chargeTransaction); - - foreach (var refund in charge.Refunds) - { - var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.Stripe, refund.Id); - if (refundTransaction != null) - { - continue; - } - - await _transactionRepository.CreateAsync(new Transaction - { - Amount = refund.Amount / 100M, - CreationDate = refund.Created, - OrganizationId = chargeTransaction.OrganizationId, - UserId = chargeTransaction.UserId, - Type = TransactionType.Refund, - Gateway = GatewayType.Stripe, - GatewayId = refund.Id, - PaymentMethodType = chargeTransaction.PaymentMethodType, - Details = chargeTransaction.Details - }); - } - } - else - { - _logger.LogWarning("Charge refund amount doesn't seem correct. " + charge.Id); - } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentSucceeded)) - { - var invoice = await GetInvoiceAsync(parsedEvent, true); - if (invoice.Paid && invoice.BillingReason == "subscription_create") - { - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - if (subscription?.Status == "active") - { - if (DateTime.UtcNow - invoice.Created < TimeSpan.FromMinutes(1)) - { - await Task.Delay(5000); - } - - var ids = GetIdsFromMetaData(subscription.Metadata); // org if (ids.Item1.HasValue) { - if (subscription.Items.Any(i => StaticStore.Plans.Any(p => p.StripePlanId == i.Plan.Id))) - { - await _organizationService.EnableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); + await _organizationService.DisableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); + } + // user + else if (ids.Item2.HasValue) + { + await _userService.DisablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); + } + } - var organization = await _organizationRepository.GetByIdAsync(ids.Item1.Value); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.Rebilled, organization) - { - PlanName = organization?.Plan, - PlanType = organization?.PlanType, - Seats = organization?.Seats, - Storage = organization?.MaxStorageGb, - }); + if (subUpdated) + { + // org + if (ids.Item1.HasValue) + { + await _organizationService.UpdateExpirationDateAsync(ids.Item1.Value, + subscription.CurrentPeriodEnd); + if (IsSponsoredSubscription(subscription)) + { + await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); } } // user else if (ids.Item2.HasValue) { - if (subscription.Items.Any(i => i.Plan.Id == PremiumPlanId)) - { - await _userService.EnablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); + await _userService.UpdatePremiumExpirationAsync(ids.Item2.Value, + subscription.CurrentPeriodEnd); + } + } + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice)) + { + var invoice = await GetInvoiceAsync(parsedEvent); + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + if (subscription == null) + { + throw new Exception("Invoice subscription is null. " + invoice.Id); + } - var user = await _userRepository.GetByIdAsync(ids.Item2.Value); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.Rebilled, user) - { - PlanName = PremiumPlanId, - Storage = user?.MaxStorageGb, - }); + subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); + + string email = null; + var ids = GetIdsFromMetaData(subscription.Metadata); + // org + if (ids.Item1.HasValue) + { + // sponsored org + if (IsSponsoredSubscription(subscription)) + { + await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value); + } + + var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value); + if (org != null && OrgPlanForInvoiceNotifications(org)) + { + email = org.BillingEmail; + } + } + // user + else if (ids.Item2.HasValue) + { + var user = await _userService.GetUserByIdAsync(ids.Item2.Value); + if (user.Premium) + { + email = user.Email; + } + } + + if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue) + { + var items = invoice.Lines.Select(i => i.Description).ToList(); + await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M, + invoice.NextPaymentAttempt.Value, items, true); + } + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded)) + { + var charge = await GetChargeAsync(parsedEvent); + var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.Stripe, charge.Id); + if (chargeTransaction != null) + { + _logger.LogWarning("Charge success already processed. " + charge.Id); + return new OkResult(); + } + + Tuple ids = null; + Subscription subscription = null; + var subscriptionService = new SubscriptionService(); + + if (charge.InvoiceId != null) + { + var invoiceService = new InvoiceService(); + var invoice = await invoiceService.GetAsync(charge.InvoiceId); + if (invoice?.SubscriptionId != null) + { + subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + ids = GetIdsFromMetaData(subscription?.Metadata); + } + } + + if (subscription == null || ids == null || (ids.Item1.HasValue && ids.Item2.HasValue)) + { + var subscriptions = await subscriptionService.ListAsync(new SubscriptionListOptions + { + Customer = charge.CustomerId + }); + foreach (var sub in subscriptions) + { + if (sub.Status != "canceled" && sub.Status != "incomplete_expired") + { + ids = GetIdsFromMetaData(sub.Metadata); + if (ids.Item1.HasValue || ids.Item2.HasValue) + { + subscription = sub; + break; + } + } + } + } + + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + _logger.LogWarning("Charge success has no subscriber ids. " + charge.Id); + return new BadRequestResult(); + } + + var tx = new Transaction + { + Amount = charge.Amount / 100M, + CreationDate = charge.Created, + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = TransactionType.Charge, + Gateway = GatewayType.Stripe, + GatewayId = charge.Id + }; + + if (charge.Source != null && charge.Source is Card card) + { + tx.PaymentMethodType = PaymentMethodType.Card; + tx.Details = $"{card.Brand}, *{card.Last4}"; + } + else if (charge.Source != null && charge.Source is BankAccount bankAccount) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"{bankAccount.BankName}, *{bankAccount.Last4}"; + } + else if (charge.Source != null && charge.Source is Source source) + { + if (source.Card != null) + { + tx.PaymentMethodType = PaymentMethodType.Card; + tx.Details = $"{source.Card.Brand}, *{source.Card.Last4}"; + } + else if (source.AchDebit != null) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"{source.AchDebit.BankName}, *{source.AchDebit.Last4}"; + } + else if (source.AchCreditTransfer != null) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"ACH => {source.AchCreditTransfer.BankName}, " + + $"{source.AchCreditTransfer.AccountNumber}"; + } + } + else if (charge.PaymentMethodDetails != null) + { + if (charge.PaymentMethodDetails.Card != null) + { + tx.PaymentMethodType = PaymentMethodType.Card; + tx.Details = $"{charge.PaymentMethodDetails.Card.Brand?.ToUpperInvariant()}, " + + $"*{charge.PaymentMethodDetails.Card.Last4}"; + } + else if (charge.PaymentMethodDetails.AchDebit != null) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"{charge.PaymentMethodDetails.AchDebit.BankName}, " + + $"*{charge.PaymentMethodDetails.AchDebit.Last4}"; + } + else if (charge.PaymentMethodDetails.AchCreditTransfer != null) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"ACH => {charge.PaymentMethodDetails.AchCreditTransfer.BankName}, " + + $"{charge.PaymentMethodDetails.AchCreditTransfer.AccountNumber}"; + } + } + + if (!tx.PaymentMethodType.HasValue) + { + _logger.LogWarning("Charge success from unsupported source/method. " + charge.Id); + return new OkResult(); + } + + try + { + await _transactionRepository.CreateAsync(tx); + } + // Catch foreign key violations because user/org could have been deleted. + catch (SqlException e) when (e.Number == 547) { } + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeRefunded)) + { + var charge = await GetChargeAsync(parsedEvent); + var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.Stripe, charge.Id); + if (chargeTransaction == null) + { + throw new Exception("Cannot find refunded charge. " + charge.Id); + } + + var amountRefunded = charge.AmountRefunded / 100M; + + if (!chargeTransaction.Refunded.GetValueOrDefault() && + chargeTransaction.RefundedAmount.GetValueOrDefault() < amountRefunded) + { + chargeTransaction.RefundedAmount = amountRefunded; + if (charge.Refunded) + { + chargeTransaction.Refunded = true; + } + await _transactionRepository.ReplaceAsync(chargeTransaction); + + foreach (var refund in charge.Refunds) + { + var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.Stripe, refund.Id); + if (refundTransaction != null) + { + continue; + } + + await _transactionRepository.CreateAsync(new Transaction + { + Amount = refund.Amount / 100M, + CreationDate = refund.Created, + OrganizationId = chargeTransaction.OrganizationId, + UserId = chargeTransaction.UserId, + Type = TransactionType.Refund, + Gateway = GatewayType.Stripe, + GatewayId = refund.Id, + PaymentMethodType = chargeTransaction.PaymentMethodType, + Details = chargeTransaction.Details + }); + } + } + else + { + _logger.LogWarning("Charge refund amount doesn't seem correct. " + charge.Id); + } + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentSucceeded)) + { + var invoice = await GetInvoiceAsync(parsedEvent, true); + if (invoice.Paid && invoice.BillingReason == "subscription_create") + { + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + if (subscription?.Status == "active") + { + if (DateTime.UtcNow - invoice.Created < TimeSpan.FromMinutes(1)) + { + await Task.Delay(5000); + } + + var ids = GetIdsFromMetaData(subscription.Metadata); + // org + if (ids.Item1.HasValue) + { + if (subscription.Items.Any(i => StaticStore.Plans.Any(p => p.StripePlanId == i.Plan.Id))) + { + await _organizationService.EnableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); + + var organization = await _organizationRepository.GetByIdAsync(ids.Item1.Value); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Rebilled, organization) + { + PlanName = organization?.Plan, + PlanType = organization?.PlanType, + Seats = organization?.Seats, + Storage = organization?.MaxStorageGb, + }); + } + } + // user + else if (ids.Item2.HasValue) + { + if (subscription.Items.Any(i => i.Plan.Id == PremiumPlanId)) + { + await _userService.EnablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); + + var user = await _userRepository.GetByIdAsync(ids.Item2.Value); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Rebilled, user) + { + PlanName = PremiumPlanId, + Storage = user?.MaxStorageGb, + }); + } } } } } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentFailed)) - { - await HandlePaymentFailed(await GetInvoiceAsync(parsedEvent, true)); - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.InvoiceCreated)) - { - var invoice = await GetInvoiceAsync(parsedEvent, true); - if (!invoice.Paid && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) + else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentFailed)) { - await AttemptToPayInvoiceAsync(invoice); + await HandlePaymentFailed(await GetInvoiceAsync(parsedEvent, true)); } - } - else - { - _logger.LogWarning("Unsupported event received. " + parsedEvent.Type); - } - - return new OkResult(); - } - - private Tuple GetIdsFromMetaData(IDictionary metaData) - { - if (metaData == null || !metaData.Any()) - { - return new Tuple(null, null); - } - - Guid? orgId = null; - Guid? userId = null; - - if (metaData.ContainsKey("organizationId")) - { - orgId = new Guid(metaData["organizationId"]); - } - else if (metaData.ContainsKey("userId")) - { - userId = new Guid(metaData["userId"]); - } - - if (userId == null && orgId == null) - { - var orgIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "organizationid"); - if (!string.IsNullOrWhiteSpace(orgIdKey)) + else if (parsedEvent.Type.Equals(HandledStripeWebhook.InvoiceCreated)) { - orgId = new Guid(metaData[orgIdKey]); + var invoice = await GetInvoiceAsync(parsedEvent, true); + if (!invoice.Paid && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) + { + await AttemptToPayInvoiceAsync(invoice); + } } else { - var userIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "userid"); - if (!string.IsNullOrWhiteSpace(userIdKey)) + _logger.LogWarning("Unsupported event received. " + parsedEvent.Type); + } + + return new OkResult(); + } + + private Tuple GetIdsFromMetaData(IDictionary metaData) + { + if (metaData == null || !metaData.Any()) + { + return new Tuple(null, null); + } + + Guid? orgId = null; + Guid? userId = null; + + if (metaData.ContainsKey("organizationId")) + { + orgId = new Guid(metaData["organizationId"]); + } + else if (metaData.ContainsKey("userId")) + { + userId = new Guid(metaData["userId"]); + } + + if (userId == null && orgId == null) + { + var orgIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "organizationid"); + if (!string.IsNullOrWhiteSpace(orgIdKey)) { - userId = new Guid(metaData[userIdKey]); + orgId = new Guid(metaData[orgIdKey]); } - } - } - - return new Tuple(orgId, userId); - } - - private bool OrgPlanForInvoiceNotifications(Organization org) - { - switch (org.PlanType) - { - case PlanType.FamiliesAnnually: - case PlanType.TeamsAnnually: - case PlanType.EnterpriseAnnually: - return true; - default: - return false; - } - } - - private async Task AttemptToPayInvoiceAsync(Invoice invoice) - { - var customerService = new CustomerService(); - var customer = await customerService.GetAsync(invoice.CustomerId); - if (customer?.Metadata?.ContainsKey("appleReceipt") ?? false) - { - return await AttemptToPayInvoiceWithAppleReceiptAsync(invoice, customer); - } - else if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) - { - return await AttemptToPayInvoiceWithBraintreeAsync(invoice, customer); - } - return false; - } - - private async Task AttemptToPayInvoiceWithAppleReceiptAsync(Invoice invoice, Customer customer) - { - if (!customer?.Metadata?.ContainsKey("appleReceipt") ?? true) - { - return false; - } - - var originalAppleReceiptTransactionId = customer.Metadata["appleReceipt"]; - var appleReceiptRecord = await _appleIapService.GetReceiptAsync(originalAppleReceiptTransactionId); - if (string.IsNullOrWhiteSpace(appleReceiptRecord?.Item1) || !appleReceiptRecord.Item2.HasValue) - { - return false; - } - - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - var ids = GetIdsFromMetaData(subscription?.Metadata); - if (!ids.Item2.HasValue) - { - // Apple receipt is only for user subscriptions - return false; - } - - if (appleReceiptRecord.Item2.Value != ids.Item2.Value) - { - _logger.LogError("User Ids for Apple Receipt and subscription do not match: {0} != {1}.", - appleReceiptRecord.Item2.Value, ids.Item2.Value); - return false; - } - - var appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(appleReceiptRecord.Item1); - if (appleReceiptStatus == null) - { - // TODO: cancel sub if receipt is cancelled? - return false; - } - - var receiptExpiration = appleReceiptStatus.GetLastExpiresDate().GetValueOrDefault(DateTime.MinValue); - var invoiceDue = invoice.DueDate.GetValueOrDefault(DateTime.MinValue); - if (receiptExpiration <= invoiceDue) - { - _logger.LogWarning("Apple receipt expiration is before invoice due date. {0} <= {1}", - receiptExpiration, invoiceDue); - return false; - } - - var receiptLastTransactionId = appleReceiptStatus.GetLastTransactionId(); - var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.AppStore, receiptLastTransactionId); - if (existingTransaction != null) - { - _logger.LogWarning("There is already an existing transaction for this Apple receipt.", - receiptLastTransactionId); - return false; - } - - var appleTransaction = appleReceiptStatus.BuildTransactionFromLastTransaction( - PremiumPlanAppleIapPrice, ids.Item2.Value); - appleTransaction.Type = TransactionType.Charge; - - var invoiceService = new InvoiceService(); - try - { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions - { - Metadata = new Dictionary + else { - ["appleReceipt"] = appleReceiptStatus.GetOriginalTransactionId(), - ["appleReceiptTransactionId"] = receiptLastTransactionId - } - }); - - await _transactionRepository.CreateAsync(appleTransaction); - await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); - } - catch (Exception e) - { - if (e.Message.Contains("Invoice is already paid")) - { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions - { - Metadata = invoice.Metadata - }); - } - else - { - throw; - } - } - - return true; - } - - private async Task AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice, Customer customer) - { - if (!customer?.Metadata?.ContainsKey("btCustomerId") ?? true) - { - return false; - } - - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - var ids = GetIdsFromMetaData(subscription?.Metadata); - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - return false; - } - - var orgTransaction = ids.Item1.HasValue; - var btObjIdField = orgTransaction ? "organization_id" : "user_id"; - var btObjId = ids.Item1 ?? ids.Item2.Value; - var btInvoiceAmount = (invoice.AmountDue / 100M); - - var existingTransactions = orgTransaction ? - await _transactionRepository.GetManyByOrganizationIdAsync(ids.Item1.Value) : - await _transactionRepository.GetManyByUserIdAsync(ids.Item2.Value); - var duplicateTimeSpan = TimeSpan.FromHours(24); - var now = DateTime.UtcNow; - var duplicateTransaction = existingTransactions? - .FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan); - if (duplicateTransaction != null) - { - _logger.LogWarning("There is already a recent PayPal transaction ({0}). " + - "Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId); - return false; - } - - var transactionResult = await _btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest - { - Amount = btInvoiceAmount, - CustomerId = customer.Metadata["btCustomerId"], - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest + var userIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "userid"); + if (!string.IsNullOrWhiteSpace(userIdKey)) { - CustomField = $"{btObjIdField}:{btObjId}" + userId = new Guid(metaData[userIdKey]); } - }, - CustomFields = new Dictionary - { - [btObjIdField] = btObjId.ToString() } - }); + } - if (!transactionResult.IsSuccess()) + return new Tuple(orgId, userId); + } + + private bool OrgPlanForInvoiceNotifications(Organization org) { - if (invoice.AttemptCount < 4) + switch (org.PlanType) { - await _mailService.SendPaymentFailedAsync(customer.Email, btInvoiceAmount, true); + case PlanType.FamiliesAnnually: + case PlanType.TeamsAnnually: + case PlanType.EnterpriseAnnually: + return true; + default: + return false; + } + } + + private async Task AttemptToPayInvoiceAsync(Invoice invoice) + { + var customerService = new CustomerService(); + var customer = await customerService.GetAsync(invoice.CustomerId); + if (customer?.Metadata?.ContainsKey("appleReceipt") ?? false) + { + return await AttemptToPayInvoiceWithAppleReceiptAsync(invoice, customer); + } + else if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) + { + return await AttemptToPayInvoiceWithBraintreeAsync(invoice, customer); } return false; } - var invoiceService = new InvoiceService(); - try + private async Task AttemptToPayInvoiceWithAppleReceiptAsync(Invoice invoice, Customer customer) { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + if (!customer?.Metadata?.ContainsKey("appleReceipt") ?? true) { - Metadata = new Dictionary - { - ["btTransactionId"] = transactionResult.Target.Id, - ["btPayPalTransactionId"] = - transactionResult.Target.PayPalDetails?.AuthorizationId - } - }); - await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); - } - catch (Exception e) - { - await _btGateway.Transaction.RefundAsync(transactionResult.Target.Id); - if (e.Message.Contains("Invoice is already paid")) - { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions - { - Metadata = invoice.Metadata - }); + return false; } - else + + var originalAppleReceiptTransactionId = customer.Metadata["appleReceipt"]; + var appleReceiptRecord = await _appleIapService.GetReceiptAsync(originalAppleReceiptTransactionId); + if (string.IsNullOrWhiteSpace(appleReceiptRecord?.Item1) || !appleReceiptRecord.Item2.HasValue) { - throw; + return false; } - } - return true; - } - - private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice) - { - return invoice.AmountDue > 0 && !invoice.Paid && invoice.CollectionMethod == "charge_automatically" && - invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; - } - - private async Task GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false) - { - if (!(parsedEvent.Data.Object is Charge eventCharge)) - { - throw new Exception("Charge is null (from parsed event). " + parsedEvent.Id); - } - if (!fresh) - { - return eventCharge; - } - var chargeService = new ChargeService(); - var charge = await chargeService.GetAsync(eventCharge.Id); - if (charge == null) - { - throw new Exception("Charge is null. " + eventCharge.Id); - } - return charge; - } - - private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false) - { - if (!(parsedEvent.Data.Object is Invoice eventInvoice)) - { - throw new Exception("Invoice is null (from parsed event). " + parsedEvent.Id); - } - if (!fresh) - { - return eventInvoice; - } - var invoiceService = new InvoiceService(); - var invoice = await invoiceService.GetAsync(eventInvoice.Id); - if (invoice == null) - { - throw new Exception("Invoice is null. " + eventInvoice.Id); - } - return invoice; - } - - private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false) - { - if (!(parsedEvent.Data.Object is Subscription eventSubscription)) - { - throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id); - } - if (!fresh) - { - return eventSubscription; - } - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(eventSubscription.Id); - if (subscription == null) - { - throw new Exception("Subscription is null. " + eventSubscription.Id); - } - return subscription; - } - - private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) - { - if (!string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) && !string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode)) - { - var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync( - new TaxRate() - { - Country = invoice.CustomerAddress.Country, - PostalCode = invoice.CustomerAddress.PostalCode - } - ); - - if (localBitwardenTaxRates.Any()) - { - var stripeTaxRate = await new TaxRateService().GetAsync(localBitwardenTaxRates.First().Id); - if (stripeTaxRate != null && !subscription.DefaultTaxRates.Any(x => x == stripeTaxRate)) - { - subscription.DefaultTaxRates = new List { stripeTaxRate }; - var subscriptionOptions = new SubscriptionUpdateOptions() { DefaultTaxRates = new List() { stripeTaxRate.Id } }; - subscription = await new SubscriptionService().UpdateAsync(subscription.Id, subscriptionOptions); - } - } - } - return subscription; - } - - private static bool IsSponsoredSubscription(Subscription subscription) => - StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id); - - private async Task HandlePaymentFailed(Invoice invoice) - { - if (!invoice.Paid && invoice.AttemptCount > 1 && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) - { var subscriptionService = new SubscriptionService(); var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - // attempt count 4 = 11 days after initial failure - if (invoice.AttemptCount > 3 && subscription.Items.Any(i => i.Price.Id == PremiumPlanId || i.Price.Id == PremiumPlanIdAppStore)) + var ids = GetIdsFromMetaData(subscription?.Metadata); + if (!ids.Item2.HasValue) { - await CancelSubscription(invoice.SubscriptionId); - await VoidOpenInvoices(invoice.SubscriptionId); + // Apple receipt is only for user subscriptions + return false; } - else + + if (appleReceiptRecord.Item2.Value != ids.Item2.Value) { - await AttemptToPayInvoiceAsync(invoice); + _logger.LogError("User Ids for Apple Receipt and subscription do not match: {0} != {1}.", + appleReceiptRecord.Item2.Value, ids.Item2.Value); + return false; + } + + var appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(appleReceiptRecord.Item1); + if (appleReceiptStatus == null) + { + // TODO: cancel sub if receipt is cancelled? + return false; + } + + var receiptExpiration = appleReceiptStatus.GetLastExpiresDate().GetValueOrDefault(DateTime.MinValue); + var invoiceDue = invoice.DueDate.GetValueOrDefault(DateTime.MinValue); + if (receiptExpiration <= invoiceDue) + { + _logger.LogWarning("Apple receipt expiration is before invoice due date. {0} <= {1}", + receiptExpiration, invoiceDue); + return false; + } + + var receiptLastTransactionId = appleReceiptStatus.GetLastTransactionId(); + var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.AppStore, receiptLastTransactionId); + if (existingTransaction != null) + { + _logger.LogWarning("There is already an existing transaction for this Apple receipt.", + receiptLastTransactionId); + return false; + } + + var appleTransaction = appleReceiptStatus.BuildTransactionFromLastTransaction( + PremiumPlanAppleIapPrice, ids.Item2.Value); + appleTransaction.Type = TransactionType.Charge; + + var invoiceService = new InvoiceService(); + try + { + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + { + Metadata = new Dictionary + { + ["appleReceipt"] = appleReceiptStatus.GetOriginalTransactionId(), + ["appleReceiptTransactionId"] = receiptLastTransactionId + } + }); + + await _transactionRepository.CreateAsync(appleTransaction); + await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); + } + catch (Exception e) + { + if (e.Message.Contains("Invoice is already paid")) + { + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + { + Metadata = invoice.Metadata + }); + } + else + { + throw; + } + } + + return true; + } + + private async Task AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice, Customer customer) + { + if (!customer?.Metadata?.ContainsKey("btCustomerId") ?? true) + { + return false; + } + + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + var ids = GetIdsFromMetaData(subscription?.Metadata); + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + return false; + } + + var orgTransaction = ids.Item1.HasValue; + var btObjIdField = orgTransaction ? "organization_id" : "user_id"; + var btObjId = ids.Item1 ?? ids.Item2.Value; + var btInvoiceAmount = (invoice.AmountDue / 100M); + + var existingTransactions = orgTransaction ? + await _transactionRepository.GetManyByOrganizationIdAsync(ids.Item1.Value) : + await _transactionRepository.GetManyByUserIdAsync(ids.Item2.Value); + var duplicateTimeSpan = TimeSpan.FromHours(24); + var now = DateTime.UtcNow; + var duplicateTransaction = existingTransactions? + .FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan); + if (duplicateTransaction != null) + { + _logger.LogWarning("There is already a recent PayPal transaction ({0}). " + + "Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId); + return false; + } + + var transactionResult = await _btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest + { + Amount = btInvoiceAmount, + CustomerId = customer.Metadata["btCustomerId"], + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{btObjIdField}:{btObjId}" + } + }, + CustomFields = new Dictionary + { + [btObjIdField] = btObjId.ToString() + } + }); + + if (!transactionResult.IsSuccess()) + { + if (invoice.AttemptCount < 4) + { + await _mailService.SendPaymentFailedAsync(customer.Email, btInvoiceAmount, true); + } + return false; + } + + var invoiceService = new InvoiceService(); + try + { + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + { + Metadata = new Dictionary + { + ["btTransactionId"] = transactionResult.Target.Id, + ["btPayPalTransactionId"] = + transactionResult.Target.PayPalDetails?.AuthorizationId + } + }); + await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); + } + catch (Exception e) + { + await _btGateway.Transaction.RefundAsync(transactionResult.Target.Id); + if (e.Message.Contains("Invoice is already paid")) + { + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + { + Metadata = invoice.Metadata + }); + } + else + { + throw; + } + } + + return true; + } + + private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice) + { + return invoice.AmountDue > 0 && !invoice.Paid && invoice.CollectionMethod == "charge_automatically" && + invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; + } + + private async Task GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false) + { + if (!(parsedEvent.Data.Object is Charge eventCharge)) + { + throw new Exception("Charge is null (from parsed event). " + parsedEvent.Id); + } + if (!fresh) + { + return eventCharge; + } + var chargeService = new ChargeService(); + var charge = await chargeService.GetAsync(eventCharge.Id); + if (charge == null) + { + throw new Exception("Charge is null. " + eventCharge.Id); + } + return charge; + } + + private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false) + { + if (!(parsedEvent.Data.Object is Invoice eventInvoice)) + { + throw new Exception("Invoice is null (from parsed event). " + parsedEvent.Id); + } + if (!fresh) + { + return eventInvoice; + } + var invoiceService = new InvoiceService(); + var invoice = await invoiceService.GetAsync(eventInvoice.Id); + if (invoice == null) + { + throw new Exception("Invoice is null. " + eventInvoice.Id); + } + return invoice; + } + + private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false) + { + if (!(parsedEvent.Data.Object is Subscription eventSubscription)) + { + throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id); + } + if (!fresh) + { + return eventSubscription; + } + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(eventSubscription.Id); + if (subscription == null) + { + throw new Exception("Subscription is null. " + eventSubscription.Id); + } + return subscription; + } + + private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) + { + if (!string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) && !string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode)) + { + var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() + { + Country = invoice.CustomerAddress.Country, + PostalCode = invoice.CustomerAddress.PostalCode + } + ); + + if (localBitwardenTaxRates.Any()) + { + var stripeTaxRate = await new TaxRateService().GetAsync(localBitwardenTaxRates.First().Id); + if (stripeTaxRate != null && !subscription.DefaultTaxRates.Any(x => x == stripeTaxRate)) + { + subscription.DefaultTaxRates = new List { stripeTaxRate }; + var subscriptionOptions = new SubscriptionUpdateOptions() { DefaultTaxRates = new List() { stripeTaxRate.Id } }; + subscription = await new SubscriptionService().UpdateAsync(subscription.Id, subscriptionOptions); + } + } + } + return subscription; + } + + private static bool IsSponsoredSubscription(Subscription subscription) => + StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id); + + private async Task HandlePaymentFailed(Invoice invoice) + { + if (!invoice.Paid && invoice.AttemptCount > 1 && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) + { + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + // attempt count 4 = 11 days after initial failure + if (invoice.AttemptCount > 3 && subscription.Items.Any(i => i.Price.Id == PremiumPlanId || i.Price.Id == PremiumPlanIdAppStore)) + { + await CancelSubscription(invoice.SubscriptionId); + await VoidOpenInvoices(invoice.SubscriptionId); + } + else + { + await AttemptToPayInvoiceAsync(invoice); + } } } - } - private async Task CancelSubscription(string subscriptionId) - { - await new SubscriptionService().CancelAsync(subscriptionId, new SubscriptionCancelOptions()); - } + private async Task CancelSubscription(string subscriptionId) + { + await new SubscriptionService().CancelAsync(subscriptionId, new SubscriptionCancelOptions()); + } - private async Task VoidOpenInvoices(string subscriptionId) - { - var invoiceService = new InvoiceService(); - var options = new InvoiceListOptions + private async Task VoidOpenInvoices(string subscriptionId) { - Status = "open", - Subscription = subscriptionId - }; - var invoices = invoiceService.List(options); - foreach (var invoice in invoices) - { - await invoiceService.VoidInvoiceAsync(invoice.Id); + var invoiceService = new InvoiceService(); + var options = new InvoiceListOptions + { + Status = "open", + Subscription = subscriptionId + }; + var invoices = invoiceService.List(options); + foreach (var invoice in invoices) + { + await invoiceService.VoidInvoiceAsync(invoice.Id); + } } } } diff --git a/src/Billing/Jobs/JobsHostedService.cs b/src/Billing/Jobs/JobsHostedService.cs index 1a5c80774c..ea91924a15 100644 --- a/src/Billing/Jobs/JobsHostedService.cs +++ b/src/Billing/Jobs/JobsHostedService.cs @@ -3,42 +3,43 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; -namespace Bit.Billing.Jobs; - -public class JobsHostedService : BaseJobsHostedService +namespace Bit.Billing.Jobs { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } - - public override async Task StartAsync(CancellationToken cancellationToken) + public class JobsHostedService : BaseJobsHostedService { - var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : - TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); - if (_globalSettings.SelfHosted) + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } + + public override async Task StartAsync(CancellationToken cancellationToken) { - timeZone = TimeZoneInfo.Local; + var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : + TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); + if (_globalSettings.SelfHosted) + { + timeZone = TimeZoneInfo.Local; + } + + var everyDayAtNinePmTrigger = TriggerBuilder.Create() + .WithIdentity("EveryDayAtNinePmTrigger") + .StartNow() + .WithCronSchedule("0 0 21 * * ?", x => x.InTimeZone(timeZone)) + .Build(); + + Jobs = new List>(); + + // Add jobs here + + await base.StartAsync(cancellationToken); } - var everyDayAtNinePmTrigger = TriggerBuilder.Create() - .WithIdentity("EveryDayAtNinePmTrigger") - .StartNow() - .WithCronSchedule("0 0 21 * * ?", x => x.InTimeZone(timeZone)) - .Build(); - - Jobs = new List>(); - - // Add jobs here - - await base.StartAsync(cancellationToken); - } - - public static void AddJobsServices(IServiceCollection services) - { - // Register jobs here + public static void AddJobsServices(IServiceCollection services) + { + // Register jobs here + } } } diff --git a/src/Billing/Models/BitPayEventModel.cs b/src/Billing/Models/BitPayEventModel.cs index e16391317a..b7ed06462d 100644 --- a/src/Billing/Models/BitPayEventModel.cs +++ b/src/Billing/Models/BitPayEventModel.cs @@ -1,27 +1,28 @@ -namespace Bit.Billing.Models; - -public class BitPayEventModel +namespace Bit.Billing.Models { - public EventModel Event { get; set; } - public InvoiceDataModel Data { get; set; } - - public class EventModel + public class BitPayEventModel { - public int Code { get; set; } - public string Name { get; set; } - } + public EventModel Event { get; set; } + public InvoiceDataModel Data { get; set; } - public class InvoiceDataModel - { - public string Id { get; set; } - public string Url { get; set; } - public string Status { get; set; } - public string Currency { get; set; } - public decimal Price { get; set; } - public string PosData { get; set; } - public bool ExceptionStatus { get; set; } - public long CurrentTime { get; set; } - public long AmountPaid { get; set; } - public string TransactionCurrency { get; set; } + public class EventModel + { + public int Code { get; set; } + public string Name { get; set; } + } + + public class InvoiceDataModel + { + public string Id { get; set; } + public string Url { get; set; } + public string Status { get; set; } + public string Currency { get; set; } + public decimal Price { get; set; } + public string PosData { get; set; } + public bool ExceptionStatus { get; set; } + public long CurrentTime { get; set; } + public long AmountPaid { get; set; } + public string TransactionCurrency { get; set; } + } } } diff --git a/src/Billing/Models/FreshdeskWebhookModel.cs b/src/Billing/Models/FreshdeskWebhookModel.cs index e9fe8e026a..c371c70fb5 100644 --- a/src/Billing/Models/FreshdeskWebhookModel.cs +++ b/src/Billing/Models/FreshdeskWebhookModel.cs @@ -1,15 +1,16 @@ using System.Text.Json.Serialization; -namespace Bit.Billing.Models; - -public class FreshdeskWebhookModel +namespace Bit.Billing.Models { - [JsonPropertyName("ticket_id")] - public string TicketId { get; set; } + public class FreshdeskWebhookModel + { + [JsonPropertyName("ticket_id")] + public string TicketId { get; set; } - [JsonPropertyName("ticket_contact_email")] - public string TicketContactEmail { get; set; } + [JsonPropertyName("ticket_contact_email")] + public string TicketContactEmail { get; set; } - [JsonPropertyName("ticket_tags")] - public string TicketTags { get; set; } + [JsonPropertyName("ticket_tags")] + public string TicketTags { get; set; } + } } diff --git a/src/Billing/Models/LoginModel.cs b/src/Billing/Models/LoginModel.cs index 5fe04ad454..51fdf0915e 100644 --- a/src/Billing/Models/LoginModel.cs +++ b/src/Billing/Models/LoginModel.cs @@ -1,10 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Billing.Models; - -public class LoginModel +namespace Bit.Billing.Models { - [Required] - [EmailAddress] - public string Email { get; set; } + public class LoginModel + { + [Required] + [EmailAddress] + public string Email { get; set; } + } } diff --git a/src/Billing/Program.cs b/src/Billing/Program.cs index d7ebadd92f..7b42ad73f9 100644 --- a/src/Billing/Program.cs +++ b/src/Billing/Program.cs @@ -1,38 +1,39 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Billing; - -public class Program +namespace Bit.Billing { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Level == LogEventLevel.Information && - (context.StartsWith("\"Bit.Billing.Jobs") || context.StartsWith("\"Bit.Core.Jobs"))) + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return true; - } + var context = e.Properties["SourceContext"].ToString(); + if (e.Level == LogEventLevel.Information && + (context.StartsWith("\"Bit.Billing.Jobs") || context.StartsWith("\"Bit.Core.Jobs"))) + { + return true; + } - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } - return e.Level >= LogEventLevel.Warning; - })); - }) - .Build() - .Run(); + return e.Level >= LogEventLevel.Warning; + })); + }) + .Build() + .Run(); + } } } diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index 328e6133d9..a2a161a88a 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -6,93 +6,94 @@ using Bit.SharedWeb.Utilities; using Microsoft.Extensions.DependencyInjection.Extensions; using Stripe; -namespace Bit.Billing; - -public class Startup +namespace Bit.Billing { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - services.Configure(Configuration.GetSection("BillingSettings")); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // PayPal Client - services.AddSingleton(); - - // BitPay Client - services.AddSingleton(); - - // Context - services.AddScoped(); - - // Identity - services.AddCustomIdentityServices(globalSettings); - //services.AddPasswordlessIdentityServices(globalSettings); - - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - - services.TryAddSingleton(); - - // Mvc - services.AddMvc(config => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - services.Configure(options => options.LowercaseUrls = true); - - // Authentication - services.AddAuthentication(); - - // Jobs service, uncomment when we have some jobs to run - // Jobs.JobsHostedService.AddJobsServices(services); - // services.AddHostedService(); - - // Set up HttpClients - services.AddHttpClient("FreshdeskApi"); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - app.UseStaticFiles(); - app.UseRouting(); - app.UseAuthentication(); - app.UseAuthorization(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + services.Configure(Configuration.GetSection("BillingSettings")); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // PayPal Client + services.AddSingleton(); + + // BitPay Client + services.AddSingleton(); + + // Context + services.AddScoped(); + + // Identity + services.AddCustomIdentityServices(globalSettings); + //services.AddPasswordlessIdentityServices(globalSettings); + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + + services.TryAddSingleton(); + + // Mvc + services.AddMvc(config => + { + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + services.Configure(options => options.LowercaseUrls = true); + + // Authentication + services.AddAuthentication(); + + // Jobs service, uncomment when we have some jobs to run + // Jobs.JobsHostedService.AddJobsServices(services); + // services.AddHostedService(); + + // Set up HttpClients + services.AddHttpClient("FreshdeskApi"); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseStaticFiles(); + app.UseRouting(); + app.UseAuthentication(); + app.UseAuthorization(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + } } } diff --git a/src/Billing/Utilities/PayPalIpnClient.cs b/src/Billing/Utilities/PayPalIpnClient.cs index 15f7a2f156..2a7b0e8526 100644 --- a/src/Billing/Utilities/PayPalIpnClient.cs +++ b/src/Billing/Utilities/PayPalIpnClient.cs @@ -4,170 +4,171 @@ using System.Text; using System.Web; using Microsoft.Extensions.Options; -namespace Bit.Billing.Utilities; - -public class PayPalIpnClient +namespace Bit.Billing.Utilities { - private readonly HttpClient _httpClient = new HttpClient(); - private readonly Uri _ipnUri; - - public PayPalIpnClient(IOptions billingSettings) + public class PayPalIpnClient { - var bSettings = billingSettings?.Value; - _ipnUri = new Uri(bSettings.PayPal.Production ? "https://www.paypal.com/cgi-bin/webscr" : - "https://www.sandbox.paypal.com/cgi-bin/webscr"); - } + private readonly HttpClient _httpClient = new HttpClient(); + private readonly Uri _ipnUri; - public async Task VerifyIpnAsync(string ipnBody) - { - if (ipnBody == null) + public PayPalIpnClient(IOptions billingSettings) { - throw new ArgumentException("No IPN body."); + var bSettings = billingSettings?.Value; + _ipnUri = new Uri(bSettings.PayPal.Production ? "https://www.paypal.com/cgi-bin/webscr" : + "https://www.sandbox.paypal.com/cgi-bin/webscr"); } - var request = new HttpRequestMessage + public async Task VerifyIpnAsync(string ipnBody) { - Method = HttpMethod.Post, - RequestUri = _ipnUri - }; - var cmdIpnBody = string.Concat("cmd=_notify-validate&", ipnBody); - request.Content = new StringContent(cmdIpnBody, Encoding.UTF8, "application/x-www-form-urlencoded"); - var response = await _httpClient.SendAsync(request); - if (!response.IsSuccessStatusCode) - { - throw new Exception("Failed to verify IPN, status: " + response.StatusCode); - } - var responseContent = await response.Content.ReadAsStringAsync(); - if (responseContent.Equals("VERIFIED")) - { - return true; - } - else if (responseContent.Equals("INVALID")) - { - return false; - } - else - { - throw new Exception("Failed to verify IPN."); - } - } - - public class IpnTransaction - { - private string[] _dateFormats = new string[] - { - "HH:mm:ss dd MMM yyyy PDT", "HH:mm:ss dd MMM yyyy PST", "HH:mm:ss dd MMM, yyyy PST", - "HH:mm:ss dd MMM, yyyy PDT","HH:mm:ss MMM dd, yyyy PST", "HH:mm:ss MMM dd, yyyy PDT" - }; - - public IpnTransaction(string ipnFormData) - { - if (string.IsNullOrWhiteSpace(ipnFormData)) + if (ipnBody == null) { - return; + throw new ArgumentException("No IPN body."); } - var qsData = HttpUtility.ParseQueryString(ipnFormData); - var dataDict = qsData.Keys.Cast().ToDictionary(k => k, v => qsData[v].ToString()); - - TxnId = GetDictValue(dataDict, "txn_id"); - TxnType = GetDictValue(dataDict, "txn_type"); - ParentTxnId = GetDictValue(dataDict, "parent_txn_id"); - PaymentStatus = GetDictValue(dataDict, "payment_status"); - PaymentType = GetDictValue(dataDict, "payment_type"); - McCurrency = GetDictValue(dataDict, "mc_currency"); - Custom = GetDictValue(dataDict, "custom"); - ItemName = GetDictValue(dataDict, "item_name"); - ItemNumber = GetDictValue(dataDict, "item_number"); - PayerId = GetDictValue(dataDict, "payer_id"); - PayerEmail = GetDictValue(dataDict, "payer_email"); - ReceiverId = GetDictValue(dataDict, "receiver_id"); - ReceiverEmail = GetDictValue(dataDict, "receiver_email"); - - PaymentDate = ConvertDate(GetDictValue(dataDict, "payment_date")); - - var mcGrossString = GetDictValue(dataDict, "mc_gross"); - if (!string.IsNullOrWhiteSpace(mcGrossString) && decimal.TryParse(mcGrossString, out var mcGross)) + var request = new HttpRequestMessage { - McGross = mcGross; + Method = HttpMethod.Post, + RequestUri = _ipnUri + }; + var cmdIpnBody = string.Concat("cmd=_notify-validate&", ipnBody); + request.Content = new StringContent(cmdIpnBody, Encoding.UTF8, "application/x-www-form-urlencoded"); + var response = await _httpClient.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + throw new Exception("Failed to verify IPN, status: " + response.StatusCode); } - var mcFeeString = GetDictValue(dataDict, "mc_fee"); - if (!string.IsNullOrWhiteSpace(mcFeeString) && decimal.TryParse(mcFeeString, out var mcFee)) + var responseContent = await response.Content.ReadAsStringAsync(); + if (responseContent.Equals("VERIFIED")) { - McFee = mcFee; + return true; + } + else if (responseContent.Equals("INVALID")) + { + return false; + } + else + { + throw new Exception("Failed to verify IPN."); } } - public string TxnId { get; set; } - public string TxnType { get; set; } - public string ParentTxnId { get; set; } - public string PaymentStatus { get; set; } - public string PaymentType { get; set; } - public decimal McGross { get; set; } - public decimal McFee { get; set; } - public string McCurrency { get; set; } - public string Custom { get; set; } - public string ItemName { get; set; } - public string ItemNumber { get; set; } - public string PayerId { get; set; } - public string PayerEmail { get; set; } - public string ReceiverId { get; set; } - public string ReceiverEmail { get; set; } - public DateTime PaymentDate { get; set; } - - public Tuple GetIdsFromCustom() + public class IpnTransaction { - Guid? orgId = null; - Guid? userId = null; - - if (!string.IsNullOrWhiteSpace(Custom) && Custom.Contains(":")) + private string[] _dateFormats = new string[] { - var mainParts = Custom.Split(','); - foreach (var mainPart in mainParts) + "HH:mm:ss dd MMM yyyy PDT", "HH:mm:ss dd MMM yyyy PST", "HH:mm:ss dd MMM, yyyy PST", + "HH:mm:ss dd MMM, yyyy PDT","HH:mm:ss MMM dd, yyyy PST", "HH:mm:ss MMM dd, yyyy PDT" + }; + + public IpnTransaction(string ipnFormData) + { + if (string.IsNullOrWhiteSpace(ipnFormData)) { - var parts = mainPart.Split(':'); - if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) + return; + } + + var qsData = HttpUtility.ParseQueryString(ipnFormData); + var dataDict = qsData.Keys.Cast().ToDictionary(k => k, v => qsData[v].ToString()); + + TxnId = GetDictValue(dataDict, "txn_id"); + TxnType = GetDictValue(dataDict, "txn_type"); + ParentTxnId = GetDictValue(dataDict, "parent_txn_id"); + PaymentStatus = GetDictValue(dataDict, "payment_status"); + PaymentType = GetDictValue(dataDict, "payment_type"); + McCurrency = GetDictValue(dataDict, "mc_currency"); + Custom = GetDictValue(dataDict, "custom"); + ItemName = GetDictValue(dataDict, "item_name"); + ItemNumber = GetDictValue(dataDict, "item_number"); + PayerId = GetDictValue(dataDict, "payer_id"); + PayerEmail = GetDictValue(dataDict, "payer_email"); + ReceiverId = GetDictValue(dataDict, "receiver_id"); + ReceiverEmail = GetDictValue(dataDict, "receiver_email"); + + PaymentDate = ConvertDate(GetDictValue(dataDict, "payment_date")); + + var mcGrossString = GetDictValue(dataDict, "mc_gross"); + if (!string.IsNullOrWhiteSpace(mcGrossString) && decimal.TryParse(mcGrossString, out var mcGross)) + { + McGross = mcGross; + } + var mcFeeString = GetDictValue(dataDict, "mc_fee"); + if (!string.IsNullOrWhiteSpace(mcFeeString) && decimal.TryParse(mcFeeString, out var mcFee)) + { + McFee = mcFee; + } + } + + public string TxnId { get; set; } + public string TxnType { get; set; } + public string ParentTxnId { get; set; } + public string PaymentStatus { get; set; } + public string PaymentType { get; set; } + public decimal McGross { get; set; } + public decimal McFee { get; set; } + public string McCurrency { get; set; } + public string Custom { get; set; } + public string ItemName { get; set; } + public string ItemNumber { get; set; } + public string PayerId { get; set; } + public string PayerEmail { get; set; } + public string ReceiverId { get; set; } + public string ReceiverEmail { get; set; } + public DateTime PaymentDate { get; set; } + + public Tuple GetIdsFromCustom() + { + Guid? orgId = null; + Guid? userId = null; + + if (!string.IsNullOrWhiteSpace(Custom) && Custom.Contains(":")) + { + var mainParts = Custom.Split(','); + foreach (var mainPart in mainParts) { - if (parts[0] == "user_id") + var parts = mainPart.Split(':'); + if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) { - userId = id; - } - else if (parts[0] == "organization_id") - { - orgId = id; + if (parts[0] == "user_id") + { + userId = id; + } + else if (parts[0] == "organization_id") + { + orgId = id; + } } } } + + return new Tuple(orgId, userId); } - return new Tuple(orgId, userId); - } - - public bool IsAccountCredit() - { - return !string.IsNullOrWhiteSpace(Custom) && Custom.Contains("account_credit:1"); - } - - private string GetDictValue(IDictionary dict, string key) - { - return dict.ContainsKey(key) ? dict[key] : null; - } - - private DateTime ConvertDate(string dateString) - { - if (!string.IsNullOrWhiteSpace(dateString)) + public bool IsAccountCredit() { - var parsed = DateTime.TryParseExact(dateString, _dateFormats, - CultureInfo.InvariantCulture, DateTimeStyles.None, out var paymentDate); - if (parsed) - { - var pacificTime = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - TimeZoneInfo.FindSystemTimeZoneById("Pacific Standard Time") : - TimeZoneInfo.FindSystemTimeZoneById("America/Los_Angeles"); - return TimeZoneInfo.ConvertTimeToUtc(paymentDate, pacificTime); - } + return !string.IsNullOrWhiteSpace(Custom) && Custom.Contains("account_credit:1"); + } + + private string GetDictValue(IDictionary dict, string key) + { + return dict.ContainsKey(key) ? dict[key] : null; + } + + private DateTime ConvertDate(string dateString) + { + if (!string.IsNullOrWhiteSpace(dateString)) + { + var parsed = DateTime.TryParseExact(dateString, _dateFormats, + CultureInfo.InvariantCulture, DateTimeStyles.None, out var paymentDate); + if (parsed) + { + var pacificTime = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + TimeZoneInfo.FindSystemTimeZoneById("Pacific Standard Time") : + TimeZoneInfo.FindSystemTimeZoneById("America/Los_Angeles"); + return TimeZoneInfo.ConvertTimeToUtc(paymentDate, pacificTime); + } + } + return default(DateTime); } - return default(DateTime); } } } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 8d1f009d40..68fd94295e 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -1,22 +1,23 @@ -namespace Bit.Core; - -public static class Constants +namespace Bit.Core { - public const int BypassFiltersEventId = 12482444; + public static class Constants + { + public const int BypassFiltersEventId = 12482444; - // File size limits - give 1 MB extra for cushion. - // Note: if request size limits are changed, 'client_max_body_size' - // in nginx/proxy.conf may also need to be updated accordingly. - public const long FileSize101mb = 101L * 1024L * 1024L; - public const long FileSize501mb = 501L * 1024L * 1024L; -} + // File size limits - give 1 MB extra for cushion. + // Note: if request size limits are changed, 'client_max_body_size' + // in nginx/proxy.conf may also need to be updated accordingly. + public const long FileSize101mb = 101L * 1024L * 1024L; + public const long FileSize501mb = 501L * 1024L * 1024L; + } -public static class TokenPurposes -{ - public const string LinkSso = "LinkSso"; -} + public static class TokenPurposes + { + public const string LinkSso = "LinkSso"; + } -public static class AuthenticationSchemes -{ - public const string BitwardenExternalCookieAuthenticationScheme = "bw.external"; + public static class AuthenticationSchemes + { + public const string BitwardenExternalCookieAuthenticationScheme = "bw.external"; + } } diff --git a/src/Core/Context/CurrentContentOrganization.cs b/src/Core/Context/CurrentContentOrganization.cs index 040c1ece49..7a54b27277 100644 --- a/src/Core/Context/CurrentContentOrganization.cs +++ b/src/Core/Context/CurrentContentOrganization.cs @@ -3,20 +3,21 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Context; - -public class CurrentContentOrganization +namespace Bit.Core.Context { - public CurrentContentOrganization() { } - - public CurrentContentOrganization(OrganizationUser orgUser) + public class CurrentContentOrganization { - Id = orgUser.OrganizationId; - Type = orgUser.Type; - Permissions = CoreHelpers.LoadClassFromJsonData(orgUser.Permissions); - } + public CurrentContentOrganization() { } - public Guid Id { get; set; } - public OrganizationUserType Type { get; set; } - public Permissions Permissions { get; set; } + public CurrentContentOrganization(OrganizationUser orgUser) + { + Id = orgUser.OrganizationId; + Type = orgUser.Type; + Permissions = CoreHelpers.LoadClassFromJsonData(orgUser.Permissions); + } + + public Guid Id { get; set; } + public OrganizationUserType Type { get; set; } + public Permissions Permissions { get; set; } + } } diff --git a/src/Core/Context/CurrentContentProvider.cs b/src/Core/Context/CurrentContentProvider.cs index f089be7b8a..f1925f5517 100644 --- a/src/Core/Context/CurrentContentProvider.cs +++ b/src/Core/Context/CurrentContentProvider.cs @@ -3,20 +3,21 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Context; - -public class CurrentContentProvider +namespace Bit.Core.Context { - public CurrentContentProvider() { } - - public CurrentContentProvider(ProviderUser providerUser) + public class CurrentContentProvider { - Id = providerUser.ProviderId; - Type = providerUser.Type; - Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); - } + public CurrentContentProvider() { } - public Guid Id { get; set; } - public ProviderUserType Type { get; set; } - public Permissions Permissions { get; set; } + public CurrentContentProvider(ProviderUser providerUser) + { + Id = providerUser.ProviderId; + Type = providerUser.Type; + Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); + } + + public Guid Id { get; set; } + public ProviderUserType Type { get; set; } + public Permissions Permissions { get; set; } + } } diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index d78340d700..47effcab13 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -8,485 +8,486 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Context; - -public class CurrentContext : ICurrentContext +namespace Bit.Core.Context { - private readonly IProviderUserRepository _providerUserRepository; - private bool _builtHttpContext; - private bool _builtClaimsPrincipal; - private IEnumerable _providerUserOrganizations; - - public virtual HttpContext HttpContext { get; set; } - public virtual Guid? UserId { get; set; } - public virtual User User { get; set; } - public virtual string DeviceIdentifier { get; set; } - public virtual DeviceType? DeviceType { get; set; } - public virtual string IpAddress { get; set; } - public virtual List Organizations { get; set; } - public virtual List Providers { get; set; } - public virtual Guid? InstallationId { get; set; } - public virtual Guid? OrganizationId { get; set; } - public virtual bool CloudflareWorkerProxied { get; set; } - public virtual bool IsBot { get; set; } - public virtual bool MaybeBot { get; set; } - public virtual int? BotScore { get; set; } - public virtual string ClientId { get; set; } - - public CurrentContext(IProviderUserRepository providerUserRepository) + public class CurrentContext : ICurrentContext { - _providerUserRepository = providerUserRepository; - } + private readonly IProviderUserRepository _providerUserRepository; + private bool _builtHttpContext; + private bool _builtClaimsPrincipal; + private IEnumerable _providerUserOrganizations; - public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings) - { - if (_builtHttpContext) + public virtual HttpContext HttpContext { get; set; } + public virtual Guid? UserId { get; set; } + public virtual User User { get; set; } + public virtual string DeviceIdentifier { get; set; } + public virtual DeviceType? DeviceType { get; set; } + public virtual string IpAddress { get; set; } + public virtual List Organizations { get; set; } + public virtual List Providers { get; set; } + public virtual Guid? InstallationId { get; set; } + public virtual Guid? OrganizationId { get; set; } + public virtual bool CloudflareWorkerProxied { get; set; } + public virtual bool IsBot { get; set; } + public virtual bool MaybeBot { get; set; } + public virtual int? BotScore { get; set; } + public virtual string ClientId { get; set; } + + public CurrentContext(IProviderUserRepository providerUserRepository) { - return; + _providerUserRepository = providerUserRepository; } - _builtHttpContext = true; - HttpContext = httpContext; - await BuildAsync(httpContext.User, globalSettings); - - if (DeviceIdentifier == null && httpContext.Request.Headers.ContainsKey("Device-Identifier")) + public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings) { - DeviceIdentifier = httpContext.Request.Headers["Device-Identifier"]; + if (_builtHttpContext) + { + return; + } + + _builtHttpContext = true; + HttpContext = httpContext; + await BuildAsync(httpContext.User, globalSettings); + + if (DeviceIdentifier == null && httpContext.Request.Headers.ContainsKey("Device-Identifier")) + { + DeviceIdentifier = httpContext.Request.Headers["Device-Identifier"]; + } + + if (httpContext.Request.Headers.ContainsKey("Device-Type") && + Enum.TryParse(httpContext.Request.Headers["Device-Type"].ToString(), out DeviceType dType)) + { + DeviceType = dType; + } + + if (!BotScore.HasValue && httpContext.Request.Headers.ContainsKey("X-Cf-Bot-Score") && + int.TryParse(httpContext.Request.Headers["X-Cf-Bot-Score"], out var parsedBotScore)) + { + BotScore = parsedBotScore; + } + + if (httpContext.Request.Headers.ContainsKey("X-Cf-Worked-Proxied")) + { + CloudflareWorkerProxied = httpContext.Request.Headers["X-Cf-Worked-Proxied"] == "1"; + } + + if (httpContext.Request.Headers.ContainsKey("X-Cf-Is-Bot")) + { + IsBot = httpContext.Request.Headers["X-Cf-Is-Bot"] == "1"; + } + + if (httpContext.Request.Headers.ContainsKey("X-Cf-Maybe-Bot")) + { + MaybeBot = httpContext.Request.Headers["X-Cf-Maybe-Bot"] == "1"; + } } - if (httpContext.Request.Headers.ContainsKey("Device-Type") && - Enum.TryParse(httpContext.Request.Headers["Device-Type"].ToString(), out DeviceType dType)) + public async virtual Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings) { - DeviceType = dType; + if (_builtClaimsPrincipal) + { + return; + } + + _builtClaimsPrincipal = true; + IpAddress = HttpContext.GetIpAddress(globalSettings); + await SetContextAsync(user); } - if (!BotScore.HasValue && httpContext.Request.Headers.ContainsKey("X-Cf-Bot-Score") && - int.TryParse(httpContext.Request.Headers["X-Cf-Bot-Score"], out var parsedBotScore)) + public virtual Task SetContextAsync(ClaimsPrincipal user) { - BotScore = parsedBotScore; - } + if (user == null || !user.Claims.Any()) + { + return Task.FromResult(0); + } - if (httpContext.Request.Headers.ContainsKey("X-Cf-Worked-Proxied")) - { - CloudflareWorkerProxied = httpContext.Request.Headers["X-Cf-Worked-Proxied"] == "1"; - } + var claimsDict = user.Claims.GroupBy(c => c.Type).ToDictionary(c => c.Key, c => c.Select(v => v)); - if (httpContext.Request.Headers.ContainsKey("X-Cf-Is-Bot")) - { - IsBot = httpContext.Request.Headers["X-Cf-Is-Bot"] == "1"; - } + var subject = GetClaimValue(claimsDict, "sub"); + if (Guid.TryParse(subject, out var subIdGuid)) + { + UserId = subIdGuid; + } - if (httpContext.Request.Headers.ContainsKey("X-Cf-Maybe-Bot")) - { - MaybeBot = httpContext.Request.Headers["X-Cf-Maybe-Bot"] == "1"; - } - } + ClientId = GetClaimValue(claimsDict, "client_id"); + var clientSubject = GetClaimValue(claimsDict, "client_sub"); + var orgApi = false; + if (clientSubject != null) + { + if (ClientId?.StartsWith("installation.") ?? false) + { + if (Guid.TryParse(clientSubject, out var idGuid)) + { + InstallationId = idGuid; + } + } + else if (ClientId?.StartsWith("organization.") ?? false) + { + if (Guid.TryParse(clientSubject, out var idGuid)) + { + OrganizationId = idGuid; + orgApi = true; + } + } + } - public async virtual Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings) - { - if (_builtClaimsPrincipal) - { - return; - } + DeviceIdentifier = GetClaimValue(claimsDict, "device"); - _builtClaimsPrincipal = true; - IpAddress = HttpContext.GetIpAddress(globalSettings); - await SetContextAsync(user); - } + Organizations = GetOrganizations(claimsDict, orgApi); + + Providers = GetProviders(claimsDict); - public virtual Task SetContextAsync(ClaimsPrincipal user) - { - if (user == null || !user.Claims.Any()) - { return Task.FromResult(0); } - var claimsDict = user.Claims.GroupBy(c => c.Type).ToDictionary(c => c.Key, c => c.Select(v => v)); - - var subject = GetClaimValue(claimsDict, "sub"); - if (Guid.TryParse(subject, out var subIdGuid)) + private List GetOrganizations(Dictionary> claimsDict, bool orgApi) { - UserId = subIdGuid; - } - - ClientId = GetClaimValue(claimsDict, "client_id"); - var clientSubject = GetClaimValue(claimsDict, "client_sub"); - var orgApi = false; - if (clientSubject != null) - { - if (ClientId?.StartsWith("installation.") ?? false) + var organizations = new List(); + if (claimsDict.ContainsKey("orgowner")) { - if (Guid.TryParse(clientSubject, out var idGuid)) - { - InstallationId = idGuid; - } + organizations.AddRange(claimsDict["orgowner"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Owner + })); } - else if (ClientId?.StartsWith("organization.") ?? false) + else if (orgApi && OrganizationId.HasValue) { - if (Guid.TryParse(clientSubject, out var idGuid)) + organizations.Add(new CurrentContentOrganization { - OrganizationId = idGuid; - orgApi = true; - } - } - } - - DeviceIdentifier = GetClaimValue(claimsDict, "device"); - - Organizations = GetOrganizations(claimsDict, orgApi); - - Providers = GetProviders(claimsDict); - - return Task.FromResult(0); - } - - private List GetOrganizations(Dictionary> claimsDict, bool orgApi) - { - var organizations = new List(); - if (claimsDict.ContainsKey("orgowner")) - { - organizations.AddRange(claimsDict["orgowner"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), + Id = OrganizationId.Value, Type = OrganizationUserType.Owner - })); - } - else if (orgApi && OrganizationId.HasValue) - { - organizations.Add(new CurrentContentOrganization + }); + } + + if (claimsDict.ContainsKey("orgadmin")) { - Id = OrganizationId.Value, - Type = OrganizationUserType.Owner - }); + organizations.AddRange(claimsDict["orgadmin"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Admin + })); + } + + if (claimsDict.ContainsKey("orguser")) + { + organizations.AddRange(claimsDict["orguser"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.User + })); + } + + if (claimsDict.ContainsKey("orgmanager")) + { + organizations.AddRange(claimsDict["orgmanager"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Manager + })); + } + + if (claimsDict.ContainsKey("orgcustom")) + { + organizations.AddRange(claimsDict["orgcustom"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Custom, + Permissions = SetOrganizationPermissionsFromClaims(c.Value, claimsDict) + })); + } + + return organizations; } - if (claimsDict.ContainsKey("orgadmin")) + private List GetProviders(Dictionary> claimsDict) { - organizations.AddRange(claimsDict["orgadmin"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Admin - })); + var providers = new List(); + if (claimsDict.ContainsKey("providerprovideradmin")) + { + providers.AddRange(claimsDict["providerprovideradmin"].Select(c => + new CurrentContentProvider + { + Id = new Guid(c.Value), + Type = ProviderUserType.ProviderAdmin + })); + } + + if (claimsDict.ContainsKey("providerserviceuser")) + { + providers.AddRange(claimsDict["providerserviceuser"].Select(c => + new CurrentContentProvider + { + Id = new Guid(c.Value), + Type = ProviderUserType.ServiceUser + })); + } + + return providers; } - if (claimsDict.ContainsKey("orguser")) + public async Task OrganizationUser(Guid orgId) { - organizations.AddRange(claimsDict["orguser"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.User - })); + return (Organizations?.Any(o => o.Id == orgId) ?? false) || await OrganizationOwner(orgId); } - if (claimsDict.ContainsKey("orgmanager")) + public async Task OrganizationManager(Guid orgId) { - organizations.AddRange(claimsDict["orgmanager"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Manager - })); + return await OrganizationAdmin(orgId) || + (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Manager) ?? false); } - if (claimsDict.ContainsKey("orgcustom")) + public async Task OrganizationAdmin(Guid orgId) { - organizations.AddRange(claimsDict["orgcustom"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Custom, - Permissions = SetOrganizationPermissionsFromClaims(c.Value, claimsDict) - })); + return await OrganizationOwner(orgId) || + (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Admin) ?? false); } - return organizations; - } - - private List GetProviders(Dictionary> claimsDict) - { - var providers = new List(); - if (claimsDict.ContainsKey("providerprovideradmin")) + public async Task OrganizationOwner(Guid orgId) { - providers.AddRange(claimsDict["providerprovideradmin"].Select(c => - new CurrentContentProvider - { - Id = new Guid(c.Value), - Type = ProviderUserType.ProviderAdmin - })); + if (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Owner) ?? false) + { + return true; + } + + if (Providers.Any()) + { + return await ProviderUserForOrgAsync(orgId); + } + + return false; } - if (claimsDict.ContainsKey("providerserviceuser")) + public Task OrganizationCustom(Guid orgId) { - providers.AddRange(claimsDict["providerserviceuser"].Select(c => - new CurrentContentProvider - { - Id = new Guid(c.Value), - Type = ProviderUserType.ServiceUser - })); + return Task.FromResult(Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Custom) ?? false); } - return providers; - } - - public async Task OrganizationUser(Guid orgId) - { - return (Organizations?.Any(o => o.Id == orgId) ?? false) || await OrganizationOwner(orgId); - } - - public async Task OrganizationManager(Guid orgId) - { - return await OrganizationAdmin(orgId) || - (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Manager) ?? false); - } - - public async Task OrganizationAdmin(Guid orgId) - { - return await OrganizationOwner(orgId) || - (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Admin) ?? false); - } - - public async Task OrganizationOwner(Guid orgId) - { - if (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Owner) ?? false) + public async Task AccessEventLogs(Guid orgId) { - return true; + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.AccessEventLogs ?? false)) ?? false); } - if (Providers.Any()) + public async Task AccessImportExport(Guid orgId) { - return await ProviderUserForOrgAsync(orgId); + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.AccessImportExport ?? false)) ?? false); } - return false; - } - - public Task OrganizationCustom(Guid orgId) - { - return Task.FromResult(Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Custom) ?? false); - } - - public async Task AccessEventLogs(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.AccessEventLogs ?? false)) ?? false); - } - - public async Task AccessImportExport(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.AccessImportExport ?? false)) ?? false); - } - - public async Task AccessReports(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.AccessReports ?? false)) ?? false); - } - - public async Task CreateNewCollections(Guid orgId) - { - return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.CreateNewCollections ?? false)) ?? false); - } - - public async Task EditAnyCollection(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.EditAnyCollection ?? false)) ?? false); - } - - public async Task DeleteAnyCollection(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.DeleteAnyCollection ?? false)) ?? false); - } - - public async Task ViewAllCollections(Guid orgId) - { - return await CreateNewCollections(orgId) || await EditAnyCollection(orgId) || await DeleteAnyCollection(orgId); - } - - public async Task EditAssignedCollections(Guid orgId) - { - return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.EditAssignedCollections ?? false)) ?? false); - } - - public async Task DeleteAssignedCollections(Guid orgId) - { - return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.DeleteAssignedCollections ?? false)) ?? false); - } - - public async Task ViewAssignedCollections(Guid orgId) - { - return await EditAssignedCollections(orgId) || await DeleteAssignedCollections(orgId); - } - - public async Task ManageGroups(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageGroups ?? false)) ?? false); - } - - public async Task ManagePolicies(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManagePolicies ?? false)) ?? false); - } - - public async Task ManageSso(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageSso ?? false)) ?? false); - } - - public async Task ManageScim(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageScim ?? false)) ?? false); - } - - public async Task ManageUsers(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageUsers ?? false)) ?? false); - } - - public async Task ManageResetPassword(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageResetPassword ?? false)) ?? false); - } - - public async Task ManageBilling(Guid orgId) - { - var orgManagedByProvider = await ProviderIdForOrg(orgId) != null; - - return orgManagedByProvider - ? await ProviderUserForOrgAsync(orgId) - : await OrganizationOwner(orgId); - } - - public bool ProviderProviderAdmin(Guid providerId) - { - return Providers?.Any(o => o.Id == providerId && o.Type == ProviderUserType.ProviderAdmin) ?? false; - } - - public bool ProviderManageUsers(Guid providerId) - { - return ProviderProviderAdmin(providerId); - } - - public bool ProviderAccessEventLogs(Guid providerId) - { - return ProviderProviderAdmin(providerId); - } - - public bool AccessProviderOrganizations(Guid providerId) - { - return ProviderUser(providerId); - } - - public bool ManageProviderOrganizations(Guid providerId) - { - return ProviderProviderAdmin(providerId); - } - - public bool ProviderUser(Guid providerId) - { - return Providers?.Any(o => o.Id == providerId) ?? false; - } - - public async Task ProviderUserForOrgAsync(Guid orgId) - { - return (await GetProviderOrganizations()).Any(po => po.OrganizationId == orgId); - } - - public async Task ProviderIdForOrg(Guid orgId) - { - if (Organizations?.Any(org => org.Id == orgId) ?? false) + public async Task AccessReports(Guid orgId) { - return null; + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.AccessReports ?? false)) ?? false); } - var po = (await GetProviderOrganizations()) - ?.FirstOrDefault(po => po.OrganizationId == orgId); - - return po?.ProviderId; - } - - public async Task> OrganizationMembershipAsync( - IOrganizationUserRepository organizationUserRepository, Guid userId) - { - if (Organizations == null) + public async Task CreateNewCollections(Guid orgId) { - var userOrgs = await organizationUserRepository.GetManyByUserAsync(userId); - Organizations = userOrgs.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed) - .Select(ou => new CurrentContentOrganization(ou)).ToList(); - } - return Organizations; - } - - public async Task> ProviderMembershipAsync( - IProviderUserRepository providerUserRepository, Guid userId) - { - if (Providers == null) - { - var userProviders = await providerUserRepository.GetManyByUserAsync(userId); - Providers = userProviders.Where(ou => ou.Status == ProviderUserStatusType.Confirmed) - .Select(ou => new CurrentContentProvider(ou)).ToList(); - } - return Providers; - } - - private string GetClaimValue(Dictionary> claims, string type) - { - if (!claims.ContainsKey(type)) - { - return null; + return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.CreateNewCollections ?? false)) ?? false); } - return claims[type].FirstOrDefault()?.Value; - } - - private Permissions SetOrganizationPermissionsFromClaims(string organizationId, Dictionary> claimsDict) - { - bool hasClaim(string claimKey) + public async Task EditAnyCollection(Guid orgId) { - return claimsDict.ContainsKey(claimKey) ? - claimsDict[claimKey].Any(x => x.Value == organizationId) : false; + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.EditAnyCollection ?? false)) ?? false); } - return new Permissions + public async Task DeleteAnyCollection(Guid orgId) { - AccessEventLogs = hasClaim("accesseventlogs"), - AccessImportExport = hasClaim("accessimportexport"), - AccessReports = hasClaim("accessreports"), - CreateNewCollections = hasClaim("createnewcollections"), - EditAnyCollection = hasClaim("editanycollection"), - DeleteAnyCollection = hasClaim("deleteanycollection"), - EditAssignedCollections = hasClaim("editassignedcollections"), - DeleteAssignedCollections = hasClaim("deleteassignedcollections"), - ManageGroups = hasClaim("managegroups"), - ManagePolicies = hasClaim("managepolicies"), - ManageSso = hasClaim("managesso"), - ManageUsers = hasClaim("manageusers"), - ManageResetPassword = hasClaim("manageresetpassword"), - ManageScim = hasClaim("managescim"), - }; - } - - protected async Task> GetProviderOrganizations() - { - if (_providerUserOrganizations == null && UserId.HasValue) - { - _providerUserOrganizations = await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(UserId.Value, ProviderUserStatusType.Confirmed); + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.DeleteAnyCollection ?? false)) ?? false); } - return _providerUserOrganizations; + public async Task ViewAllCollections(Guid orgId) + { + return await CreateNewCollections(orgId) || await EditAnyCollection(orgId) || await DeleteAnyCollection(orgId); + } + + public async Task EditAssignedCollections(Guid orgId) + { + return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.EditAssignedCollections ?? false)) ?? false); + } + + public async Task DeleteAssignedCollections(Guid orgId) + { + return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.DeleteAssignedCollections ?? false)) ?? false); + } + + public async Task ViewAssignedCollections(Guid orgId) + { + return await EditAssignedCollections(orgId) || await DeleteAssignedCollections(orgId); + } + + public async Task ManageGroups(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageGroups ?? false)) ?? false); + } + + public async Task ManagePolicies(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManagePolicies ?? false)) ?? false); + } + + public async Task ManageSso(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageSso ?? false)) ?? false); + } + + public async Task ManageScim(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageScim ?? false)) ?? false); + } + + public async Task ManageUsers(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageUsers ?? false)) ?? false); + } + + public async Task ManageResetPassword(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageResetPassword ?? false)) ?? false); + } + + public async Task ManageBilling(Guid orgId) + { + var orgManagedByProvider = await ProviderIdForOrg(orgId) != null; + + return orgManagedByProvider + ? await ProviderUserForOrgAsync(orgId) + : await OrganizationOwner(orgId); + } + + public bool ProviderProviderAdmin(Guid providerId) + { + return Providers?.Any(o => o.Id == providerId && o.Type == ProviderUserType.ProviderAdmin) ?? false; + } + + public bool ProviderManageUsers(Guid providerId) + { + return ProviderProviderAdmin(providerId); + } + + public bool ProviderAccessEventLogs(Guid providerId) + { + return ProviderProviderAdmin(providerId); + } + + public bool AccessProviderOrganizations(Guid providerId) + { + return ProviderUser(providerId); + } + + public bool ManageProviderOrganizations(Guid providerId) + { + return ProviderProviderAdmin(providerId); + } + + public bool ProviderUser(Guid providerId) + { + return Providers?.Any(o => o.Id == providerId) ?? false; + } + + public async Task ProviderUserForOrgAsync(Guid orgId) + { + return (await GetProviderOrganizations()).Any(po => po.OrganizationId == orgId); + } + + public async Task ProviderIdForOrg(Guid orgId) + { + if (Organizations?.Any(org => org.Id == orgId) ?? false) + { + return null; + } + + var po = (await GetProviderOrganizations()) + ?.FirstOrDefault(po => po.OrganizationId == orgId); + + return po?.ProviderId; + } + + public async Task> OrganizationMembershipAsync( + IOrganizationUserRepository organizationUserRepository, Guid userId) + { + if (Organizations == null) + { + var userOrgs = await organizationUserRepository.GetManyByUserAsync(userId); + Organizations = userOrgs.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed) + .Select(ou => new CurrentContentOrganization(ou)).ToList(); + } + return Organizations; + } + + public async Task> ProviderMembershipAsync( + IProviderUserRepository providerUserRepository, Guid userId) + { + if (Providers == null) + { + var userProviders = await providerUserRepository.GetManyByUserAsync(userId); + Providers = userProviders.Where(ou => ou.Status == ProviderUserStatusType.Confirmed) + .Select(ou => new CurrentContentProvider(ou)).ToList(); + } + return Providers; + } + + private string GetClaimValue(Dictionary> claims, string type) + { + if (!claims.ContainsKey(type)) + { + return null; + } + + return claims[type].FirstOrDefault()?.Value; + } + + private Permissions SetOrganizationPermissionsFromClaims(string organizationId, Dictionary> claimsDict) + { + bool hasClaim(string claimKey) + { + return claimsDict.ContainsKey(claimKey) ? + claimsDict[claimKey].Any(x => x.Value == organizationId) : false; + } + + return new Permissions + { + AccessEventLogs = hasClaim("accesseventlogs"), + AccessImportExport = hasClaim("accessimportexport"), + AccessReports = hasClaim("accessreports"), + CreateNewCollections = hasClaim("createnewcollections"), + EditAnyCollection = hasClaim("editanycollection"), + DeleteAnyCollection = hasClaim("deleteanycollection"), + EditAssignedCollections = hasClaim("editassignedcollections"), + DeleteAssignedCollections = hasClaim("deleteassignedcollections"), + ManageGroups = hasClaim("managegroups"), + ManagePolicies = hasClaim("managepolicies"), + ManageSso = hasClaim("managesso"), + ManageUsers = hasClaim("manageusers"), + ManageResetPassword = hasClaim("manageresetpassword"), + ManageScim = hasClaim("managescim"), + }; + } + + protected async Task> GetProviderOrganizations() + { + if (_providerUserOrganizations == null && UserId.HasValue) + { + _providerUserOrganizations = await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(UserId.Value, ProviderUserStatusType.Confirmed); + } + + return _providerUserOrganizations; + } } } diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index b53e43dfac..d82ad12e46 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -5,64 +5,65 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Context; - -public interface ICurrentContext +namespace Bit.Core.Context { - HttpContext HttpContext { get; set; } - Guid? UserId { get; set; } - User User { get; set; } - string DeviceIdentifier { get; set; } - DeviceType? DeviceType { get; set; } - string IpAddress { get; set; } - List Organizations { get; set; } - Guid? InstallationId { get; set; } - Guid? OrganizationId { get; set; } - bool IsBot { get; set; } - bool MaybeBot { get; set; } - int? BotScore { get; set; } - string ClientId { get; set; } - Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings); - Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings); + public interface ICurrentContext + { + HttpContext HttpContext { get; set; } + Guid? UserId { get; set; } + User User { get; set; } + string DeviceIdentifier { get; set; } + DeviceType? DeviceType { get; set; } + string IpAddress { get; set; } + List Organizations { get; set; } + Guid? InstallationId { get; set; } + Guid? OrganizationId { get; set; } + bool IsBot { get; set; } + bool MaybeBot { get; set; } + int? BotScore { get; set; } + string ClientId { get; set; } + Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings); + Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings); - Task SetContextAsync(ClaimsPrincipal user); + Task SetContextAsync(ClaimsPrincipal user); - Task OrganizationUser(Guid orgId); - Task OrganizationManager(Guid orgId); - Task OrganizationAdmin(Guid orgId); - Task OrganizationOwner(Guid orgId); - Task OrganizationCustom(Guid orgId); - Task AccessEventLogs(Guid orgId); - Task AccessImportExport(Guid orgId); - Task AccessReports(Guid orgId); - Task CreateNewCollections(Guid orgId); - Task EditAnyCollection(Guid orgId); - Task DeleteAnyCollection(Guid orgId); - Task ViewAllCollections(Guid orgId); - Task EditAssignedCollections(Guid orgId); - Task DeleteAssignedCollections(Guid orgId); - Task ViewAssignedCollections(Guid orgId); - Task ManageGroups(Guid orgId); - Task ManagePolicies(Guid orgId); - Task ManageSso(Guid orgId); - Task ManageUsers(Guid orgId); - Task ManageScim(Guid orgId); - Task ManageResetPassword(Guid orgId); - Task ManageBilling(Guid orgId); - Task ProviderUserForOrgAsync(Guid orgId); - bool ProviderProviderAdmin(Guid providerId); - bool ProviderUser(Guid providerId); - bool ProviderManageUsers(Guid providerId); - bool ProviderAccessEventLogs(Guid providerId); - bool AccessProviderOrganizations(Guid providerId); - bool ManageProviderOrganizations(Guid providerId); + Task OrganizationUser(Guid orgId); + Task OrganizationManager(Guid orgId); + Task OrganizationAdmin(Guid orgId); + Task OrganizationOwner(Guid orgId); + Task OrganizationCustom(Guid orgId); + Task AccessEventLogs(Guid orgId); + Task AccessImportExport(Guid orgId); + Task AccessReports(Guid orgId); + Task CreateNewCollections(Guid orgId); + Task EditAnyCollection(Guid orgId); + Task DeleteAnyCollection(Guid orgId); + Task ViewAllCollections(Guid orgId); + Task EditAssignedCollections(Guid orgId); + Task DeleteAssignedCollections(Guid orgId); + Task ViewAssignedCollections(Guid orgId); + Task ManageGroups(Guid orgId); + Task ManagePolicies(Guid orgId); + Task ManageSso(Guid orgId); + Task ManageUsers(Guid orgId); + Task ManageScim(Guid orgId); + Task ManageResetPassword(Guid orgId); + Task ManageBilling(Guid orgId); + Task ProviderUserForOrgAsync(Guid orgId); + bool ProviderProviderAdmin(Guid providerId); + bool ProviderUser(Guid providerId); + bool ProviderManageUsers(Guid providerId); + bool ProviderAccessEventLogs(Guid providerId); + bool AccessProviderOrganizations(Guid providerId); + bool ManageProviderOrganizations(Guid providerId); - Task> OrganizationMembershipAsync( - IOrganizationUserRepository organizationUserRepository, Guid userId); + Task> OrganizationMembershipAsync( + IOrganizationUserRepository organizationUserRepository, Guid userId); - Task> ProviderMembershipAsync( - IProviderUserRepository providerUserRepository, Guid userId); + Task> ProviderMembershipAsync( + IProviderUserRepository providerUserRepository, Guid userId); - Task ProviderIdForOrg(Guid orgId); + Task ProviderIdForOrg(Guid orgId); + } } diff --git a/src/Core/Entities/Cipher.cs b/src/Core/Entities/Cipher.cs index 186a7c5b84..c4e57aa76f 100644 --- a/src/Core/Entities/Cipher.cs +++ b/src/Core/Entities/Cipher.cs @@ -2,107 +2,108 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Cipher : ITableObject, ICloneable +namespace Bit.Core.Entities { - private Dictionary _attachmentData; - - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Enums.CipherType Type { get; set; } - public string Data { get; set; } - public string Favorites { get; set; } - public string Folders { get; set; } - public string Attachments { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public DateTime? DeletedDate { get; set; } - public Enums.CipherRepromptType? Reprompt { get; set; } - - public void SetNewId() + public class Cipher : ITableObject, ICloneable { - Id = CoreHelpers.GenerateComb(); - } + private Dictionary _attachmentData; - public Dictionary GetAttachments() - { - if (string.IsNullOrWhiteSpace(Attachments)) + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Enums.CipherType Type { get; set; } + public string Data { get; set; } + public string Favorites { get; set; } + public string Folders { get; set; } + public string Attachments { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public DateTime? DeletedDate { get; set; } + public Enums.CipherRepromptType? Reprompt { get; set; } + + public void SetNewId() { - return null; + Id = CoreHelpers.GenerateComb(); } - if (_attachmentData != null) + public Dictionary GetAttachments() { - return _attachmentData; - } - - try - { - _attachmentData = JsonSerializer.Deserialize>(Attachments); - foreach (var kvp in _attachmentData) + if (string.IsNullOrWhiteSpace(Attachments)) { - kvp.Value.AttachmentId = kvp.Key; + return null; + } + + if (_attachmentData != null) + { + return _attachmentData; + } + + try + { + _attachmentData = JsonSerializer.Deserialize>(Attachments); + foreach (var kvp in _attachmentData) + { + kvp.Value.AttachmentId = kvp.Key; + } + return _attachmentData; + } + catch + { + return null; } - return _attachmentData; } - catch + + public void SetAttachments(Dictionary data) { - return null; - } - } + if (data == null || data.Count == 0) + { + _attachmentData = null; + Attachments = null; + return; + } - public void SetAttachments(Dictionary data) - { - if (data == null || data.Count == 0) + _attachmentData = data; + Attachments = JsonSerializer.Serialize(_attachmentData); + } + + public void AddAttachment(string id, CipherAttachment.MetaData data) { - _attachmentData = null; - Attachments = null; - return; + var attachments = GetAttachments(); + if (attachments == null) + { + attachments = new Dictionary(); + } + + attachments.Add(id, data); + SetAttachments(attachments); } - _attachmentData = data; - Attachments = JsonSerializer.Serialize(_attachmentData); - } - - public void AddAttachment(string id, CipherAttachment.MetaData data) - { - var attachments = GetAttachments(); - if (attachments == null) + public void DeleteAttachment(string id) { - attachments = new Dictionary(); + var attachments = GetAttachments(); + if (!attachments?.ContainsKey(id) ?? true) + { + return; + } + + attachments.Remove(id); + SetAttachments(attachments); } - attachments.Add(id, data); - SetAttachments(attachments); - } - - public void DeleteAttachment(string id) - { - var attachments = GetAttachments(); - if (!attachments?.ContainsKey(id) ?? true) + public bool ContainsAttachment(string id) { - return; + var attachments = GetAttachments(); + return attachments?.ContainsKey(id) ?? false; } - attachments.Remove(id); - SetAttachments(attachments); - } + object ICloneable.Clone() => Clone(); + public Cipher Clone() + { + var clone = CoreHelpers.CloneObject(this); + clone.CreationDate = CreationDate; + clone.RevisionDate = RevisionDate; - public bool ContainsAttachment(string id) - { - var attachments = GetAttachments(); - return attachments?.ContainsKey(id) ?? false; - } - - object ICloneable.Clone() => Clone(); - public Cipher Clone() - { - var clone = CoreHelpers.CloneObject(this); - clone.CreationDate = CreationDate; - clone.RevisionDate = RevisionDate; - - return clone; + return clone; + } } } diff --git a/src/Core/Entities/Collection.cs b/src/Core/Entities/Collection.cs index fb7225fc20..fb6e646fc2 100644 --- a/src/Core/Entities/Collection.cs +++ b/src/Core/Entities/Collection.cs @@ -1,20 +1,21 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Collection : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public string Name { get; set; } - [MaxLength(300)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - - public void SetNewId() + public class Collection : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public string Name { get; set; } + [MaxLength(300)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/CollectionCipher.cs b/src/Core/Entities/CollectionCipher.cs index d212ced514..f04c2bdf48 100644 --- a/src/Core/Entities/CollectionCipher.cs +++ b/src/Core/Entities/CollectionCipher.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Entities; - -public class CollectionCipher +namespace Bit.Core.Entities { - public Guid CollectionId { get; set; } - public Guid CipherId { get; set; } + public class CollectionCipher + { + public Guid CollectionId { get; set; } + public Guid CipherId { get; set; } + } } diff --git a/src/Core/Entities/CollectionGroup.cs b/src/Core/Entities/CollectionGroup.cs index 8224aed466..c68ae30055 100644 --- a/src/Core/Entities/CollectionGroup.cs +++ b/src/Core/Entities/CollectionGroup.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Entities; - -public class CollectionGroup +namespace Bit.Core.Entities { - public Guid CollectionId { get; set; } - public Guid GroupId { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } + public class CollectionGroup + { + public Guid CollectionId { get; set; } + public Guid GroupId { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } + } } diff --git a/src/Core/Entities/CollectionUser.cs b/src/Core/Entities/CollectionUser.cs index bb22e7b7c9..5b5d01fcc7 100644 --- a/src/Core/Entities/CollectionUser.cs +++ b/src/Core/Entities/CollectionUser.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Entities; - -public class CollectionUser +namespace Bit.Core.Entities { - public Guid CollectionId { get; set; } - public Guid OrganizationUserId { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } + public class CollectionUser + { + public Guid CollectionId { get; set; } + public Guid OrganizationUserId { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } + } } diff --git a/src/Core/Entities/Device.cs b/src/Core/Entities/Device.cs index 3b5fb1a247..9cca56c3fa 100644 --- a/src/Core/Entities/Device.cs +++ b/src/Core/Entities/Device.cs @@ -1,24 +1,25 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Device : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid UserId { get; set; } - [MaxLength(50)] - public string Name { get; set; } - public Enums.DeviceType Type { get; set; } - [MaxLength(50)] - public string Identifier { get; set; } - [MaxLength(255)] - public string PushToken { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class Device : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid UserId { get; set; } + [MaxLength(50)] + public string Name { get; set; } + public Enums.DeviceType Type { get; set; } + [MaxLength(50)] + public string Identifier { get; set; } + [MaxLength(255)] + public string PushToken { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/EmergencyAccess.cs b/src/Core/Entities/EmergencyAccess.cs index e78f90e662..eafd9ee8e3 100644 --- a/src/Core/Entities/EmergencyAccess.cs +++ b/src/Core/Entities/EmergencyAccess.cs @@ -2,45 +2,46 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class EmergencyAccess : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid GrantorId { get; set; } - public Guid? GranteeId { get; set; } - [MaxLength(256)] - public string Email { get; set; } - public string KeyEncrypted { get; set; } - public EmergencyAccessType Type { get; set; } - public EmergencyAccessStatusType Status { get; set; } - public int WaitTimeDays { get; set; } - public DateTime? RecoveryInitiatedDate { get; set; } - public DateTime? LastNotificationDate { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - - public void SetNewId() + public class EmergencyAccess : ITableObject { - Id = CoreHelpers.GenerateComb(); - } + public Guid Id { get; set; } + public Guid GrantorId { get; set; } + public Guid? GranteeId { get; set; } + [MaxLength(256)] + public string Email { get; set; } + public string KeyEncrypted { get; set; } + public EmergencyAccessType Type { get; set; } + public EmergencyAccessStatusType Status { get; set; } + public int WaitTimeDays { get; set; } + public DateTime? RecoveryInitiatedDate { get; set; } + public DateTime? LastNotificationDate { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public EmergencyAccess ToEmergencyAccess() - { - return new EmergencyAccess + public void SetNewId() { - Id = Id, - GrantorId = GrantorId, - GranteeId = GranteeId, - Email = Email, - KeyEncrypted = KeyEncrypted, - Type = Type, - Status = Status, - WaitTimeDays = WaitTimeDays, - RecoveryInitiatedDate = RecoveryInitiatedDate, - LastNotificationDate = LastNotificationDate, - CreationDate = CreationDate, - RevisionDate = RevisionDate, - }; + Id = CoreHelpers.GenerateComb(); + } + + public EmergencyAccess ToEmergencyAccess() + { + return new EmergencyAccess + { + Id = Id, + GrantorId = GrantorId, + GranteeId = GranteeId, + Email = Email, + KeyEncrypted = KeyEncrypted, + Type = Type, + Status = Status, + WaitTimeDays = WaitTimeDays, + RecoveryInitiatedDate = RecoveryInitiatedDate, + LastNotificationDate = LastNotificationDate, + CreationDate = CreationDate, + RevisionDate = RevisionDate, + }; + } } } diff --git a/src/Core/Entities/Event.cs b/src/Core/Entities/Event.cs index 99e2091c9a..d17116ce6f 100644 --- a/src/Core/Entities/Event.cs +++ b/src/Core/Entities/Event.cs @@ -3,53 +3,54 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Event : ITableObject, IEvent +namespace Bit.Core.Entities { - public Event() { } - - public Event(IEvent e) + public class Event : ITableObject, IEvent { - Date = e.Date; - Type = e.Type; - UserId = e.UserId; - OrganizationId = e.OrganizationId; - ProviderId = e.ProviderId; - CipherId = e.CipherId; - CollectionId = e.CollectionId; - PolicyId = e.PolicyId; - GroupId = e.GroupId; - OrganizationUserId = e.OrganizationUserId; - InstallationId = e.InstallationId; - ProviderUserId = e.ProviderUserId; - ProviderOrganizationId = e.ProviderOrganizationId; - DeviceType = e.DeviceType; - IpAddress = e.IpAddress; - ActingUserId = e.ActingUserId; - } + public Event() { } - public Guid Id { get; set; } - public DateTime Date { get; set; } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? GroupId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public DeviceType? DeviceType { get; set; } - [MaxLength(50)] - public string IpAddress { get; set; } - public Guid? ActingUserId { get; set; } + public Event(IEvent e) + { + Date = e.Date; + Type = e.Type; + UserId = e.UserId; + OrganizationId = e.OrganizationId; + ProviderId = e.ProviderId; + CipherId = e.CipherId; + CollectionId = e.CollectionId; + PolicyId = e.PolicyId; + GroupId = e.GroupId; + OrganizationUserId = e.OrganizationUserId; + InstallationId = e.InstallationId; + ProviderUserId = e.ProviderUserId; + ProviderOrganizationId = e.ProviderOrganizationId; + DeviceType = e.DeviceType; + IpAddress = e.IpAddress; + ActingUserId = e.ActingUserId; + } - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public DateTime Date { get; set; } + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? GroupId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public DeviceType? DeviceType { get; set; } + [MaxLength(50)] + public string IpAddress { get; set; } + public Guid? ActingUserId { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/Folder.cs b/src/Core/Entities/Folder.cs index fd6d4dafa2..5fc97a3e54 100644 --- a/src/Core/Entities/Folder.cs +++ b/src/Core/Entities/Folder.cs @@ -1,17 +1,18 @@ using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Folder : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string Name { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class Folder : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string Name { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/Grant.cs b/src/Core/Entities/Grant.cs index f66ff11345..f2bd464fba 100644 --- a/src/Core/Entities/Grant.cs +++ b/src/Core/Entities/Grant.cs @@ -1,23 +1,24 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Entities; - -public class Grant +namespace Bit.Core.Entities { - [MaxLength(200)] - public string Key { get; set; } - [MaxLength(50)] - public string Type { get; set; } - [MaxLength(200)] - public string SubjectId { get; set; } - [MaxLength(100)] - public string SessionId { get; set; } - [MaxLength(200)] - public string ClientId { get; set; } - [MaxLength(200)] - public string Description { get; set; } - public DateTime CreationDate { get; set; } - public DateTime? ExpirationDate { get; set; } - public DateTime? ConsumedDate { get; set; } - public string Data { get; set; } + public class Grant + { + [MaxLength(200)] + public string Key { get; set; } + [MaxLength(50)] + public string Type { get; set; } + [MaxLength(200)] + public string SubjectId { get; set; } + [MaxLength(100)] + public string SessionId { get; set; } + [MaxLength(200)] + public string ClientId { get; set; } + [MaxLength(200)] + public string Description { get; set; } + public DateTime CreationDate { get; set; } + public DateTime? ExpirationDate { get; set; } + public DateTime? ConsumedDate { get; set; } + public string Data { get; set; } + } } diff --git a/src/Core/Entities/Group.cs b/src/Core/Entities/Group.cs index 3c15380fa4..0ca760cff6 100644 --- a/src/Core/Entities/Group.cs +++ b/src/Core/Entities/Group.cs @@ -2,22 +2,23 @@ using Bit.Core.Models; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Group : ITableObject, IExternal +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - [MaxLength(100)] - public string Name { get; set; } - public bool AccessAll { get; set; } - [MaxLength(300)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class Group : ITableObject, IExternal { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + [MaxLength(100)] + public string Name { get; set; } + public bool AccessAll { get; set; } + [MaxLength(300)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/GroupUser.cs b/src/Core/Entities/GroupUser.cs index 3497c2c744..c7933d5e7e 100644 --- a/src/Core/Entities/GroupUser.cs +++ b/src/Core/Entities/GroupUser.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Entities; - -public class GroupUser +namespace Bit.Core.Entities { - public Guid GroupId { get; set; } - public Guid OrganizationUserId { get; set; } + public class GroupUser + { + public Guid GroupId { get; set; } + public Guid OrganizationUserId { get; set; } + } } diff --git a/src/Core/Entities/IReferenceable.cs b/src/Core/Entities/IReferenceable.cs index 79837781e0..a5373978da 100644 --- a/src/Core/Entities/IReferenceable.cs +++ b/src/Core/Entities/IReferenceable.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Entities; - -public interface IReferenceable +namespace Bit.Core.Entities { - Guid Id { get; set; } - string ReferenceData { get; set; } - bool IsUser(); + public interface IReferenceable + { + Guid Id { get; set; } + string ReferenceData { get; set; } + bool IsUser(); + } } diff --git a/src/Core/Entities/IRevisable.cs b/src/Core/Entities/IRevisable.cs index bba3b3c94f..6de7478c00 100644 --- a/src/Core/Entities/IRevisable.cs +++ b/src/Core/Entities/IRevisable.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Entities; - -public interface IRevisable +namespace Bit.Core.Entities { - DateTime CreationDate { get; } - DateTime RevisionDate { get; } + public interface IRevisable + { + DateTime CreationDate { get; } + DateTime RevisionDate { get; } + } } diff --git a/src/Core/Entities/IStorable.cs b/src/Core/Entities/IStorable.cs index fd0da49fea..67c16098fb 100644 --- a/src/Core/Entities/IStorable.cs +++ b/src/Core/Entities/IStorable.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Entities; - -public interface IStorable +namespace Bit.Core.Entities { - long? Storage { get; set; } - short? MaxStorageGb { get; set; } - long StorageBytesRemaining(); - long StorageBytesRemaining(short maxStorageGb); + public interface IStorable + { + long? Storage { get; set; } + short? MaxStorageGb { get; set; } + long StorageBytesRemaining(); + long StorageBytesRemaining(short maxStorageGb); + } } diff --git a/src/Core/Entities/IStorableSubscriber.cs b/src/Core/Entities/IStorableSubscriber.cs index 27fcb25f6c..e37966dea9 100644 --- a/src/Core/Entities/IStorableSubscriber.cs +++ b/src/Core/Entities/IStorableSubscriber.cs @@ -1,4 +1,5 @@ -namespace Bit.Core.Entities; - -public interface IStorableSubscriber : IStorable, ISubscriber -{ } +namespace Bit.Core.Entities +{ + public interface IStorableSubscriber : IStorable, ISubscriber + { } +} diff --git a/src/Core/Entities/ISubscriber.cs b/src/Core/Entities/ISubscriber.cs index 6753e648e3..1c80ffc207 100644 --- a/src/Core/Entities/ISubscriber.cs +++ b/src/Core/Entities/ISubscriber.cs @@ -1,17 +1,18 @@ using Bit.Core.Enums; -namespace Bit.Core.Entities; - -public interface ISubscriber +namespace Bit.Core.Entities { - Guid Id { get; } - GatewayType? Gateway { get; set; } - string GatewayCustomerId { get; set; } - string GatewaySubscriptionId { get; set; } - string BillingEmailAddress(); - string BillingName(); - string BraintreeCustomerIdPrefix(); - string BraintreeIdField(); - string GatewayIdField(); - bool IsUser(); + public interface ISubscriber + { + Guid Id { get; } + GatewayType? Gateway { get; set; } + string GatewayCustomerId { get; set; } + string GatewaySubscriptionId { get; set; } + string BillingEmailAddress(); + string BillingName(); + string BraintreeCustomerIdPrefix(); + string BraintreeIdField(); + string GatewayIdField(); + bool IsUser(); + } } diff --git a/src/Core/Entities/ITableObject.cs b/src/Core/Entities/ITableObject.cs index 1f54b8cc17..f9ecb864b8 100644 --- a/src/Core/Entities/ITableObject.cs +++ b/src/Core/Entities/ITableObject.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Entities; - -public interface ITableObject where T : IEquatable +namespace Bit.Core.Entities { - T Id { get; set; } - void SetNewId(); + public interface ITableObject where T : IEquatable + { + T Id { get; set; } + void SetNewId(); + } } diff --git a/src/Core/Entities/Installation.cs b/src/Core/Entities/Installation.cs index a91ecef2e7..36966d8619 100644 --- a/src/Core/Entities/Installation.cs +++ b/src/Core/Entities/Installation.cs @@ -1,20 +1,21 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Installation : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - [MaxLength(256)] - public string Email { get; set; } - [MaxLength(150)] - public string Key { get; set; } - public bool Enabled { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class Installation : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + [MaxLength(256)] + public string Email { get; set; } + [MaxLength(150)] + public string Key { get; set; } + public bool Enabled { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/Organization.cs b/src/Core/Entities/Organization.cs index 823eb5bafc..818db32307 100644 --- a/src/Core/Entities/Organization.cs +++ b/src/Core/Entities/Organization.cs @@ -4,195 +4,196 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Organization : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, IReferenceable +namespace Bit.Core.Entities { - private Dictionary _twoFactorProviders; - - public Guid Id { get; set; } - [MaxLength(50)] - public string Identifier { get; set; } - [MaxLength(50)] - public string Name { get; set; } - [MaxLength(50)] - public string BusinessName { get; set; } - [MaxLength(50)] - public string BusinessAddress1 { get; set; } - [MaxLength(50)] - public string BusinessAddress2 { get; set; } - [MaxLength(50)] - public string BusinessAddress3 { get; set; } - [MaxLength(2)] - public string BusinessCountry { get; set; } - [MaxLength(30)] - public string BusinessTaxNumber { get; set; } - [MaxLength(256)] - public string BillingEmail { get; set; } - [MaxLength(50)] - public string Plan { get; set; } - public PlanType PlanType { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public long? Storage { get; set; } - public short? MaxStorageGb { get; set; } - public GatewayType? Gateway { get; set; } - [MaxLength(50)] - public string GatewayCustomerId { get; set; } - [MaxLength(50)] - public string GatewaySubscriptionId { get; set; } - public string ReferenceData { get; set; } - public bool Enabled { get; set; } = true; - [MaxLength(100)] - public string LicenseKey { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public string TwoFactorProviders { get; set; } - public DateTime? ExpirationDate { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public int? MaxAutoscaleSeats { get; set; } = null; - public DateTime? OwnersNotifiedOfAutoscaling { get; set; } = null; - - public void SetNewId() + public class Organization : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, IReferenceable { - if (Id == default(Guid)) + private Dictionary _twoFactorProviders; + + public Guid Id { get; set; } + [MaxLength(50)] + public string Identifier { get; set; } + [MaxLength(50)] + public string Name { get; set; } + [MaxLength(50)] + public string BusinessName { get; set; } + [MaxLength(50)] + public string BusinessAddress1 { get; set; } + [MaxLength(50)] + public string BusinessAddress2 { get; set; } + [MaxLength(50)] + public string BusinessAddress3 { get; set; } + [MaxLength(2)] + public string BusinessCountry { get; set; } + [MaxLength(30)] + public string BusinessTaxNumber { get; set; } + [MaxLength(256)] + public string BillingEmail { get; set; } + [MaxLength(50)] + public string Plan { get; set; } + public PlanType PlanType { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public long? Storage { get; set; } + public short? MaxStorageGb { get; set; } + public GatewayType? Gateway { get; set; } + [MaxLength(50)] + public string GatewayCustomerId { get; set; } + [MaxLength(50)] + public string GatewaySubscriptionId { get; set; } + public string ReferenceData { get; set; } + public bool Enabled { get; set; } = true; + [MaxLength(100)] + public string LicenseKey { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public string TwoFactorProviders { get; set; } + public DateTime? ExpirationDate { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public int? MaxAutoscaleSeats { get; set; } = null; + public DateTime? OwnersNotifiedOfAutoscaling { get; set; } = null; + + public void SetNewId() { - Id = CoreHelpers.GenerateComb(); - } - } - - public string BillingEmailAddress() - { - return BillingEmail?.ToLowerInvariant()?.Trim(); - } - - public string BillingName() - { - return BusinessName; - } - - public string BraintreeCustomerIdPrefix() - { - return "o"; - } - - public string BraintreeIdField() - { - return "organization_id"; - } - - public string GatewayIdField() - { - return "organizationId"; - } - - public bool IsUser() - { - return false; - } - - public long StorageBytesRemaining() - { - if (!MaxStorageGb.HasValue) - { - return 0; - } - - return StorageBytesRemaining(MaxStorageGb.Value); - } - - public long StorageBytesRemaining(short maxStorageGb) - { - var maxStorageBytes = maxStorageGb * 1073741824L; - if (!Storage.HasValue) - { - return maxStorageBytes; - } - - return maxStorageBytes - Storage.Value; - } - - public Dictionary GetTwoFactorProviders() - { - if (string.IsNullOrWhiteSpace(TwoFactorProviders)) - { - return null; - } - - try - { - if (_twoFactorProviders == null) + if (Id == default(Guid)) { - _twoFactorProviders = - JsonHelpers.LegacyDeserialize>( - TwoFactorProviders); + Id = CoreHelpers.GenerateComb(); + } + } + + public string BillingEmailAddress() + { + return BillingEmail?.ToLowerInvariant()?.Trim(); + } + + public string BillingName() + { + return BusinessName; + } + + public string BraintreeCustomerIdPrefix() + { + return "o"; + } + + public string BraintreeIdField() + { + return "organization_id"; + } + + public string GatewayIdField() + { + return "organizationId"; + } + + public bool IsUser() + { + return false; + } + + public long StorageBytesRemaining() + { + if (!MaxStorageGb.HasValue) + { + return 0; } - return _twoFactorProviders; + return StorageBytesRemaining(MaxStorageGb.Value); } - catch (JsonException) + + public long StorageBytesRemaining(short maxStorageGb) { - return null; - } - } + var maxStorageBytes = maxStorageGb * 1073741824L; + if (!Storage.HasValue) + { + return maxStorageBytes; + } - public void SetTwoFactorProviders(Dictionary providers) - { - if (!providers.Any()) + return maxStorageBytes - Storage.Value; + } + + public Dictionary GetTwoFactorProviders() { - TwoFactorProviders = null; - _twoFactorProviders = null; - return; + if (string.IsNullOrWhiteSpace(TwoFactorProviders)) + { + return null; + } + + try + { + if (_twoFactorProviders == null) + { + _twoFactorProviders = + JsonHelpers.LegacyDeserialize>( + TwoFactorProviders); + } + + return _twoFactorProviders; + } + catch (JsonException) + { + return null; + } } - TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); - _twoFactorProviders = providers; - } - - public bool TwoFactorProviderIsEnabled(TwoFactorProviderType provider) - { - var providers = GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider)) + public void SetTwoFactorProviders(Dictionary providers) { - return false; + if (!providers.Any()) + { + TwoFactorProviders = null; + _twoFactorProviders = null; + return; + } + + TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); + _twoFactorProviders = providers; } - return providers[provider].Enabled && Use2fa; - } - - public bool TwoFactorIsEnabled() - { - var providers = GetTwoFactorProviders(); - if (providers == null) + public bool TwoFactorProviderIsEnabled(TwoFactorProviderType provider) { - return false; + var providers = GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider)) + { + return false; + } + + return providers[provider].Enabled && Use2fa; } - return providers.Any(p => (p.Value?.Enabled ?? false) && Use2fa); - } - - public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) - { - var providers = GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider)) + public bool TwoFactorIsEnabled() { - return null; + var providers = GetTwoFactorProviders(); + if (providers == null) + { + return false; + } + + return providers.Any(p => (p.Value?.Enabled ?? false) && Use2fa); } - return providers[provider]; + public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) + { + var providers = GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider)) + { + return null; + } + + return providers[provider]; + } } } diff --git a/src/Core/Entities/OrganizationApiKey.cs b/src/Core/Entities/OrganizationApiKey.cs index af9f3c9122..f3a71bde2c 100644 --- a/src/Core/Entities/OrganizationApiKey.cs +++ b/src/Core/Entities/OrganizationApiKey.cs @@ -2,19 +2,20 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class OrganizationApiKey : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public OrganizationApiKeyType Type { get; set; } - [MaxLength(30)] - public string ApiKey { get; set; } - public DateTime RevisionDate { get; set; } - - public void SetNewId() + public class OrganizationApiKey : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public OrganizationApiKeyType Type { get; set; } + [MaxLength(30)] + public string ApiKey { get; set; } + public DateTime RevisionDate { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/OrganizationConnection.cs b/src/Core/Entities/OrganizationConnection.cs index cc07177384..804913fd69 100644 --- a/src/Core/Entities/OrganizationConnection.cs +++ b/src/Core/Entities/OrganizationConnection.cs @@ -2,44 +2,45 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class OrganizationConnection : OrganizationConnection where T : new() +namespace Bit.Core.Entities { - public new T Config + public class OrganizationConnection : OrganizationConnection where T : new() { - get => base.GetConfig(); - set => base.SetConfig(value); - } -} - -public class OrganizationConnection : ITableObject -{ - public Guid Id { get; set; } - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public string Config { get; set; } - - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } - - public T GetConfig() where T : new() - { - try + public new T Config { - return JsonSerializer.Deserialize(Config); - } - catch (JsonException) - { - return default; + get => base.GetConfig(); + set => base.SetConfig(value); } } - public void SetConfig(T config) where T : new() + public class OrganizationConnection : ITableObject { - Config = JsonSerializer.Serialize(config); + public Guid Id { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public string Config { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } + + public T GetConfig() where T : new() + { + try + { + return JsonSerializer.Deserialize(Config); + } + catch (JsonException) + { + return default; + } + } + + public void SetConfig(T config) where T : new() + { + Config = JsonSerializer.Serialize(config); + } } } diff --git a/src/Core/Entities/OrganizationSponsorship.cs b/src/Core/Entities/OrganizationSponsorship.cs index 8d747bd623..27d07e8f71 100644 --- a/src/Core/Entities/OrganizationSponsorship.cs +++ b/src/Core/Entities/OrganizationSponsorship.cs @@ -2,25 +2,26 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class OrganizationSponsorship : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid? SponsoringOrganizationId { get; set; } - public Guid SponsoringOrganizationUserId { get; set; } - public Guid? SponsoredOrganizationId { get; set; } - [MaxLength(256)] - public string FriendlyName { get; set; } - [MaxLength(256)] - public string OfferedToEmail { get; set; } - public PlanSponsorshipType? PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } - - public void SetNewId() + public class OrganizationSponsorship : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid? SponsoringOrganizationId { get; set; } + public Guid SponsoringOrganizationUserId { get; set; } + public Guid? SponsoredOrganizationId { get; set; } + [MaxLength(256)] + public string FriendlyName { get; set; } + [MaxLength(256)] + public string OfferedToEmail { get; set; } + public PlanSponsorshipType? PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/OrganizationUser.cs b/src/Core/Entities/OrganizationUser.cs index ee1bdc15d4..390374dec6 100644 --- a/src/Core/Entities/OrganizationUser.cs +++ b/src/Core/Entities/OrganizationUser.cs @@ -3,28 +3,29 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class OrganizationUser : ITableObject, IExternal +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - [MaxLength(256)] - public string Email { get; set; } - public string Key { get; set; } - public string ResetPasswordKey { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool AccessAll { get; set; } - [MaxLength(300)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - public string Permissions { get; set; } - - public void SetNewId() + public class OrganizationUser : ITableObject, IExternal { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + [MaxLength(256)] + public string Email { get; set; } + public string Key { get; set; } + public string ResetPasswordKey { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public bool AccessAll { get; set; } + [MaxLength(300)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + public string Permissions { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/Policy.cs b/src/Core/Entities/Policy.cs index 4863b8ccc8..7a5f958719 100644 --- a/src/Core/Entities/Policy.cs +++ b/src/Core/Entities/Policy.cs @@ -2,30 +2,31 @@ using Bit.Core.Models.Data.Organizations.Policies; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Policy : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public PolicyType Type { get; set; } - public string Data { get; set; } - public bool Enabled { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class Policy : ITableObject { - Id = CoreHelpers.GenerateComb(); - } + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public PolicyType Type { get; set; } + public string Data { get; set; } + public bool Enabled { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - public T GetDataModel() where T : IPolicyDataModel, new() - { - return CoreHelpers.LoadClassFromJsonData(Data); - } + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } - public void SetDataModel(T dataModel) where T : IPolicyDataModel, new() - { - Data = CoreHelpers.ClassToJsonData(dataModel); + public T GetDataModel() where T : IPolicyDataModel, new() + { + return CoreHelpers.LoadClassFromJsonData(Data); + } + + public void SetDataModel(T dataModel) where T : IPolicyDataModel, new() + { + Data = CoreHelpers.ClassToJsonData(dataModel); + } } } diff --git a/src/Core/Entities/Provider/Provider.cs b/src/Core/Entities/Provider/Provider.cs index 440be7d434..95da01f93b 100644 --- a/src/Core/Entities/Provider/Provider.cs +++ b/src/Core/Entities/Provider/Provider.cs @@ -1,30 +1,31 @@ using Bit.Core.Enums.Provider; using Bit.Core.Utilities; -namespace Bit.Core.Entities.Provider; - -public class Provider : ITableObject +namespace Bit.Core.Entities.Provider { - public Guid Id { get; set; } - public string Name { get; set; } - public string BusinessName { get; set; } - public string BusinessAddress1 { get; set; } - public string BusinessAddress2 { get; set; } - public string BusinessAddress3 { get; set; } - public string BusinessCountry { get; set; } - public string BusinessTaxNumber { get; set; } - public string BillingEmail { get; set; } - public ProviderStatusType Status { get; set; } - public bool UseEvents { get; set; } - public bool Enabled { get; set; } = true; - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class Provider : ITableObject { - if (Id == default) + public Guid Id { get; set; } + public string Name { get; set; } + public string BusinessName { get; set; } + public string BusinessAddress1 { get; set; } + public string BusinessAddress2 { get; set; } + public string BusinessAddress3 { get; set; } + public string BusinessCountry { get; set; } + public string BusinessTaxNumber { get; set; } + public string BillingEmail { get; set; } + public ProviderStatusType Status { get; set; } + public bool UseEvents { get; set; } + public bool Enabled { get; set; } = true; + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() { - Id = CoreHelpers.GenerateComb(); + if (Id == default) + { + Id = CoreHelpers.GenerateComb(); + } } } } diff --git a/src/Core/Entities/Provider/ProviderOrganization.cs b/src/Core/Entities/Provider/ProviderOrganization.cs index 6cafef67b7..6bb1eec54c 100644 --- a/src/Core/Entities/Provider/ProviderOrganization.cs +++ b/src/Core/Entities/Provider/ProviderOrganization.cs @@ -1,22 +1,23 @@ using Bit.Core.Utilities; -namespace Bit.Core.Entities.Provider; - -public class ProviderOrganization : ITableObject +namespace Bit.Core.Entities.Provider { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string Key { get; set; } - public string Settings { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class ProviderOrganization : ITableObject { - if (Id == default) + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid OrganizationId { get; set; } + public string Key { get; set; } + public string Settings { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() { - Id = CoreHelpers.GenerateComb(); + if (Id == default) + { + Id = CoreHelpers.GenerateComb(); + } } } } diff --git a/src/Core/Entities/Provider/ProviderUser.cs b/src/Core/Entities/Provider/ProviderUser.cs index 9b86d591c9..c3d0582da3 100644 --- a/src/Core/Entities/Provider/ProviderUser.cs +++ b/src/Core/Entities/Provider/ProviderUser.cs @@ -1,26 +1,27 @@ using Bit.Core.Enums.Provider; using Bit.Core.Utilities; -namespace Bit.Core.Entities.Provider; - -public class ProviderUser : ITableObject +namespace Bit.Core.Entities.Provider { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid? UserId { get; set; } - public string Email { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public string Permissions { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - - public void SetNewId() + public class ProviderUser : ITableObject { - if (Id == default) + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid? UserId { get; set; } + public string Email { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public string Permissions { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() { - Id = CoreHelpers.GenerateComb(); + if (Id == default) + { + Id = CoreHelpers.GenerateComb(); + } } } } diff --git a/src/Core/Entities/Role.cs b/src/Core/Entities/Role.cs index 5e1f6319c2..2acdb1c65f 100644 --- a/src/Core/Entities/Role.cs +++ b/src/Core/Entities/Role.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Entities; - -/// -/// This class is not used. It is implemented to make the Identity provider happy. -/// -public class Role +namespace Bit.Core.Entities { - public string Name { get; set; } + /// + /// This class is not used. It is implemented to make the Identity provider happy. + /// + public class Role + { + public string Name { get; set; } + } } diff --git a/src/Core/Entities/Send.cs b/src/Core/Entities/Send.cs index 7cc8f3b257..cbe2006e8e 100644 --- a/src/Core/Entities/Send.cs +++ b/src/Core/Entities/Send.cs @@ -2,29 +2,30 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Send : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public SendType Type { get; set; } - public string Data { get; set; } - public string Key { get; set; } - [MaxLength(300)] - public string Password { get; set; } - public int? MaxAccessCount { get; set; } - public int AccessCount { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - public DateTime? ExpirationDate { get; set; } - public DateTime DeletionDate { get; set; } - public bool Disabled { get; set; } - public bool? HideEmail { get; set; } - - public void SetNewId() + public class Send : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public SendType Type { get; set; } + public string Data { get; set; } + public string Key { get; set; } + [MaxLength(300)] + public string Password { get; set; } + public int? MaxAccessCount { get; set; } + public int AccessCount { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + public DateTime? ExpirationDate { get; set; } + public DateTime DeletionDate { get; set; } + public bool Disabled { get; set; } + public bool? HideEmail { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/SsoConfig.cs b/src/Core/Entities/SsoConfig.cs index 09f3697b7e..63bf9173ca 100644 --- a/src/Core/Entities/SsoConfig.cs +++ b/src/Core/Entities/SsoConfig.cs @@ -1,29 +1,30 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Entities; - -public class SsoConfig : ITableObject +namespace Bit.Core.Entities { - public long Id { get; set; } - public bool Enabled { get; set; } = true; - public Guid OrganizationId { get; set; } - public string Data { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class SsoConfig : ITableObject { - // int will be auto-populated - Id = 0; - } + public long Id { get; set; } + public bool Enabled { get; set; } = true; + public Guid OrganizationId { get; set; } + public string Data { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - public SsoConfigurationData GetData() - { - return SsoConfigurationData.Deserialize(Data); - } + public void SetNewId() + { + // int will be auto-populated + Id = 0; + } - public void SetData(SsoConfigurationData data) - { - Data = data.Serialize(); + public SsoConfigurationData GetData() + { + return SsoConfigurationData.Deserialize(Data); + } + + public void SetData(SsoConfigurationData data) + { + Data = data.Serialize(); + } } } diff --git a/src/Core/Entities/SsoUser.cs b/src/Core/Entities/SsoUser.cs index 6bc32c20d5..47818e2bdc 100644 --- a/src/Core/Entities/SsoUser.cs +++ b/src/Core/Entities/SsoUser.cs @@ -1,19 +1,20 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Entities; - -public class SsoUser : ITableObject +namespace Bit.Core.Entities { - public long Id { get; set; } - public Guid UserId { get; set; } - public Guid? OrganizationId { get; set; } - [MaxLength(50)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - - public void SetNewId() + public class SsoUser : ITableObject { - // int will be auto-populated - Id = 0; + public long Id { get; set; } + public Guid UserId { get; set; } + public Guid? OrganizationId { get; set; } + [MaxLength(50)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + // int will be auto-populated + Id = 0; + } } } diff --git a/src/Core/Entities/TaxRate.cs b/src/Core/Entities/TaxRate.cs index a04ccf445c..bf53c8cf04 100644 --- a/src/Core/Entities/TaxRate.cs +++ b/src/Core/Entities/TaxRate.cs @@ -1,23 +1,24 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Entities; - -public class TaxRate : ITableObject +namespace Bit.Core.Entities { - [MaxLength(40)] - public string Id { get; set; } - [MaxLength(50)] - public string Country { get; set; } - [MaxLength(2)] - public string State { get; set; } - [MaxLength(10)] - public string PostalCode { get; set; } - public decimal Rate { get; set; } - public bool Active { get; set; } - - public void SetNewId() + public class TaxRate : ITableObject { - // Id is created by Stripe, should exist before this gets called - return; + [MaxLength(40)] + public string Id { get; set; } + [MaxLength(50)] + public string Country { get; set; } + [MaxLength(2)] + public string State { get; set; } + [MaxLength(10)] + public string PostalCode { get; set; } + public decimal Rate { get; set; } + public bool Active { get; set; } + + public void SetNewId() + { + // Id is created by Stripe, should exist before this gets called + return; + } } } diff --git a/src/Core/Entities/Transaction.cs b/src/Core/Entities/Transaction.cs index f82b76a12a..b2a01908cb 100644 --- a/src/Core/Entities/Transaction.cs +++ b/src/Core/Entities/Transaction.cs @@ -2,27 +2,28 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities; - -public class Transaction : ITableObject +namespace Bit.Core.Entities { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public TransactionType Type { get; set; } - public decimal Amount { get; set; } - public bool? Refunded { get; set; } - public decimal? RefundedAmount { get; set; } - [MaxLength(100)] - public string Details { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public GatewayType? Gateway { get; set; } - [MaxLength(50)] - public string GatewayId { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - - public void SetNewId() + public class Transaction : ITableObject { - Id = CoreHelpers.GenerateComb(); + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public TransactionType Type { get; set; } + public decimal Amount { get; set; } + public bool? Refunded { get; set; } + public decimal? RefundedAmount { get; set; } + [MaxLength(100)] + public string Details { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public GatewayType? Gateway { get; set; } + [MaxLength(50)] + public string GatewayId { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } } } diff --git a/src/Core/Entities/User.cs b/src/Core/Entities/User.cs index 5236fe249d..e5d79c7226 100644 --- a/src/Core/Entities/User.cs +++ b/src/Core/Entities/User.cs @@ -5,188 +5,189 @@ using Bit.Core.Models; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Entities; - -public class User : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, ITwoFactorProvidersUser, IReferenceable +namespace Bit.Core.Entities { - private Dictionary _twoFactorProviders; - - public Guid Id { get; set; } - [MaxLength(50)] - public string Name { get; set; } - [Required] - [MaxLength(256)] - public string Email { get; set; } - public bool EmailVerified { get; set; } - [MaxLength(300)] - public string MasterPassword { get; set; } - [MaxLength(50)] - public string MasterPasswordHint { get; set; } - [MaxLength(10)] - public string Culture { get; set; } = "en-US"; - [Required] - [MaxLength(50)] - public string SecurityStamp { get; set; } - public string TwoFactorProviders { get; set; } - [MaxLength(32)] - public string TwoFactorRecoveryCode { get; set; } - public string EquivalentDomains { get; set; } - public string ExcludedGlobalEquivalentDomains { get; set; } - public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow; - public string Key { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public bool Premium { get; set; } - public DateTime? PremiumExpirationDate { get; set; } - public DateTime? RenewalReminderDate { get; set; } - public long? Storage { get; set; } - public short? MaxStorageGb { get; set; } - public GatewayType? Gateway { get; set; } - [MaxLength(50)] - public string GatewayCustomerId { get; set; } - [MaxLength(50)] - public string GatewaySubscriptionId { get; set; } - public string ReferenceData { get; set; } - [MaxLength(100)] - public string LicenseKey { get; set; } - [Required] - [MaxLength(30)] - public string ApiKey { get; set; } - public KdfType Kdf { get; set; } = KdfType.PBKDF2_SHA256; - public int KdfIterations { get; set; } = 5000; - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public bool ForcePasswordReset { get; set; } - public bool UsesKeyConnector { get; set; } - public int FailedLoginCount { get; set; } - public DateTime? LastFailedLoginDate { get; set; } - public bool UnknownDeviceVerificationEnabled { get; set; } - - public void SetNewId() + public class User : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, ITwoFactorProvidersUser, IReferenceable { - Id = CoreHelpers.GenerateComb(); - } + private Dictionary _twoFactorProviders; - public string BillingEmailAddress() - { - return Email?.ToLowerInvariant()?.Trim(); - } + public Guid Id { get; set; } + [MaxLength(50)] + public string Name { get; set; } + [Required] + [MaxLength(256)] + public string Email { get; set; } + public bool EmailVerified { get; set; } + [MaxLength(300)] + public string MasterPassword { get; set; } + [MaxLength(50)] + public string MasterPasswordHint { get; set; } + [MaxLength(10)] + public string Culture { get; set; } = "en-US"; + [Required] + [MaxLength(50)] + public string SecurityStamp { get; set; } + public string TwoFactorProviders { get; set; } + [MaxLength(32)] + public string TwoFactorRecoveryCode { get; set; } + public string EquivalentDomains { get; set; } + public string ExcludedGlobalEquivalentDomains { get; set; } + public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow; + public string Key { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public bool Premium { get; set; } + public DateTime? PremiumExpirationDate { get; set; } + public DateTime? RenewalReminderDate { get; set; } + public long? Storage { get; set; } + public short? MaxStorageGb { get; set; } + public GatewayType? Gateway { get; set; } + [MaxLength(50)] + public string GatewayCustomerId { get; set; } + [MaxLength(50)] + public string GatewaySubscriptionId { get; set; } + public string ReferenceData { get; set; } + [MaxLength(100)] + public string LicenseKey { get; set; } + [Required] + [MaxLength(30)] + public string ApiKey { get; set; } + public KdfType Kdf { get; set; } = KdfType.PBKDF2_SHA256; + public int KdfIterations { get; set; } = 5000; + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public bool ForcePasswordReset { get; set; } + public bool UsesKeyConnector { get; set; } + public int FailedLoginCount { get; set; } + public DateTime? LastFailedLoginDate { get; set; } + public bool UnknownDeviceVerificationEnabled { get; set; } - public string BillingName() - { - return Name; - } - - public string BraintreeCustomerIdPrefix() - { - return "u"; - } - - public string BraintreeIdField() - { - return "user_id"; - } - - public string GatewayIdField() - { - return "userId"; - } - - public bool IsUser() - { - return true; - } - - public Dictionary GetTwoFactorProviders() - { - if (string.IsNullOrWhiteSpace(TwoFactorProviders)) + public void SetNewId() { - return null; + Id = CoreHelpers.GenerateComb(); } - try + public string BillingEmailAddress() { - if (_twoFactorProviders == null) + return Email?.ToLowerInvariant()?.Trim(); + } + + public string BillingName() + { + return Name; + } + + public string BraintreeCustomerIdPrefix() + { + return "u"; + } + + public string BraintreeIdField() + { + return "user_id"; + } + + public string GatewayIdField() + { + return "userId"; + } + + public bool IsUser() + { + return true; + } + + public Dictionary GetTwoFactorProviders() + { + if (string.IsNullOrWhiteSpace(TwoFactorProviders)) { - _twoFactorProviders = - JsonHelpers.LegacyDeserialize>( - TwoFactorProviders); + return null; } - return _twoFactorProviders; - } - catch (JsonException) - { - return null; - } - } + try + { + if (_twoFactorProviders == null) + { + _twoFactorProviders = + JsonHelpers.LegacyDeserialize>( + TwoFactorProviders); + } - public Guid? GetUserId() - { - return Id; - } - - public bool GetPremium() - { - return Premium; - } - - public void SetTwoFactorProviders(Dictionary providers) - { - // When replacing with system.text remember to remove the extra serialization in WebAuthnTokenProvider. - TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); - _twoFactorProviders = providers; - } - - public void ClearTwoFactorProviders() - { - SetTwoFactorProviders(new Dictionary()); - } - - public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) - { - var providers = GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider)) - { - return null; + return _twoFactorProviders; + } + catch (JsonException) + { + return null; + } } - return providers[provider]; - } - - public long StorageBytesRemaining() - { - if (!MaxStorageGb.HasValue) + public Guid? GetUserId() { - return 0; + return Id; } - return StorageBytesRemaining(MaxStorageGb.Value); - } - - public long StorageBytesRemaining(short maxStorageGb) - { - var maxStorageBytes = maxStorageGb * 1073741824L; - if (!Storage.HasValue) + public bool GetPremium() { - return maxStorageBytes; + return Premium; } - return maxStorageBytes - Storage.Value; - } - - public IdentityUser ToIdentityUser(bool twoFactorEnabled) - { - return new IdentityUser + public void SetTwoFactorProviders(Dictionary providers) { - Id = Id.ToString(), - Email = Email, - NormalizedEmail = Email, - EmailConfirmed = EmailVerified, - UserName = Email, - NormalizedUserName = Email, - TwoFactorEnabled = twoFactorEnabled, - SecurityStamp = SecurityStamp - }; + // When replacing with system.text remember to remove the extra serialization in WebAuthnTokenProvider. + TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); + _twoFactorProviders = providers; + } + + public void ClearTwoFactorProviders() + { + SetTwoFactorProviders(new Dictionary()); + } + + public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) + { + var providers = GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider)) + { + return null; + } + + return providers[provider]; + } + + public long StorageBytesRemaining() + { + if (!MaxStorageGb.HasValue) + { + return 0; + } + + return StorageBytesRemaining(MaxStorageGb.Value); + } + + public long StorageBytesRemaining(short maxStorageGb) + { + var maxStorageBytes = maxStorageGb * 1073741824L; + if (!Storage.HasValue) + { + return maxStorageBytes; + } + + return maxStorageBytes - Storage.Value; + } + + public IdentityUser ToIdentityUser(bool twoFactorEnabled) + { + return new IdentityUser + { + Id = Id.ToString(), + Email = Email, + NormalizedEmail = Email, + EmailConfirmed = EmailVerified, + UserName = Email, + NormalizedUserName = Email, + TwoFactorEnabled = twoFactorEnabled, + SecurityStamp = SecurityStamp + }; + } } } diff --git a/src/Core/Enums/ApplicationCacheMessageType.cs b/src/Core/Enums/ApplicationCacheMessageType.cs index 94889ed4eb..b91b079953 100644 --- a/src/Core/Enums/ApplicationCacheMessageType.cs +++ b/src/Core/Enums/ApplicationCacheMessageType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum ApplicationCacheMessageType : byte +namespace Bit.Core.Enums { - UpsertOrganizationAbility = 0, - DeleteOrganizationAbility = 1 + public enum ApplicationCacheMessageType : byte + { + UpsertOrganizationAbility = 0, + DeleteOrganizationAbility = 1 + } } diff --git a/src/Core/Enums/BitwardenClient.cs b/src/Core/Enums/BitwardenClient.cs index 6a1244c0c4..067eef92b5 100644 --- a/src/Core/Enums/BitwardenClient.cs +++ b/src/Core/Enums/BitwardenClient.cs @@ -1,12 +1,13 @@ -namespace Bit.Core.Enums; - -public static class BitwardenClient +namespace Bit.Core.Enums { - public const string - Web = "web", - Browser = "browser", - Desktop = "desktop", - Mobile = "mobile", - Cli = "cli", - DirectoryConnector = "connector"; + public static class BitwardenClient + { + public const string + Web = "web", + Browser = "browser", + Desktop = "desktop", + Mobile = "mobile", + Cli = "cli", + DirectoryConnector = "connector"; + } } diff --git a/src/Core/Enums/CipherRepromptType.cs b/src/Core/Enums/CipherRepromptType.cs index 3c64c19450..0e5b60ff20 100644 --- a/src/Core/Enums/CipherRepromptType.cs +++ b/src/Core/Enums/CipherRepromptType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum CipherRepromptType : byte +namespace Bit.Core.Enums { - None = 0, - Password = 1, + public enum CipherRepromptType : byte + { + None = 0, + Password = 1, + } } diff --git a/src/Core/Enums/CipherStateAction.cs b/src/Core/Enums/CipherStateAction.cs index 926c8b06c7..87b73a41c7 100644 --- a/src/Core/Enums/CipherStateAction.cs +++ b/src/Core/Enums/CipherStateAction.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Enums; - -public enum CipherStateAction +namespace Bit.Core.Enums { - Restore, - SoftDelete, - HardDelete, + public enum CipherStateAction + { + Restore, + SoftDelete, + HardDelete, + } } diff --git a/src/Core/Enums/CipherType.cs b/src/Core/Enums/CipherType.cs index d9f37bcbc6..0aca948640 100644 --- a/src/Core/Enums/CipherType.cs +++ b/src/Core/Enums/CipherType.cs @@ -1,11 +1,12 @@ -namespace Bit.Core.Enums; - -public enum CipherType : byte +namespace Bit.Core.Enums { - // Folder is deprecated - //Folder = 0, - Login = 1, - SecureNote = 2, - Card = 3, - Identity = 4 + public enum CipherType : byte + { + // Folder is deprecated + //Folder = 0, + Login = 1, + SecureNote = 2, + Card = 3, + Identity = 4 + } } diff --git a/src/Core/Enums/DeviceType.cs b/src/Core/Enums/DeviceType.cs index 361d9ac38b..53aa21c76e 100644 --- a/src/Core/Enums/DeviceType.cs +++ b/src/Core/Enums/DeviceType.cs @@ -1,49 +1,50 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum DeviceType : byte +namespace Bit.Core.Enums { - [Display(Name = "Android")] - Android = 0, - [Display(Name = "iOS")] - iOS = 1, - [Display(Name = "Chrome Extension")] - ChromeExtension = 2, - [Display(Name = "Firefox Extension")] - FirefoxExtension = 3, - [Display(Name = "Opera Extension")] - OperaExtension = 4, - [Display(Name = "Edge Extension")] - EdgeExtension = 5, - [Display(Name = "Windows")] - WindowsDesktop = 6, - [Display(Name = "macOS")] - MacOsDesktop = 7, - [Display(Name = "Linux")] - LinuxDesktop = 8, - [Display(Name = "Chrome")] - ChromeBrowser = 9, - [Display(Name = "Firefox")] - FirefoxBrowser = 10, - [Display(Name = "Opera")] - OperaBrowser = 11, - [Display(Name = "Edge")] - EdgeBrowser = 12, - [Display(Name = "Internet Explorer")] - IEBrowser = 13, - [Display(Name = "Unknown Browser")] - UnknownBrowser = 14, - [Display(Name = "Android")] - AndroidAmazon = 15, - [Display(Name = "UWP")] - UWP = 16, - [Display(Name = "Safari")] - SafariBrowser = 17, - [Display(Name = "Vivaldi")] - VivaldiBrowser = 18, - [Display(Name = "Vivaldi Extension")] - VivaldiExtension = 19, - [Display(Name = "Safari Extension")] - SafariExtension = 20 + public enum DeviceType : byte + { + [Display(Name = "Android")] + Android = 0, + [Display(Name = "iOS")] + iOS = 1, + [Display(Name = "Chrome Extension")] + ChromeExtension = 2, + [Display(Name = "Firefox Extension")] + FirefoxExtension = 3, + [Display(Name = "Opera Extension")] + OperaExtension = 4, + [Display(Name = "Edge Extension")] + EdgeExtension = 5, + [Display(Name = "Windows")] + WindowsDesktop = 6, + [Display(Name = "macOS")] + MacOsDesktop = 7, + [Display(Name = "Linux")] + LinuxDesktop = 8, + [Display(Name = "Chrome")] + ChromeBrowser = 9, + [Display(Name = "Firefox")] + FirefoxBrowser = 10, + [Display(Name = "Opera")] + OperaBrowser = 11, + [Display(Name = "Edge")] + EdgeBrowser = 12, + [Display(Name = "Internet Explorer")] + IEBrowser = 13, + [Display(Name = "Unknown Browser")] + UnknownBrowser = 14, + [Display(Name = "Android")] + AndroidAmazon = 15, + [Display(Name = "UWP")] + UWP = 16, + [Display(Name = "Safari")] + SafariBrowser = 17, + [Display(Name = "Vivaldi")] + VivaldiBrowser = 18, + [Display(Name = "Vivaldi Extension")] + VivaldiExtension = 19, + [Display(Name = "Safari Extension")] + SafariExtension = 20 + } } diff --git a/src/Core/Enums/EmergencyAccessStatusType.cs b/src/Core/Enums/EmergencyAccessStatusType.cs index 79fca334e6..2c5b472a92 100644 --- a/src/Core/Enums/EmergencyAccessStatusType.cs +++ b/src/Core/Enums/EmergencyAccessStatusType.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Enums; - -public enum EmergencyAccessStatusType : byte +namespace Bit.Core.Enums { - Invited = 0, - Accepted = 1, - Confirmed = 2, - RecoveryInitiated = 3, - RecoveryApproved = 4, + public enum EmergencyAccessStatusType : byte + { + Invited = 0, + Accepted = 1, + Confirmed = 2, + RecoveryInitiated = 3, + RecoveryApproved = 4, + } } diff --git a/src/Core/Enums/EmergencyAccessType.cs b/src/Core/Enums/EmergencyAccessType.cs index 5742bb5314..d622857aad 100644 --- a/src/Core/Enums/EmergencyAccessType.cs +++ b/src/Core/Enums/EmergencyAccessType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum EmergencyAccessType : byte +namespace Bit.Core.Enums { - View = 0, - Takeover = 1, + public enum EmergencyAccessType : byte + { + View = 0, + Takeover = 1, + } } diff --git a/src/Core/Enums/EncryptionType.cs b/src/Core/Enums/EncryptionType.cs index a37110911f..2b6eaf086c 100644 --- a/src/Core/Enums/EncryptionType.cs +++ b/src/Core/Enums/EncryptionType.cs @@ -1,12 +1,13 @@ -namespace Bit.Core.Enums; - -public enum EncryptionType : byte +namespace Bit.Core.Enums { - AesCbc256_B64 = 0, - AesCbc128_HmacSha256_B64 = 1, - AesCbc256_HmacSha256_B64 = 2, - Rsa2048_OaepSha256_B64 = 3, - Rsa2048_OaepSha1_B64 = 4, - Rsa2048_OaepSha256_HmacSha256_B64 = 5, - Rsa2048_OaepSha1_HmacSha256_B64 = 6 + public enum EncryptionType : byte + { + AesCbc256_B64 = 0, + AesCbc128_HmacSha256_B64 = 1, + AesCbc256_HmacSha256_B64 = 2, + Rsa2048_OaepSha256_B64 = 3, + Rsa2048_OaepSha1_B64 = 4, + Rsa2048_OaepSha256_HmacSha256_B64 = 5, + Rsa2048_OaepSha1_HmacSha256_B64 = 6 + } } diff --git a/src/Core/Enums/EventType.cs b/src/Core/Enums/EventType.cs index 09a1afffd9..98d8440088 100644 --- a/src/Core/Enums/EventType.cs +++ b/src/Core/Enums/EventType.cs @@ -1,78 +1,79 @@ -namespace Bit.Core.Enums; - -public enum EventType : int +namespace Bit.Core.Enums { - User_LoggedIn = 1000, - User_ChangedPassword = 1001, - User_Updated2fa = 1002, - User_Disabled2fa = 1003, - User_Recovered2fa = 1004, - User_FailedLogIn = 1005, - User_FailedLogIn2fa = 1006, - User_ClientExportedVault = 1007, - User_UpdatedTempPassword = 1008, - User_MigratedKeyToKeyConnector = 1009, + public enum EventType : int + { + User_LoggedIn = 1000, + User_ChangedPassword = 1001, + User_Updated2fa = 1002, + User_Disabled2fa = 1003, + User_Recovered2fa = 1004, + User_FailedLogIn = 1005, + User_FailedLogIn2fa = 1006, + User_ClientExportedVault = 1007, + User_UpdatedTempPassword = 1008, + User_MigratedKeyToKeyConnector = 1009, - Cipher_Created = 1100, - Cipher_Updated = 1101, - Cipher_Deleted = 1102, - Cipher_AttachmentCreated = 1103, - Cipher_AttachmentDeleted = 1104, - Cipher_Shared = 1105, - Cipher_UpdatedCollections = 1106, - Cipher_ClientViewed = 1107, - Cipher_ClientToggledPasswordVisible = 1108, - Cipher_ClientToggledHiddenFieldVisible = 1109, - Cipher_ClientToggledCardCodeVisible = 1110, - Cipher_ClientCopiedPassword = 1111, - Cipher_ClientCopiedHiddenField = 1112, - Cipher_ClientCopiedCardCode = 1113, - Cipher_ClientAutofilled = 1114, - Cipher_SoftDeleted = 1115, - Cipher_Restored = 1116, - Cipher_ClientToggledCardNumberVisible = 1117, + Cipher_Created = 1100, + Cipher_Updated = 1101, + Cipher_Deleted = 1102, + Cipher_AttachmentCreated = 1103, + Cipher_AttachmentDeleted = 1104, + Cipher_Shared = 1105, + Cipher_UpdatedCollections = 1106, + Cipher_ClientViewed = 1107, + Cipher_ClientToggledPasswordVisible = 1108, + Cipher_ClientToggledHiddenFieldVisible = 1109, + Cipher_ClientToggledCardCodeVisible = 1110, + Cipher_ClientCopiedPassword = 1111, + Cipher_ClientCopiedHiddenField = 1112, + Cipher_ClientCopiedCardCode = 1113, + Cipher_ClientAutofilled = 1114, + Cipher_SoftDeleted = 1115, + Cipher_Restored = 1116, + Cipher_ClientToggledCardNumberVisible = 1117, - Collection_Created = 1300, - Collection_Updated = 1301, - Collection_Deleted = 1302, + Collection_Created = 1300, + Collection_Updated = 1301, + Collection_Deleted = 1302, - Group_Created = 1400, - Group_Updated = 1401, - Group_Deleted = 1402, + Group_Created = 1400, + Group_Updated = 1401, + Group_Deleted = 1402, - OrganizationUser_Invited = 1500, - OrganizationUser_Confirmed = 1501, - OrganizationUser_Updated = 1502, - OrganizationUser_Removed = 1503, - OrganizationUser_UpdatedGroups = 1504, - OrganizationUser_UnlinkedSso = 1505, - OrganizationUser_ResetPassword_Enroll = 1506, - OrganizationUser_ResetPassword_Withdraw = 1507, - OrganizationUser_AdminResetPassword = 1508, - OrganizationUser_ResetSsoLink = 1509, - OrganizationUser_FirstSsoLogin = 1510, - OrganizationUser_Revoked = 1511, - OrganizationUser_Restored = 1512, + OrganizationUser_Invited = 1500, + OrganizationUser_Confirmed = 1501, + OrganizationUser_Updated = 1502, + OrganizationUser_Removed = 1503, + OrganizationUser_UpdatedGroups = 1504, + OrganizationUser_UnlinkedSso = 1505, + OrganizationUser_ResetPassword_Enroll = 1506, + OrganizationUser_ResetPassword_Withdraw = 1507, + OrganizationUser_AdminResetPassword = 1508, + OrganizationUser_ResetSsoLink = 1509, + OrganizationUser_FirstSsoLogin = 1510, + OrganizationUser_Revoked = 1511, + OrganizationUser_Restored = 1512, - Organization_Updated = 1600, - Organization_PurgedVault = 1601, - Organization_ClientExportedVault = 1602, - Organization_VaultAccessed = 1603, - Organization_EnabledSso = 1604, - Organization_DisabledSso = 1605, - Organization_EnabledKeyConnector = 1606, - Organization_DisabledKeyConnector = 1607, - Organization_SponsorshipsSynced = 1608, + Organization_Updated = 1600, + Organization_PurgedVault = 1601, + Organization_ClientExportedVault = 1602, + Organization_VaultAccessed = 1603, + Organization_EnabledSso = 1604, + Organization_DisabledSso = 1605, + Organization_EnabledKeyConnector = 1606, + Organization_DisabledKeyConnector = 1607, + Organization_SponsorshipsSynced = 1608, - Policy_Updated = 1700, + Policy_Updated = 1700, - ProviderUser_Invited = 1800, - ProviderUser_Confirmed = 1801, - ProviderUser_Updated = 1802, - ProviderUser_Removed = 1803, + ProviderUser_Invited = 1800, + ProviderUser_Confirmed = 1801, + ProviderUser_Updated = 1802, + ProviderUser_Removed = 1803, - ProviderOrganization_Created = 1900, - ProviderOrganization_Added = 1901, - ProviderOrganization_Removed = 1902, - ProviderOrganization_VaultAccessed = 1903, + ProviderOrganization_Created = 1900, + ProviderOrganization_Added = 1901, + ProviderOrganization_Removed = 1902, + ProviderOrganization_VaultAccessed = 1903, + } } diff --git a/src/Core/Enums/FieldType.cs b/src/Core/Enums/FieldType.cs index 4642b63a81..5eef485b7a 100644 --- a/src/Core/Enums/FieldType.cs +++ b/src/Core/Enums/FieldType.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Enums; - -public enum FieldType : byte +namespace Bit.Core.Enums { - Text = 0, - Hidden = 1, - Boolean = 2, - Linked = 3, + public enum FieldType : byte + { + Text = 0, + Hidden = 1, + Boolean = 2, + Linked = 3, + } } diff --git a/src/Core/Enums/FileUploadType.cs b/src/Core/Enums/FileUploadType.cs index 4d32589b6a..4bdefd4dd0 100644 --- a/src/Core/Enums/FileUploadType.cs +++ b/src/Core/Enums/FileUploadType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum FileUploadType +namespace Bit.Core.Enums { - Direct = 0, - Azure = 1, + public enum FileUploadType + { + Direct = 0, + Azure = 1, + } } diff --git a/src/Core/Enums/GatewayType.cs b/src/Core/Enums/GatewayType.cs index 5ad73cf0f0..68c959ad73 100644 --- a/src/Core/Enums/GatewayType.cs +++ b/src/Core/Enums/GatewayType.cs @@ -1,21 +1,22 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum GatewayType : byte +namespace Bit.Core.Enums { - [Display(Name = "Stripe")] - Stripe = 0, - [Display(Name = "Braintree")] - Braintree = 1, - [Display(Name = "Apple App Store")] - AppStore = 2, - [Display(Name = "Google Play Store")] - PlayStore = 3, - [Display(Name = "BitPay")] - BitPay = 4, - [Display(Name = "PayPal")] - PayPal = 5, - [Display(Name = "Bank")] - Bank = 6, + public enum GatewayType : byte + { + [Display(Name = "Stripe")] + Stripe = 0, + [Display(Name = "Braintree")] + Braintree = 1, + [Display(Name = "Apple App Store")] + AppStore = 2, + [Display(Name = "Google Play Store")] + PlayStore = 3, + [Display(Name = "BitPay")] + BitPay = 4, + [Display(Name = "PayPal")] + PayPal = 5, + [Display(Name = "Bank")] + Bank = 6, + } } diff --git a/src/Core/Enums/GlobalEquivalentDomainsType.cs b/src/Core/Enums/GlobalEquivalentDomainsType.cs index 1291736d72..22b0cdd3a6 100644 --- a/src/Core/Enums/GlobalEquivalentDomainsType.cs +++ b/src/Core/Enums/GlobalEquivalentDomainsType.cs @@ -1,94 +1,95 @@ -namespace Bit.Core.Enums; - -public enum GlobalEquivalentDomainsType : byte +namespace Bit.Core.Enums { - Google = 0, - Apple = 1, - Ameritrade = 2, - BoA = 3, - Sprint = 4, - WellsFargo = 5, - Merrill = 6, - Citi = 7, - Cnet = 8, - Gap = 9, - Microsoft = 10, - United = 11, - Yahoo = 12, - Zonelabs = 13, - PayPal = 14, - Avon = 15, - Diapers = 16, - Contacts = 17, - Amazon = 18, - Cox = 19, - Norton = 20, - Verizon = 21, - Buy = 22, - Sirius = 23, - Ea = 24, - Basecamp = 25, - Steam = 26, - Chart = 27, - Gotomeeting = 28, - Gogo = 29, - Oracle = 30, - Discover = 31, - Dcu = 32, - Healthcare = 33, - Pepco = 34, - Century21 = 35, - Comcast = 36, - Cricket = 37, - Mtb = 38, - Dropbox = 39, - Snapfish = 40, - Alibaba = 41, - Playstation = 42, - Mercado = 43, - Zendesk = 44, - Autodesk = 45, - RailNation = 46, - Wpcu = 47, - Mathletics = 48, - Discountbank = 49, - Mi = 50, - Facebook = 51, - Postepay = 52, - Skysports = 53, - Disney = 54, - Pokemon = 55, - Uv = 56, - Yahavo = 57, - Mdsol = 58, - Sears = 59, - Xiami = 60, - Belkin = 61, - Turbotax = 62, - Shopify = 63, - Ebay = 64, - Techdata = 65, - Schwab = 66, - Mozilla = 67, // deprecated - Tesla = 68, - MorganStanley = 69, - TaxAct = 70, - Wikimedia = 71, - Airbnb = 72, - Eventbrite = 73, - StackExchange = 74, - Docusign = 75, - Envato = 76, - X10Hosting = 77, - Cisco = 78, - CedarFair = 79, - Ubiquiti = 80, - Discord = 81, - Netcup = 82, - Yandex = 83, - Sony = 84, - Proton = 85, - Ubisoft = 86, - TransferWise = 87, - TakeawayEU = 88, + public enum GlobalEquivalentDomainsType : byte + { + Google = 0, + Apple = 1, + Ameritrade = 2, + BoA = 3, + Sprint = 4, + WellsFargo = 5, + Merrill = 6, + Citi = 7, + Cnet = 8, + Gap = 9, + Microsoft = 10, + United = 11, + Yahoo = 12, + Zonelabs = 13, + PayPal = 14, + Avon = 15, + Diapers = 16, + Contacts = 17, + Amazon = 18, + Cox = 19, + Norton = 20, + Verizon = 21, + Buy = 22, + Sirius = 23, + Ea = 24, + Basecamp = 25, + Steam = 26, + Chart = 27, + Gotomeeting = 28, + Gogo = 29, + Oracle = 30, + Discover = 31, + Dcu = 32, + Healthcare = 33, + Pepco = 34, + Century21 = 35, + Comcast = 36, + Cricket = 37, + Mtb = 38, + Dropbox = 39, + Snapfish = 40, + Alibaba = 41, + Playstation = 42, + Mercado = 43, + Zendesk = 44, + Autodesk = 45, + RailNation = 46, + Wpcu = 47, + Mathletics = 48, + Discountbank = 49, + Mi = 50, + Facebook = 51, + Postepay = 52, + Skysports = 53, + Disney = 54, + Pokemon = 55, + Uv = 56, + Yahavo = 57, + Mdsol = 58, + Sears = 59, + Xiami = 60, + Belkin = 61, + Turbotax = 62, + Shopify = 63, + Ebay = 64, + Techdata = 65, + Schwab = 66, + Mozilla = 67, // deprecated + Tesla = 68, + MorganStanley = 69, + TaxAct = 70, + Wikimedia = 71, + Airbnb = 72, + Eventbrite = 73, + StackExchange = 74, + Docusign = 75, + Envato = 76, + X10Hosting = 77, + Cisco = 78, + CedarFair = 79, + Ubiquiti = 80, + Discord = 81, + Netcup = 82, + Yandex = 83, + Sony = 84, + Proton = 85, + Ubisoft = 86, + TransferWise = 87, + TakeawayEU = 88, + } } diff --git a/src/Core/Enums/KdfType.cs b/src/Core/Enums/KdfType.cs index 212794eac6..1c845846ae 100644 --- a/src/Core/Enums/KdfType.cs +++ b/src/Core/Enums/KdfType.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Enums; - -public enum KdfType : byte +namespace Bit.Core.Enums { - PBKDF2_SHA256 = 0 + public enum KdfType : byte + { + PBKDF2_SHA256 = 0 + } } diff --git a/src/Core/Enums/LicenseType.cs b/src/Core/Enums/LicenseType.cs index 90ca0d7a68..60d622b9c9 100644 --- a/src/Core/Enums/LicenseType.cs +++ b/src/Core/Enums/LicenseType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum LicenseType : byte +namespace Bit.Core.Enums { - User = 0, - Organization = 1, + public enum LicenseType : byte + { + User = 0, + Organization = 1, + } } diff --git a/src/Core/Enums/OrganizationApiKeyType.cs b/src/Core/Enums/OrganizationApiKeyType.cs index 8fdbf931aa..153079cf21 100644 --- a/src/Core/Enums/OrganizationApiKeyType.cs +++ b/src/Core/Enums/OrganizationApiKeyType.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Enums; - -public enum OrganizationApiKeyType : byte +namespace Bit.Core.Enums { - Default = 0, - BillingSync = 1, - Scim = 2, + public enum OrganizationApiKeyType : byte + { + Default = 0, + BillingSync = 1, + Scim = 2, + } } diff --git a/src/Core/Enums/OrganizationConnectionType.cs b/src/Core/Enums/OrganizationConnectionType.cs index 995cfc8662..e998e5532e 100644 --- a/src/Core/Enums/OrganizationConnectionType.cs +++ b/src/Core/Enums/OrganizationConnectionType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum OrganizationConnectionType : byte +namespace Bit.Core.Enums { - CloudBillingSync = 1, - Scim = 2, + public enum OrganizationConnectionType : byte + { + CloudBillingSync = 1, + Scim = 2, + } } diff --git a/src/Core/Enums/OrganizationUserStatusType.cs b/src/Core/Enums/OrganizationUserStatusType.cs index 576e98ea74..8c39c053f1 100644 --- a/src/Core/Enums/OrganizationUserStatusType.cs +++ b/src/Core/Enums/OrganizationUserStatusType.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Enums; - -public enum OrganizationUserStatusType : short +namespace Bit.Core.Enums { - Invited = 0, - Accepted = 1, - Confirmed = 2, - Revoked = -1, + public enum OrganizationUserStatusType : short + { + Invited = 0, + Accepted = 1, + Confirmed = 2, + Revoked = -1, + } } diff --git a/src/Core/Enums/OrganizationUserType.cs b/src/Core/Enums/OrganizationUserType.cs index 620eaeb330..738c80657e 100644 --- a/src/Core/Enums/OrganizationUserType.cs +++ b/src/Core/Enums/OrganizationUserType.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Enums; - -public enum OrganizationUserType : byte +namespace Bit.Core.Enums { - Owner = 0, - Admin = 1, - User = 2, - Manager = 3, - Custom = 4, + public enum OrganizationUserType : byte + { + Owner = 0, + Admin = 1, + User = 2, + Manager = 3, + Custom = 4, + } } diff --git a/src/Core/Enums/PaymentMethodType.cs b/src/Core/Enums/PaymentMethodType.cs index 0b6c235b3b..b0290f92b3 100644 --- a/src/Core/Enums/PaymentMethodType.cs +++ b/src/Core/Enums/PaymentMethodType.cs @@ -1,27 +1,28 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum PaymentMethodType : byte +namespace Bit.Core.Enums { - [Display(Name = "Card")] - Card = 0, - [Display(Name = "Bank Account")] - BankAccount = 1, - [Display(Name = "PayPal")] - PayPal = 2, - [Display(Name = "BitPay")] - BitPay = 3, - [Display(Name = "Credit")] - Credit = 4, - [Display(Name = "Wire Transfer")] - WireTransfer = 5, - [Display(Name = "Apple In-App Purchase")] - AppleInApp = 6, - [Display(Name = "Google In-App Purchase")] - GoogleInApp = 7, - [Display(Name = "Check")] - Check = 8, - [Display(Name = "None")] - None = 255, + public enum PaymentMethodType : byte + { + [Display(Name = "Card")] + Card = 0, + [Display(Name = "Bank Account")] + BankAccount = 1, + [Display(Name = "PayPal")] + PayPal = 2, + [Display(Name = "BitPay")] + BitPay = 3, + [Display(Name = "Credit")] + Credit = 4, + [Display(Name = "Wire Transfer")] + WireTransfer = 5, + [Display(Name = "Apple In-App Purchase")] + AppleInApp = 6, + [Display(Name = "Google In-App Purchase")] + GoogleInApp = 7, + [Display(Name = "Check")] + Check = 8, + [Display(Name = "None")] + None = 255, + } } diff --git a/src/Core/Enums/PlanSponsorshipType.cs b/src/Core/Enums/PlanSponsorshipType.cs index 2bb7a15b10..59f778e101 100644 --- a/src/Core/Enums/PlanSponsorshipType.cs +++ b/src/Core/Enums/PlanSponsorshipType.cs @@ -1,9 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum PlanSponsorshipType : byte +namespace Bit.Core.Enums { - [Display(Name = "Families For Enterprise")] - FamiliesForEnterprise = 0, + public enum PlanSponsorshipType : byte + { + [Display(Name = "Families For Enterprise")] + FamiliesForEnterprise = 0, + } } diff --git a/src/Core/Enums/PlanType.cs b/src/Core/Enums/PlanType.cs index ac32f217e4..037f1f8938 100644 --- a/src/Core/Enums/PlanType.cs +++ b/src/Core/Enums/PlanType.cs @@ -1,31 +1,32 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum PlanType : byte +namespace Bit.Core.Enums { - [Display(Name = "Free")] - Free = 0, - [Display(Name = "Families 2019")] - FamiliesAnnually2019 = 1, - [Display(Name = "Teams (Monthly) 2019")] - TeamsMonthly2019 = 2, - [Display(Name = "Teams (Annually) 2019")] - TeamsAnnually2019 = 3, - [Display(Name = "Enterprise (Monthly) 2019")] - EnterpriseMonthly2019 = 4, - [Display(Name = "Enterprise (Annually) 2019")] - EnterpriseAnnually2019 = 5, - [Display(Name = "Custom")] - Custom = 6, - [Display(Name = "Families")] - FamiliesAnnually = 7, - [Display(Name = "Teams (Monthly)")] - TeamsMonthly = 8, - [Display(Name = "Teams (Annually)")] - TeamsAnnually = 9, - [Display(Name = "Enterprise (Monthly)")] - EnterpriseMonthly = 10, - [Display(Name = "Enterprise (Annually)")] - EnterpriseAnnually = 11, + public enum PlanType : byte + { + [Display(Name = "Free")] + Free = 0, + [Display(Name = "Families 2019")] + FamiliesAnnually2019 = 1, + [Display(Name = "Teams (Monthly) 2019")] + TeamsMonthly2019 = 2, + [Display(Name = "Teams (Annually) 2019")] + TeamsAnnually2019 = 3, + [Display(Name = "Enterprise (Monthly) 2019")] + EnterpriseMonthly2019 = 4, + [Display(Name = "Enterprise (Annually) 2019")] + EnterpriseAnnually2019 = 5, + [Display(Name = "Custom")] + Custom = 6, + [Display(Name = "Families")] + FamiliesAnnually = 7, + [Display(Name = "Teams (Monthly)")] + TeamsMonthly = 8, + [Display(Name = "Teams (Annually)")] + TeamsAnnually = 9, + [Display(Name = "Enterprise (Monthly)")] + EnterpriseMonthly = 10, + [Display(Name = "Enterprise (Annually)")] + EnterpriseAnnually = 11, + } } diff --git a/src/Core/Enums/PolicyType.cs b/src/Core/Enums/PolicyType.cs index e4c1208362..ac76699957 100644 --- a/src/Core/Enums/PolicyType.cs +++ b/src/Core/Enums/PolicyType.cs @@ -1,16 +1,17 @@ -namespace Bit.Core.Enums; - -public enum PolicyType : byte +namespace Bit.Core.Enums { - TwoFactorAuthentication = 0, - MasterPassword = 1, - PasswordGenerator = 2, - SingleOrg = 3, - RequireSso = 4, - PersonalOwnership = 5, - DisableSend = 6, - SendOptions = 7, - ResetPassword = 8, - MaximumVaultTimeout = 9, - DisablePersonalVaultExport = 10, + public enum PolicyType : byte + { + TwoFactorAuthentication = 0, + MasterPassword = 1, + PasswordGenerator = 2, + SingleOrg = 3, + RequireSso = 4, + PersonalOwnership = 5, + DisableSend = 6, + SendOptions = 7, + ResetPassword = 8, + MaximumVaultTimeout = 9, + DisablePersonalVaultExport = 10, + } } diff --git a/src/Core/Enums/ProductType.cs b/src/Core/Enums/ProductType.cs index 1e443f56f9..2f9b1d478d 100644 --- a/src/Core/Enums/ProductType.cs +++ b/src/Core/Enums/ProductType.cs @@ -1,16 +1,17 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum ProductType : byte +namespace Bit.Core.Enums { - [Display(Name = "Free")] - Free = 0, - [Display(Name = "Families")] - Families = 1, - [Display(Name = "Teams")] - Teams = 2, - [Display(Name = "Enterprise")] - Enterprise = 3, + public enum ProductType : byte + { + [Display(Name = "Free")] + Free = 0, + [Display(Name = "Families")] + Families = 1, + [Display(Name = "Teams")] + Teams = 2, + [Display(Name = "Enterprise")] + Enterprise = 3, + } } diff --git a/src/Core/Enums/Provider/ProviderStatusType.cs b/src/Core/Enums/Provider/ProviderStatusType.cs index bcb1f8cd2d..16d8d63303 100644 --- a/src/Core/Enums/Provider/ProviderStatusType.cs +++ b/src/Core/Enums/Provider/ProviderStatusType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums.Provider; - -public enum ProviderStatusType : byte +namespace Bit.Core.Enums.Provider { - Pending = 0, - Created = 1, + public enum ProviderStatusType : byte + { + Pending = 0, + Created = 1, + } } diff --git a/src/Core/Enums/Provider/ProviderUserStatusType.cs b/src/Core/Enums/Provider/ProviderUserStatusType.cs index 60571386d7..73e9c8e335 100644 --- a/src/Core/Enums/Provider/ProviderUserStatusType.cs +++ b/src/Core/Enums/Provider/ProviderUserStatusType.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Enums.Provider; - -public enum ProviderUserStatusType : byte +namespace Bit.Core.Enums.Provider { - Invited = 0, - Accepted = 1, - Confirmed = 2, + public enum ProviderUserStatusType : byte + { + Invited = 0, + Accepted = 1, + Confirmed = 2, + } } diff --git a/src/Core/Enums/Provider/ProviderUserType.cs b/src/Core/Enums/Provider/ProviderUserType.cs index d13591290d..7147d21a3b 100644 --- a/src/Core/Enums/Provider/ProviderUserType.cs +++ b/src/Core/Enums/Provider/ProviderUserType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums.Provider; - -public enum ProviderUserType : byte +namespace Bit.Core.Enums.Provider { - ProviderAdmin = 0, - ServiceUser = 1, + public enum ProviderUserType : byte + { + ProviderAdmin = 0, + ServiceUser = 1, + } } diff --git a/src/Core/Enums/PushType.cs b/src/Core/Enums/PushType.cs index 9054d1d40b..7899656b5f 100644 --- a/src/Core/Enums/PushType.cs +++ b/src/Core/Enums/PushType.cs @@ -1,23 +1,24 @@ -namespace Bit.Core.Enums; - -public enum PushType : byte +namespace Bit.Core.Enums { - SyncCipherUpdate = 0, - SyncCipherCreate = 1, - SyncLoginDelete = 2, - SyncFolderDelete = 3, - SyncCiphers = 4, + public enum PushType : byte + { + SyncCipherUpdate = 0, + SyncCipherCreate = 1, + SyncLoginDelete = 2, + SyncFolderDelete = 3, + SyncCiphers = 4, - SyncVault = 5, - SyncOrgKeys = 6, - SyncFolderCreate = 7, - SyncFolderUpdate = 8, - SyncCipherDelete = 9, - SyncSettings = 10, + SyncVault = 5, + SyncOrgKeys = 6, + SyncFolderCreate = 7, + SyncFolderUpdate = 8, + SyncCipherDelete = 9, + SyncSettings = 10, - LogOut = 11, + LogOut = 11, - SyncSendCreate = 12, - SyncSendUpdate = 13, - SyncSendDelete = 14, + SyncSendCreate = 12, + SyncSendUpdate = 13, + SyncSendDelete = 14, + } } diff --git a/src/Core/Enums/ReferenceEventSource.cs b/src/Core/Enums/ReferenceEventSource.cs index 3d7ad85ffa..0a19b0772d 100644 --- a/src/Core/Enums/ReferenceEventSource.cs +++ b/src/Core/Enums/ReferenceEventSource.cs @@ -1,11 +1,12 @@ using System.Runtime.Serialization; -namespace Bit.Core.Enums; - -public enum ReferenceEventSource +namespace Bit.Core.Enums { - [EnumMember(Value = "organization")] - Organization, - [EnumMember(Value = "user")] - User, + public enum ReferenceEventSource + { + [EnumMember(Value = "organization")] + Organization, + [EnumMember(Value = "user")] + User, + } } diff --git a/src/Core/Enums/ReferenceEventType.cs b/src/Core/Enums/ReferenceEventType.cs index 1a925736c4..efd631f32a 100644 --- a/src/Core/Enums/ReferenceEventType.cs +++ b/src/Core/Enums/ReferenceEventType.cs @@ -1,43 +1,44 @@ using System.Runtime.Serialization; -namespace Bit.Core.Enums; - -public enum ReferenceEventType +namespace Bit.Core.Enums { - [EnumMember(Value = "signup")] - Signup, - [EnumMember(Value = "upgrade-plan")] - UpgradePlan, - [EnumMember(Value = "adjust-storage")] - AdjustStorage, - [EnumMember(Value = "adjust-seats")] - AdjustSeats, - [EnumMember(Value = "cancel-subscription")] - CancelSubscription, - [EnumMember(Value = "reinstate-subscription")] - ReinstateSubscription, - [EnumMember(Value = "delete-account")] - DeleteAccount, - [EnumMember(Value = "confirm-email")] - ConfirmEmailAddress, - [EnumMember(Value = "invited-users")] - InvitedUsers, - [EnumMember(Value = "rebilled")] - Rebilled, - [EnumMember(Value = "send-created")] - SendCreated, - [EnumMember(Value = "send-accessed")] - SendAccessed, - [EnumMember(Value = "directory-synced")] - DirectorySynced, - [EnumMember(Value = "vault-imported")] - VaultImported, - [EnumMember(Value = "cipher-created")] - CipherCreated, - [EnumMember(Value = "group-created")] - GroupCreated, - [EnumMember(Value = "collection-created")] - CollectionCreated, - [EnumMember(Value = "organization-edited-by-admin")] - OrganizationEditedByAdmin + public enum ReferenceEventType + { + [EnumMember(Value = "signup")] + Signup, + [EnumMember(Value = "upgrade-plan")] + UpgradePlan, + [EnumMember(Value = "adjust-storage")] + AdjustStorage, + [EnumMember(Value = "adjust-seats")] + AdjustSeats, + [EnumMember(Value = "cancel-subscription")] + CancelSubscription, + [EnumMember(Value = "reinstate-subscription")] + ReinstateSubscription, + [EnumMember(Value = "delete-account")] + DeleteAccount, + [EnumMember(Value = "confirm-email")] + ConfirmEmailAddress, + [EnumMember(Value = "invited-users")] + InvitedUsers, + [EnumMember(Value = "rebilled")] + Rebilled, + [EnumMember(Value = "send-created")] + SendCreated, + [EnumMember(Value = "send-accessed")] + SendAccessed, + [EnumMember(Value = "directory-synced")] + DirectorySynced, + [EnumMember(Value = "vault-imported")] + VaultImported, + [EnumMember(Value = "cipher-created")] + CipherCreated, + [EnumMember(Value = "group-created")] + GroupCreated, + [EnumMember(Value = "collection-created")] + CollectionCreated, + [EnumMember(Value = "organization-edited-by-admin")] + OrganizationEditedByAdmin + } } diff --git a/src/Core/Enums/Saml2BindingType.cs b/src/Core/Enums/Saml2BindingType.cs index c02a5d7ccb..0c0882bc48 100644 --- a/src/Core/Enums/Saml2BindingType.cs +++ b/src/Core/Enums/Saml2BindingType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum Saml2BindingType : byte +namespace Bit.Core.Enums { - HttpRedirect = 1, - HttpPost = 2, + public enum Saml2BindingType : byte + { + HttpRedirect = 1, + HttpPost = 2, + } } diff --git a/src/Core/Enums/Saml2NameIdFormat.cs b/src/Core/Enums/Saml2NameIdFormat.cs index f90426e5c2..9ba83e58fd 100644 --- a/src/Core/Enums/Saml2NameIdFormat.cs +++ b/src/Core/Enums/Saml2NameIdFormat.cs @@ -1,14 +1,15 @@ -namespace Bit.Core.Enums; - -public enum Saml2NameIdFormat : byte +namespace Bit.Core.Enums { - NotConfigured = 0, - Unspecified = 1, - EmailAddress = 2, - X509SubjectName = 3, - WindowsDomainQualifiedName = 4, - KerberosPrincipalName = 5, - EntityIdentifier = 6, - Persistent = 7, - Transient = 8, + public enum Saml2NameIdFormat : byte + { + NotConfigured = 0, + Unspecified = 1, + EmailAddress = 2, + X509SubjectName = 3, + WindowsDomainQualifiedName = 4, + KerberosPrincipalName = 5, + EntityIdentifier = 6, + Persistent = 7, + Transient = 8, + } } diff --git a/src/Core/Enums/Saml2SigningBehavior.cs b/src/Core/Enums/Saml2SigningBehavior.cs index 25344dbc86..a02e5b1d91 100644 --- a/src/Core/Enums/Saml2SigningBehavior.cs +++ b/src/Core/Enums/Saml2SigningBehavior.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Enums; - -public enum Saml2SigningBehavior : byte +namespace Bit.Core.Enums { - IfIdpWantAuthnRequestsSigned = 0, - Always = 1, - Never = 3 + public enum Saml2SigningBehavior : byte + { + IfIdpWantAuthnRequestsSigned = 0, + Always = 1, + Never = 3 + } } diff --git a/src/Core/Enums/ScimProviderType.cs b/src/Core/Enums/ScimProviderType.cs index c1d4670392..18039c87c9 100644 --- a/src/Core/Enums/ScimProviderType.cs +++ b/src/Core/Enums/ScimProviderType.cs @@ -1,12 +1,13 @@ -namespace Bit.Core.Enums; - -public enum ScimProviderType : byte +namespace Bit.Core.Enums { - Default = 0, - AzureAd = 1, - Okta = 2, - OneLogin = 3, - JumpCloud = 4, - GoogleWorkspace = 5, - Rippling = 6, + public enum ScimProviderType : byte + { + Default = 0, + AzureAd = 1, + Okta = 2, + OneLogin = 3, + JumpCloud = 4, + GoogleWorkspace = 5, + Rippling = 6, + } } diff --git a/src/Core/Enums/SecureNoteType.cs b/src/Core/Enums/SecureNoteType.cs index cdd565e7c1..cc84edfc35 100644 --- a/src/Core/Enums/SecureNoteType.cs +++ b/src/Core/Enums/SecureNoteType.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Enums; - -public enum SecureNoteType : byte +namespace Bit.Core.Enums { - Generic = 0 + public enum SecureNoteType : byte + { + Generic = 0 + } } diff --git a/src/Core/Enums/SendType.cs b/src/Core/Enums/SendType.cs index ce59df6b39..a52008556a 100644 --- a/src/Core/Enums/SendType.cs +++ b/src/Core/Enums/SendType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum SendType : byte +namespace Bit.Core.Enums { - Text = 0, - File = 1 + public enum SendType : byte + { + Text = 0, + File = 1 + } } diff --git a/src/Core/Enums/SsoType.cs b/src/Core/Enums/SsoType.cs index 3e890817f7..3c1884bd79 100644 --- a/src/Core/Enums/SsoType.cs +++ b/src/Core/Enums/SsoType.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Enums; - -public enum SsoType : byte +namespace Bit.Core.Enums { - OpenIdConnect = 1, - Saml2 = 2, + public enum SsoType : byte + { + OpenIdConnect = 1, + Saml2 = 2, + } } diff --git a/src/Core/Enums/SupportedDatabaseProviders.cs b/src/Core/Enums/SupportedDatabaseProviders.cs index 81e60b58ec..c38a023c48 100644 --- a/src/Core/Enums/SupportedDatabaseProviders.cs +++ b/src/Core/Enums/SupportedDatabaseProviders.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Enums; - -public enum SupportedDatabaseProviders +namespace Bit.Core.Enums { - SqlServer, - MySql, - Postgres, + public enum SupportedDatabaseProviders + { + SqlServer, + MySql, + Postgres, + } } diff --git a/src/Core/Enums/TransactionType.cs b/src/Core/Enums/TransactionType.cs index 6a5107763f..02556ae1da 100644 --- a/src/Core/Enums/TransactionType.cs +++ b/src/Core/Enums/TransactionType.cs @@ -1,17 +1,18 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums; - -public enum TransactionType : byte +namespace Bit.Core.Enums { - [Display(Name = "Charge")] - Charge = 0, - [Display(Name = "Credit")] - Credit = 1, - [Display(Name = "Promotional Credit")] - PromotionalCredit = 2, - [Display(Name = "Referral Credit")] - ReferralCredit = 3, - [Display(Name = "Refund")] - Refund = 4, + public enum TransactionType : byte + { + [Display(Name = "Charge")] + Charge = 0, + [Display(Name = "Credit")] + Credit = 1, + [Display(Name = "Promotional Credit")] + PromotionalCredit = 2, + [Display(Name = "Referral Credit")] + ReferralCredit = 3, + [Display(Name = "Refund")] + Refund = 4, + } } diff --git a/src/Core/Enums/TwoFactorProviderType.cs b/src/Core/Enums/TwoFactorProviderType.cs index 31d6269910..40c4e55111 100644 --- a/src/Core/Enums/TwoFactorProviderType.cs +++ b/src/Core/Enums/TwoFactorProviderType.cs @@ -1,13 +1,14 @@ -namespace Bit.Core.Enums; - -public enum TwoFactorProviderType : byte +namespace Bit.Core.Enums { - Authenticator = 0, - Email = 1, - Duo = 2, - YubiKey = 3, - U2f = 4, // Deprecated - Remember = 5, - OrganizationDuo = 6, - WebAuthn = 7, + public enum TwoFactorProviderType : byte + { + Authenticator = 0, + Email = 1, + Duo = 2, + YubiKey = 3, + U2f = 4, // Deprecated + Remember = 5, + OrganizationDuo = 6, + WebAuthn = 7, + } } diff --git a/src/Core/Enums/UriMatchType.cs b/src/Core/Enums/UriMatchType.cs index 593caf40ca..5694372989 100644 --- a/src/Core/Enums/UriMatchType.cs +++ b/src/Core/Enums/UriMatchType.cs @@ -1,11 +1,12 @@ -namespace Bit.Core.Enums; - -public enum UriMatchType : byte +namespace Bit.Core.Enums { - Domain = 0, - Host = 1, - StartsWith = 2, - Exact = 3, - RegularExpression = 4, - Never = 5 + public enum UriMatchType : byte + { + Domain = 0, + Host = 1, + StartsWith = 2, + Exact = 3, + RegularExpression = 4, + Never = 5 + } } diff --git a/src/Core/Exceptions/BadRequestException.cs b/src/Core/Exceptions/BadRequestException.cs index d18bd041e3..686bf786c1 100644 --- a/src/Core/Exceptions/BadRequestException.cs +++ b/src/Core/Exceptions/BadRequestException.cs @@ -1,30 +1,31 @@ using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Core.Exceptions; - -public class BadRequestException : Exception +namespace Bit.Core.Exceptions { - public BadRequestException(string message) - : base(message) - { } - - public BadRequestException(string key, string errorMessage) - : base("The model state is invalid.") + public class BadRequestException : Exception { - ModelState = new ModelStateDictionary(); - ModelState.AddModelError(key, errorMessage); - } + public BadRequestException(string message) + : base(message) + { } - public BadRequestException(ModelStateDictionary modelState) - : base("The model state is invalid.") - { - if (modelState.IsValid || modelState.ErrorCount == 0) + public BadRequestException(string key, string errorMessage) + : base("The model state is invalid.") { - return; + ModelState = new ModelStateDictionary(); + ModelState.AddModelError(key, errorMessage); } - ModelState = modelState; - } + public BadRequestException(ModelStateDictionary modelState) + : base("The model state is invalid.") + { + if (modelState.IsValid || modelState.ErrorCount == 0) + { + return; + } - public ModelStateDictionary ModelState { get; set; } + ModelState = modelState; + } + + public ModelStateDictionary ModelState { get; set; } + } } diff --git a/src/Core/Exceptions/GatewayException.cs b/src/Core/Exceptions/GatewayException.cs index 73e8cd7613..d97511a68d 100644 --- a/src/Core/Exceptions/GatewayException.cs +++ b/src/Core/Exceptions/GatewayException.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Exceptions; - -public class GatewayException : Exception +namespace Bit.Core.Exceptions { - public GatewayException(string message, Exception innerException = null) - : base(message, innerException) - { } + public class GatewayException : Exception + { + public GatewayException(string message, Exception innerException = null) + : base(message, innerException) + { } + } } diff --git a/src/Core/Exceptions/InvalidEmailException.cs b/src/Core/Exceptions/InvalidEmailException.cs index 1f17acf62e..64ede1fdb2 100644 --- a/src/Core/Exceptions/InvalidEmailException.cs +++ b/src/Core/Exceptions/InvalidEmailException.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Exceptions; - -public class InvalidEmailException : Exception +namespace Bit.Core.Exceptions { - public InvalidEmailException() - : base("Invalid email.") + public class InvalidEmailException : Exception { + public InvalidEmailException() + : base("Invalid email.") + { + } } } diff --git a/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs b/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs index cfc7c56c1c..ad3a4544ac 100644 --- a/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs +++ b/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Exceptions; - -public class InvalidGatewayCustomerIdException : Exception +namespace Bit.Core.Exceptions { - public InvalidGatewayCustomerIdException() - : base("Invalid gateway customerId.") + public class InvalidGatewayCustomerIdException : Exception { + public InvalidGatewayCustomerIdException() + : base("Invalid gateway customerId.") + { + } } } diff --git a/src/Core/Exceptions/NotFoundException.cs b/src/Core/Exceptions/NotFoundException.cs index 3f52f792c4..a47023093a 100644 --- a/src/Core/Exceptions/NotFoundException.cs +++ b/src/Core/Exceptions/NotFoundException.cs @@ -1,3 +1,4 @@ -namespace Bit.Core.Exceptions; - -public class NotFoundException : Exception { } +namespace Bit.Core.Exceptions +{ + public class NotFoundException : Exception { } +} diff --git a/src/Core/HostedServices/ApplicationCacheHostedService.cs b/src/Core/HostedServices/ApplicationCacheHostedService.cs index d5f4b77e3f..a5a27e5dec 100644 --- a/src/Core/HostedServices/ApplicationCacheHostedService.cs +++ b/src/Core/HostedServices/ApplicationCacheHostedService.cs @@ -8,99 +8,100 @@ using Microsoft.Azure.ServiceBus.Management; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.HostedServices; - -public class ApplicationCacheHostedService : IHostedService, IDisposable +namespace Bit.Core.HostedServices { - private readonly InMemoryServiceBusApplicationCacheService _applicationCacheService; - private readonly IOrganizationRepository _organizationRepository; - protected readonly ILogger _logger; - private readonly SubscriptionClient _subscriptionClient; - private readonly ManagementClient _managementClient; - private readonly string _subName; - private readonly string _topicName; - - public ApplicationCacheHostedService( - IApplicationCacheService applicationCacheService, - IOrganizationRepository organizationRepository, - ILogger logger, - GlobalSettings globalSettings) + public class ApplicationCacheHostedService : IHostedService, IDisposable { - _topicName = globalSettings.ServiceBus.ApplicationCacheTopicName; - _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); - _applicationCacheService = applicationCacheService as InMemoryServiceBusApplicationCacheService; - _organizationRepository = organizationRepository; - _logger = logger; - _managementClient = new ManagementClient(globalSettings.ServiceBus.ConnectionString); - _subscriptionClient = new SubscriptionClient(globalSettings.ServiceBus.ConnectionString, - _topicName, _subName); - } + private readonly InMemoryServiceBusApplicationCacheService _applicationCacheService; + private readonly IOrganizationRepository _organizationRepository; + protected readonly ILogger _logger; + private readonly SubscriptionClient _subscriptionClient; + private readonly ManagementClient _managementClient; + private readonly string _subName; + private readonly string _topicName; - public virtual async Task StartAsync(CancellationToken cancellationToken) - { - try + public ApplicationCacheHostedService( + IApplicationCacheService applicationCacheService, + IOrganizationRepository organizationRepository, + ILogger logger, + GlobalSettings globalSettings) { - await _managementClient.CreateSubscriptionAsync(new SubscriptionDescription(_topicName, _subName) - { - DefaultMessageTimeToLive = TimeSpan.FromDays(14), - LockDuration = TimeSpan.FromSeconds(30), - EnableDeadLetteringOnFilterEvaluationExceptions = true, - EnableDeadLetteringOnMessageExpiration = true, - }, new RuleDescription("default", new SqlFilter($"sys.Label != '{_subName}'"))); + _topicName = globalSettings.ServiceBus.ApplicationCacheTopicName; + _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); + _applicationCacheService = applicationCacheService as InMemoryServiceBusApplicationCacheService; + _organizationRepository = organizationRepository; + _logger = logger; + _managementClient = new ManagementClient(globalSettings.ServiceBus.ConnectionString); + _subscriptionClient = new SubscriptionClient(globalSettings.ServiceBus.ConnectionString, + _topicName, _subName); } - catch (MessagingEntityAlreadyExistsException) { } - _subscriptionClient.RegisterMessageHandler(ProcessMessageAsync, - new MessageHandlerOptions(ExceptionReceivedHandlerAsync) - { - MaxConcurrentCalls = 2, - AutoComplete = false, - }); - } - public virtual async Task StopAsync(CancellationToken cancellationToken) - { - await _subscriptionClient.CloseAsync(); - try + public virtual async Task StartAsync(CancellationToken cancellationToken) { - await _managementClient.DeleteSubscriptionAsync(_topicName, _subName, cancellationToken); + try + { + await _managementClient.CreateSubscriptionAsync(new SubscriptionDescription(_topicName, _subName) + { + DefaultMessageTimeToLive = TimeSpan.FromDays(14), + LockDuration = TimeSpan.FromSeconds(30), + EnableDeadLetteringOnFilterEvaluationExceptions = true, + EnableDeadLetteringOnMessageExpiration = true, + }, new RuleDescription("default", new SqlFilter($"sys.Label != '{_subName}'"))); + } + catch (MessagingEntityAlreadyExistsException) { } + _subscriptionClient.RegisterMessageHandler(ProcessMessageAsync, + new MessageHandlerOptions(ExceptionReceivedHandlerAsync) + { + MaxConcurrentCalls = 2, + AutoComplete = false, + }); } - catch { } - } - public virtual void Dispose() - { } - - private async Task ProcessMessageAsync(Message message, CancellationToken cancellationToken) - { - if (message.Label != _subName && _applicationCacheService != null) + public virtual async Task StopAsync(CancellationToken cancellationToken) { - switch ((ApplicationCacheMessageType)message.UserProperties["type"]) + await _subscriptionClient.CloseAsync(); + try { - case ApplicationCacheMessageType.UpsertOrganizationAbility: - var upsertedOrgId = (Guid)message.UserProperties["id"]; - var upsertedOrg = await _organizationRepository.GetByIdAsync(upsertedOrgId); - if (upsertedOrg != null) - { - await _applicationCacheService.BaseUpsertOrganizationAbilityAsync(upsertedOrg); - } - break; - case ApplicationCacheMessageType.DeleteOrganizationAbility: - await _applicationCacheService.BaseDeleteOrganizationAbilityAsync( - (Guid)message.UserProperties["id"]); - break; - default: - break; + await _managementClient.DeleteSubscriptionAsync(_topicName, _subName, cancellationToken); + } + catch { } + } + + public virtual void Dispose() + { } + + private async Task ProcessMessageAsync(Message message, CancellationToken cancellationToken) + { + if (message.Label != _subName && _applicationCacheService != null) + { + switch ((ApplicationCacheMessageType)message.UserProperties["type"]) + { + case ApplicationCacheMessageType.UpsertOrganizationAbility: + var upsertedOrgId = (Guid)message.UserProperties["id"]; + var upsertedOrg = await _organizationRepository.GetByIdAsync(upsertedOrgId); + if (upsertedOrg != null) + { + await _applicationCacheService.BaseUpsertOrganizationAbilityAsync(upsertedOrg); + } + break; + case ApplicationCacheMessageType.DeleteOrganizationAbility: + await _applicationCacheService.BaseDeleteOrganizationAbilityAsync( + (Guid)message.UserProperties["id"]); + break; + default: + break; + } + } + if (!cancellationToken.IsCancellationRequested) + { + await _subscriptionClient.CompleteAsync(message.SystemProperties.LockToken); } } - if (!cancellationToken.IsCancellationRequested) + + private Task ExceptionReceivedHandlerAsync(ExceptionReceivedEventArgs args) { - await _subscriptionClient.CompleteAsync(message.SystemProperties.LockToken); + _logger.LogError(args.Exception, "Message handler encountered an exception."); + return Task.FromResult(0); } } - - private Task ExceptionReceivedHandlerAsync(ExceptionReceivedEventArgs args) - { - _logger.LogError(args.Exception, "Message handler encountered an exception."); - return Task.FromResult(0); - } } diff --git a/src/Core/HostedServices/IpRateLimitSeedStartupService.cs b/src/Core/HostedServices/IpRateLimitSeedStartupService.cs index a6869d929c..dd77982cb5 100644 --- a/src/Core/HostedServices/IpRateLimitSeedStartupService.cs +++ b/src/Core/HostedServices/IpRateLimitSeedStartupService.cs @@ -1,40 +1,41 @@ using AspNetCoreRateLimit; using Microsoft.Extensions.Hosting; -namespace Bit.Core.HostedServices; - -/// -/// A startup service that will seed the IP rate limiting stores with any values in the -/// GlobalSettings configuration. -/// -/// -/// Using an here because it runs before the request processing pipeline -/// is configured, so that any rate limiting configuration is seeded/applied before any requests come in. -/// -/// -/// This is a cleaner alternative to modifying Program.cs in every project that requires rate limiting as -/// described/suggested here: -/// https://github.com/stefanprodan/AspNetCoreRateLimit/wiki/Version-3.0.0-Breaking-Changes -/// -/// -public class IpRateLimitSeedStartupService : IHostedService +namespace Bit.Core.HostedServices { - private readonly IIpPolicyStore _ipPolicyStore; - private readonly IClientPolicyStore _clientPolicyStore; - - public IpRateLimitSeedStartupService(IIpPolicyStore ipPolicyStore, IClientPolicyStore clientPolicyStore) + /// + /// A startup service that will seed the IP rate limiting stores with any values in the + /// GlobalSettings configuration. + /// + /// + /// Using an here because it runs before the request processing pipeline + /// is configured, so that any rate limiting configuration is seeded/applied before any requests come in. + /// + /// + /// This is a cleaner alternative to modifying Program.cs in every project that requires rate limiting as + /// described/suggested here: + /// https://github.com/stefanprodan/AspNetCoreRateLimit/wiki/Version-3.0.0-Breaking-Changes + /// + /// + public class IpRateLimitSeedStartupService : IHostedService { - _ipPolicyStore = ipPolicyStore; - _clientPolicyStore = clientPolicyStore; - } + private readonly IIpPolicyStore _ipPolicyStore; + private readonly IClientPolicyStore _clientPolicyStore; - public async Task StartAsync(CancellationToken cancellationToken) - { - // Seed the policies from GlobalSettings - await _ipPolicyStore.SeedAsync(); - await _clientPolicyStore.SeedAsync(); - } + public IpRateLimitSeedStartupService(IIpPolicyStore ipPolicyStore, IClientPolicyStore clientPolicyStore) + { + _ipPolicyStore = ipPolicyStore; + _clientPolicyStore = clientPolicyStore; + } - // noop - public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; + public async Task StartAsync(CancellationToken cancellationToken) + { + // Seed the policies from GlobalSettings + await _ipPolicyStore.SeedAsync(); + await _clientPolicyStore.SeedAsync(); + } + + // noop + public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; + } } diff --git a/src/Core/Identity/AuthenticatorTokenProvider.cs b/src/Core/Identity/AuthenticatorTokenProvider.cs index 8bda023e52..5eef3869db 100644 --- a/src/Core/Identity/AuthenticatorTokenProvider.cs +++ b/src/Core/Identity/AuthenticatorTokenProvider.cs @@ -5,41 +5,42 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; using OtpNet; -namespace Bit.Core.Identity; - -public class AuthenticatorTokenProvider : IUserTwoFactorTokenProvider +namespace Bit.Core.Identity { - private readonly IServiceProvider _serviceProvider; - - public AuthenticatorTokenProvider(IServiceProvider serviceProvider) + public class AuthenticatorTokenProvider : IUserTwoFactorTokenProvider { - _serviceProvider = serviceProvider; - } + private readonly IServiceProvider _serviceProvider; - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); - if (string.IsNullOrWhiteSpace((string)provider?.MetaData["Key"])) + public AuthenticatorTokenProvider(IServiceProvider serviceProvider) { - return false; + _serviceProvider = serviceProvider; } - return await _serviceProvider.GetRequiredService() - .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Authenticator, user); - } - public Task GenerateAsync(string purpose, UserManager manager, User user) - { - return Task.FromResult(null); - } + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); + if (string.IsNullOrWhiteSpace((string)provider?.MetaData["Key"])) + { + return false; + } + return await _serviceProvider.GetRequiredService() + .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Authenticator, user); + } - public Task ValidateAsync(string purpose, string token, UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); - var otp = new Totp(Base32Encoding.ToBytes((string)provider.MetaData["Key"])); + public Task GenerateAsync(string purpose, UserManager manager, User user) + { + return Task.FromResult(null); + } - long timeStepMatched; - var valid = otp.VerifyTotp(token, out timeStepMatched, new VerificationWindow(1, 1)); + public Task ValidateAsync(string purpose, string token, UserManager manager, User user) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); + var otp = new Totp(Base32Encoding.ToBytes((string)provider.MetaData["Key"])); - return Task.FromResult(valid); + long timeStepMatched; + var valid = otp.VerifyTotp(token, out timeStepMatched, new VerificationWindow(1, 1)); + + return Task.FromResult(valid); + } } } diff --git a/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs b/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs index a63bde8793..0acb4a3f4b 100644 --- a/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs +++ b/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs @@ -2,48 +2,49 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection.Extensions; -namespace Microsoft.Extensions.DependencyInjection; - -// ref: https://github.com/aspnet/Identity/blob/dev/src/Microsoft.AspNetCore.Identity/IdentityServiceCollectionExtensions.cs -public static class CustomIdentityServiceCollectionExtensions +namespace Microsoft.Extensions.DependencyInjection { - public static IdentityBuilder AddIdentityWithoutCookieAuth( - this IServiceCollection services) - where TUser : class - where TRole : class + // ref: https://github.com/aspnet/Identity/blob/dev/src/Microsoft.AspNetCore.Identity/IdentityServiceCollectionExtensions.cs + public static class CustomIdentityServiceCollectionExtensions { - return services.AddIdentityWithoutCookieAuth(setupAction: null); - } - - public static IdentityBuilder AddIdentityWithoutCookieAuth( - this IServiceCollection services, - Action setupAction) - where TUser : class - where TRole : class - { - // Hosting doesn't add IHttpContextAccessor by default - services.AddHttpContextAccessor(); - // Identity services - services.TryAddScoped, UserValidator>(); - services.TryAddScoped, PasswordValidator>(); - services.TryAddScoped, PasswordHasher>(); - services.TryAddScoped(); - services.TryAddScoped, RoleValidator>(); - // No interface for the error describer so we can add errors without rev'ing the interface - services.TryAddScoped(); - services.TryAddScoped>(); - services.TryAddScoped>(); - services.TryAddScoped, UserClaimsPrincipalFactory>(); - services.TryAddScoped, DefaultUserConfirmation>(); - services.TryAddScoped>(); - services.TryAddScoped>(); - services.TryAddScoped>(); - - if (setupAction != null) + public static IdentityBuilder AddIdentityWithoutCookieAuth( + this IServiceCollection services) + where TUser : class + where TRole : class { - services.Configure(setupAction); + return services.AddIdentityWithoutCookieAuth(setupAction: null); } - return new IdentityBuilder(typeof(TUser), typeof(TRole), services); + public static IdentityBuilder AddIdentityWithoutCookieAuth( + this IServiceCollection services, + Action setupAction) + where TUser : class + where TRole : class + { + // Hosting doesn't add IHttpContextAccessor by default + services.AddHttpContextAccessor(); + // Identity services + services.TryAddScoped, UserValidator>(); + services.TryAddScoped, PasswordValidator>(); + services.TryAddScoped, PasswordHasher>(); + services.TryAddScoped(); + services.TryAddScoped, RoleValidator>(); + // No interface for the error describer so we can add errors without rev'ing the interface + services.TryAddScoped(); + services.TryAddScoped>(); + services.TryAddScoped>(); + services.TryAddScoped, UserClaimsPrincipalFactory>(); + services.TryAddScoped, DefaultUserConfirmation>(); + services.TryAddScoped>(); + services.TryAddScoped>(); + services.TryAddScoped>(); + + if (setupAction != null) + { + services.Configure(setupAction); + } + + return new IdentityBuilder(typeof(TUser), typeof(TRole), services); + } } } diff --git a/src/Core/Identity/DuoWebTokenProvider.cs b/src/Core/Identity/DuoWebTokenProvider.cs index 396f3b4005..3ef02df6f6 100644 --- a/src/Core/Identity/DuoWebTokenProvider.cs +++ b/src/Core/Identity/DuoWebTokenProvider.cs @@ -7,80 +7,81 @@ using Bit.Core.Utilities.Duo; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity; - -public class DuoWebTokenProvider : IUserTwoFactorTokenProvider +namespace Bit.Core.Identity { - private readonly IServiceProvider _serviceProvider; - private readonly GlobalSettings _globalSettings; - - public DuoWebTokenProvider( - IServiceProvider serviceProvider, - GlobalSettings globalSettings) + public class DuoWebTokenProvider : IUserTwoFactorTokenProvider { - _serviceProvider = serviceProvider; - _globalSettings = globalSettings; - } + private readonly IServiceProvider _serviceProvider; + private readonly GlobalSettings _globalSettings; - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + public DuoWebTokenProvider( + IServiceProvider serviceProvider, + GlobalSettings globalSettings) { - return false; + _serviceProvider = serviceProvider; + _globalSettings = globalSettings; } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - if (!HasProperMetaData(provider)) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) { - return false; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return false; + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + if (!HasProperMetaData(provider)) + { + return false; + } + + return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Duo, user); } - return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Duo, user); - } - - public async Task GenerateAsync(string purpose, UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + public async Task GenerateAsync(string purpose, UserManager manager, User user) { - return null; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return null; + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + if (!HasProperMetaData(provider)) + { + return null; + } + + var signatureRequest = DuoWeb.SignRequest((string)provider.MetaData["IKey"], + (string)provider.MetaData["SKey"], _globalSettings.Duo.AKey, user.Email); + return signatureRequest; } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - if (!HasProperMetaData(provider)) + public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) { - return null; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return false; + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + if (!HasProperMetaData(provider)) + { + return false; + } + + var response = DuoWeb.VerifyResponse((string)provider.MetaData["IKey"], (string)provider.MetaData["SKey"], + _globalSettings.Duo.AKey, token); + + return response == user.Email; } - var signatureRequest = DuoWeb.SignRequest((string)provider.MetaData["IKey"], - (string)provider.MetaData["SKey"], _globalSettings.Duo.AKey, user.Email); - return signatureRequest; - } - - public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + private bool HasProperMetaData(TwoFactorProvider provider) { - return false; + return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && + provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - if (!HasProperMetaData(provider)) - { - return false; - } - - var response = DuoWeb.VerifyResponse((string)provider.MetaData["IKey"], (string)provider.MetaData["SKey"], - _globalSettings.Duo.AKey, token); - - return response == user.Email; - } - - private bool HasProperMetaData(TwoFactorProvider provider) - { - return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && - provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); } } diff --git a/src/Core/Identity/EmailTokenProvider.cs b/src/Core/Identity/EmailTokenProvider.cs index 71987fa86d..a0002b47fa 100644 --- a/src/Core/Identity/EmailTokenProvider.cs +++ b/src/Core/Identity/EmailTokenProvider.cs @@ -5,79 +5,80 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity; - -public class EmailTokenProvider : IUserTwoFactorTokenProvider +namespace Bit.Core.Identity { - private readonly IServiceProvider _serviceProvider; - - public EmailTokenProvider(IServiceProvider serviceProvider) + public class EmailTokenProvider : IUserTwoFactorTokenProvider { - _serviceProvider = serviceProvider; - } + private readonly IServiceProvider _serviceProvider; - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (!HasProperMetaData(provider)) + public EmailTokenProvider(IServiceProvider serviceProvider) { - return false; + _serviceProvider = serviceProvider; } - return await _serviceProvider.GetRequiredService(). - TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Email, user); - } - - public Task GenerateAsync(string purpose, UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (!HasProperMetaData(provider)) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) { - return null; + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (!HasProperMetaData(provider)) + { + return false; + } + + return await _serviceProvider.GetRequiredService(). + TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Email, user); } - return Task.FromResult(RedactEmail((string)provider.MetaData["Email"])); - } - - public Task ValidateAsync(string purpose, string token, UserManager manager, User user) - { - return _serviceProvider.GetRequiredService().VerifyTwoFactorEmailAsync(user, token); - } - - private bool HasProperMetaData(TwoFactorProvider provider) - { - return provider?.MetaData != null && provider.MetaData.ContainsKey("Email") && - !string.IsNullOrWhiteSpace((string)provider.MetaData["Email"]); - } - - private static string RedactEmail(string email) - { - var emailParts = email.Split('@'); - - string shownPart = null; - if (emailParts[0].Length > 2 && emailParts[0].Length <= 4) + public Task GenerateAsync(string purpose, UserManager manager, User user) { - shownPart = emailParts[0].Substring(0, 1); - } - else if (emailParts[0].Length > 4) - { - shownPart = emailParts[0].Substring(0, 2); - } - else - { - shownPart = string.Empty; + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (!HasProperMetaData(provider)) + { + return null; + } + + return Task.FromResult(RedactEmail((string)provider.MetaData["Email"])); } - string redactedPart = null; - if (emailParts[0].Length > 4) + public Task ValidateAsync(string purpose, string token, UserManager manager, User user) { - redactedPart = new string('*', emailParts[0].Length - 2); - } - else - { - redactedPart = new string('*', emailParts[0].Length - shownPart.Length); + return _serviceProvider.GetRequiredService().VerifyTwoFactorEmailAsync(user, token); } - return $"{shownPart}{redactedPart}@{emailParts[1]}"; + private bool HasProperMetaData(TwoFactorProvider provider) + { + return provider?.MetaData != null && provider.MetaData.ContainsKey("Email") && + !string.IsNullOrWhiteSpace((string)provider.MetaData["Email"]); + } + + private static string RedactEmail(string email) + { + var emailParts = email.Split('@'); + + string shownPart = null; + if (emailParts[0].Length > 2 && emailParts[0].Length <= 4) + { + shownPart = emailParts[0].Substring(0, 1); + } + else if (emailParts[0].Length > 4) + { + shownPart = emailParts[0].Substring(0, 2); + } + else + { + shownPart = string.Empty; + } + + string redactedPart = null; + if (emailParts[0].Length > 4) + { + redactedPart = new string('*', emailParts[0].Length - 2); + } + else + { + redactedPart = new string('*', emailParts[0].Length - shownPart.Length); + } + + return $"{shownPart}{redactedPart}@{emailParts[1]}"; + } } } diff --git a/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs b/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs index 0046add96d..11157a7830 100644 --- a/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs +++ b/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Identity; - -public interface IOrganizationTwoFactorTokenProvider +namespace Bit.Core.Identity { - Task CanGenerateTwoFactorTokenAsync(Organization organization); - Task GenerateAsync(Organization organization, User user); - Task ValidateAsync(string token, Organization organization, User user); + public interface IOrganizationTwoFactorTokenProvider + { + Task CanGenerateTwoFactorTokenAsync(Organization organization); + Task GenerateAsync(Organization organization, User user); + Task ValidateAsync(string token, Organization organization, User user); + } } diff --git a/src/Core/Identity/LowerInvariantLookupNormalizer.cs b/src/Core/Identity/LowerInvariantLookupNormalizer.cs index 880a2bbfbf..591b840a47 100644 --- a/src/Core/Identity/LowerInvariantLookupNormalizer.cs +++ b/src/Core/Identity/LowerInvariantLookupNormalizer.cs @@ -1,21 +1,22 @@ using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity; - -public class LowerInvariantLookupNormalizer : ILookupNormalizer +namespace Bit.Core.Identity { - public string NormalizeEmail(string email) + public class LowerInvariantLookupNormalizer : ILookupNormalizer { - return Normalize(email); - } + public string NormalizeEmail(string email) + { + return Normalize(email); + } - public string NormalizeName(string name) - { - return Normalize(name); - } + public string NormalizeName(string name) + { + return Normalize(name); + } - private string Normalize(string key) - { - return key?.Normalize().ToLowerInvariant(); + private string Normalize(string key) + { + return key?.Normalize().ToLowerInvariant(); + } } } diff --git a/src/Core/Identity/OrganizationDuoWebTokenProvider.cs b/src/Core/Identity/OrganizationDuoWebTokenProvider.cs index 53d979d907..cd3f271845 100644 --- a/src/Core/Identity/OrganizationDuoWebTokenProvider.cs +++ b/src/Core/Identity/OrganizationDuoWebTokenProvider.cs @@ -4,72 +4,73 @@ using Bit.Core.Models; using Bit.Core.Settings; using Bit.Core.Utilities.Duo; -namespace Bit.Core.Identity; - -public interface IOrganizationDuoWebTokenProvider : IOrganizationTwoFactorTokenProvider { } - -public class OrganizationDuoWebTokenProvider : IOrganizationDuoWebTokenProvider +namespace Bit.Core.Identity { - private readonly GlobalSettings _globalSettings; + public interface IOrganizationDuoWebTokenProvider : IOrganizationTwoFactorTokenProvider { } - public OrganizationDuoWebTokenProvider(GlobalSettings globalSettings) + public class OrganizationDuoWebTokenProvider : IOrganizationDuoWebTokenProvider { - _globalSettings = globalSettings; - } + private readonly GlobalSettings _globalSettings; - public Task CanGenerateTwoFactorTokenAsync(Organization organization) - { - if (organization == null || !organization.Enabled || !organization.Use2fa) + public OrganizationDuoWebTokenProvider(GlobalSettings globalSettings) { - return Task.FromResult(false); + _globalSettings = globalSettings; } - var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - var canGenerate = organization.TwoFactorProviderIsEnabled(TwoFactorProviderType.OrganizationDuo) - && HasProperMetaData(provider); - return Task.FromResult(canGenerate); - } - - public Task GenerateAsync(Organization organization, User user) - { - if (organization == null || !organization.Enabled || !organization.Use2fa) + public Task CanGenerateTwoFactorTokenAsync(Organization organization) { - return Task.FromResult(null); + if (organization == null || !organization.Enabled || !organization.Use2fa) + { + return Task.FromResult(false); + } + + var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + var canGenerate = organization.TwoFactorProviderIsEnabled(TwoFactorProviderType.OrganizationDuo) + && HasProperMetaData(provider); + return Task.FromResult(canGenerate); } - var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - if (!HasProperMetaData(provider)) + public Task GenerateAsync(Organization organization, User user) { - return Task.FromResult(null); + if (organization == null || !organization.Enabled || !organization.Use2fa) + { + return Task.FromResult(null); + } + + var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + if (!HasProperMetaData(provider)) + { + return Task.FromResult(null); + } + + var signatureRequest = DuoWeb.SignRequest(provider.MetaData["IKey"].ToString(), + provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, user.Email); + return Task.FromResult(signatureRequest); } - var signatureRequest = DuoWeb.SignRequest(provider.MetaData["IKey"].ToString(), - provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, user.Email); - return Task.FromResult(signatureRequest); - } - - public Task ValidateAsync(string token, Organization organization, User user) - { - if (organization == null || !organization.Enabled || !organization.Use2fa) + public Task ValidateAsync(string token, Organization organization, User user) { - return Task.FromResult(false); + if (organization == null || !organization.Enabled || !organization.Use2fa) + { + return Task.FromResult(false); + } + + var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + if (!HasProperMetaData(provider)) + { + return Task.FromResult(false); + } + + var response = DuoWeb.VerifyResponse(provider.MetaData["IKey"].ToString(), + provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, token); + + return Task.FromResult(response == user.Email); } - var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - if (!HasProperMetaData(provider)) + private bool HasProperMetaData(TwoFactorProvider provider) { - return Task.FromResult(false); + return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && + provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); } - - var response = DuoWeb.VerifyResponse(provider.MetaData["IKey"].ToString(), - provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, token); - - return Task.FromResult(response == user.Email); - } - - private bool HasProperMetaData(TwoFactorProvider provider) - { - return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && - provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); } } diff --git a/src/Core/Identity/PasswordlessSignInManager.cs b/src/Core/Identity/PasswordlessSignInManager.cs index 1ca010835b..a9f400058a 100644 --- a/src/Core/Identity/PasswordlessSignInManager.cs +++ b/src/Core/Identity/PasswordlessSignInManager.cs @@ -5,85 +5,86 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -namespace Bit.Core.Identity; - -public class PasswordlessSignInManager : SignInManager where TUser : class +namespace Bit.Core.Identity { - public const string PasswordlessSignInPurpose = "PasswordlessSignIn"; - - private readonly IMailService _mailService; - - public PasswordlessSignInManager(UserManager userManager, - IHttpContextAccessor contextAccessor, - IUserClaimsPrincipalFactory claimsFactory, - IOptions optionsAccessor, - ILogger> logger, - IAuthenticationSchemeProvider schemes, - IUserConfirmation confirmation, - IMailService mailService) - : base(userManager, contextAccessor, claimsFactory, optionsAccessor, logger, schemes, confirmation) + public class PasswordlessSignInManager : SignInManager where TUser : class { - _mailService = mailService; - } + public const string PasswordlessSignInPurpose = "PasswordlessSignIn"; - public async Task PasswordlessSignInAsync(string email, string returnUrl) - { - var user = await UserManager.FindByEmailAsync(email); - if (user == null) + private readonly IMailService _mailService; + + public PasswordlessSignInManager(UserManager userManager, + IHttpContextAccessor contextAccessor, + IUserClaimsPrincipalFactory claimsFactory, + IOptions optionsAccessor, + ILogger> logger, + IAuthenticationSchemeProvider schemes, + IUserConfirmation confirmation, + IMailService mailService) + : base(userManager, contextAccessor, claimsFactory, optionsAccessor, logger, schemes, confirmation) { - return SignInResult.Failed; + _mailService = mailService; } - var token = await UserManager.GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, - PasswordlessSignInPurpose); - await _mailService.SendPasswordlessSignInAsync(returnUrl, token, email); - return SignInResult.Success; - } - - public async Task PasswordlessSignInAsync(TUser user, string token, bool isPersistent) - { - if (user == null) + public async Task PasswordlessSignInAsync(string email, string returnUrl) { - throw new ArgumentNullException(nameof(user)); - } + var user = await UserManager.FindByEmailAsync(email); + if (user == null) + { + return SignInResult.Failed; + } - var attempt = await CheckPasswordlessSignInAsync(user, token); - return attempt.Succeeded ? - await SignInOrTwoFactorAsync(user, isPersistent, bypassTwoFactor: true) : attempt; - } - - public async Task PasswordlessSignInAsync(string email, string token, bool isPersistent) - { - var user = await UserManager.FindByEmailAsync(email); - if (user == null) - { - return SignInResult.Failed; - } - - return await PasswordlessSignInAsync(user, token, isPersistent); - } - - public virtual async Task CheckPasswordlessSignInAsync(TUser user, string token) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var error = await PreSignInCheck(user); - if (error != null) - { - return error; - } - - if (await UserManager.VerifyUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, - PasswordlessSignInPurpose, token)) - { + var token = await UserManager.GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, + PasswordlessSignInPurpose); + await _mailService.SendPasswordlessSignInAsync(returnUrl, token, email); return SignInResult.Success; } - Logger.LogWarning(2, "User {userId} failed to provide the correct token.", - await UserManager.GetUserIdAsync(user)); - return SignInResult.Failed; + public async Task PasswordlessSignInAsync(TUser user, string token, bool isPersistent) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var attempt = await CheckPasswordlessSignInAsync(user, token); + return attempt.Succeeded ? + await SignInOrTwoFactorAsync(user, isPersistent, bypassTwoFactor: true) : attempt; + } + + public async Task PasswordlessSignInAsync(string email, string token, bool isPersistent) + { + var user = await UserManager.FindByEmailAsync(email); + if (user == null) + { + return SignInResult.Failed; + } + + return await PasswordlessSignInAsync(user, token, isPersistent); + } + + public virtual async Task CheckPasswordlessSignInAsync(TUser user, string token) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var error = await PreSignInCheck(user); + if (error != null) + { + return error; + } + + if (await UserManager.VerifyUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, + PasswordlessSignInPurpose, token)) + { + return SignInResult.Success; + } + + Logger.LogWarning(2, "User {userId} failed to provide the correct token.", + await UserManager.GetUserIdAsync(user)); + return SignInResult.Failed; + } } } diff --git a/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs b/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs index 70d3da0072..7f4b76755b 100644 --- a/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs +++ b/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs @@ -2,37 +2,38 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity; - -public class ReadOnlyDatabaseIdentityUserStore : ReadOnlyIdentityUserStore +namespace Bit.Core.Identity { - private readonly IUserService _userService; - private readonly IUserRepository _userRepository; - - public ReadOnlyDatabaseIdentityUserStore( - IUserService userService, - IUserRepository userRepository) + public class ReadOnlyDatabaseIdentityUserStore : ReadOnlyIdentityUserStore { - _userService = userService; - _userRepository = userRepository; - } + private readonly IUserService _userService; + private readonly IUserRepository _userRepository; - public override async Task FindByEmailAsync(string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)) - { - var user = await _userRepository.GetByEmailAsync(normalizedEmail); - return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); - } - - public override async Task FindByIdAsync(string userId, - CancellationToken cancellationToken = default(CancellationToken)) - { - if (!Guid.TryParse(userId, out var userIdGuid)) + public ReadOnlyDatabaseIdentityUserStore( + IUserService userService, + IUserRepository userRepository) { - return null; + _userService = userService; + _userRepository = userRepository; } - var user = await _userRepository.GetByIdAsync(userIdGuid); - return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); + public override async Task FindByEmailAsync(string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)) + { + var user = await _userRepository.GetByEmailAsync(normalizedEmail); + return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); + } + + public override async Task FindByIdAsync(string userId, + CancellationToken cancellationToken = default(CancellationToken)) + { + if (!Guid.TryParse(userId, out var userIdGuid)) + { + return null; + } + + var user = await _userRepository.GetByIdAsync(userIdGuid); + return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); + } } } diff --git a/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs b/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs index 341bcd38a9..26cc7a3c86 100644 --- a/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs +++ b/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs @@ -2,65 +2,66 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Configuration; -namespace Bit.Core.Identity; - -public class ReadOnlyEnvIdentityUserStore : ReadOnlyIdentityUserStore +namespace Bit.Core.Identity { - private readonly IConfiguration _configuration; - - public ReadOnlyEnvIdentityUserStore(IConfiguration configuration) + public class ReadOnlyEnvIdentityUserStore : ReadOnlyIdentityUserStore { - _configuration = configuration; - } + private readonly IConfiguration _configuration; - public override Task FindByEmailAsync(string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)) - { - var usersCsv = _configuration["adminSettings:admins"]; - if (!CoreHelpers.SettingHasValue(usersCsv)) + public ReadOnlyEnvIdentityUserStore(IConfiguration configuration) { - return Task.FromResult(null); + _configuration = configuration; } - var users = usersCsv.ToLowerInvariant().Split(','); - var usersDict = new Dictionary(); - foreach (var u in users) + public override Task FindByEmailAsync(string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)) { - var parts = u.Split(':'); - if (parts.Length == 2) + var usersCsv = _configuration["adminSettings:admins"]; + if (!CoreHelpers.SettingHasValue(usersCsv)) { - var email = parts[0].Trim(); - var stamp = parts[1].Trim(); - usersDict.Add(email, stamp); + return Task.FromResult(null); } - else + + var users = usersCsv.ToLowerInvariant().Split(','); + var usersDict = new Dictionary(); + foreach (var u in users) { - var email = parts[0].Trim(); - usersDict.Add(email, email); + var parts = u.Split(':'); + if (parts.Length == 2) + { + var email = parts[0].Trim(); + var stamp = parts[1].Trim(); + usersDict.Add(email, stamp); + } + else + { + var email = parts[0].Trim(); + usersDict.Add(email, email); + } } + + var userStamp = usersDict.ContainsKey(normalizedEmail) ? usersDict[normalizedEmail] : null; + if (userStamp == null) + { + return Task.FromResult(null); + } + + return Task.FromResult(new IdentityUser + { + Id = normalizedEmail, + Email = normalizedEmail, + NormalizedEmail = normalizedEmail, + EmailConfirmed = true, + UserName = normalizedEmail, + NormalizedUserName = normalizedEmail, + SecurityStamp = userStamp + }); } - var userStamp = usersDict.ContainsKey(normalizedEmail) ? usersDict[normalizedEmail] : null; - if (userStamp == null) + public override Task FindByIdAsync(string userId, + CancellationToken cancellationToken = default(CancellationToken)) { - return Task.FromResult(null); + return FindByEmailAsync(userId, cancellationToken); } - - return Task.FromResult(new IdentityUser - { - Id = normalizedEmail, - Email = normalizedEmail, - NormalizedEmail = normalizedEmail, - EmailConfirmed = true, - UserName = normalizedEmail, - NormalizedUserName = normalizedEmail, - SecurityStamp = userStamp - }); - } - - public override Task FindByIdAsync(string userId, - CancellationToken cancellationToken = default(CancellationToken)) - { - return FindByEmailAsync(userId, cancellationToken); } } diff --git a/src/Core/Identity/ReadOnlyIdentityUserStore.cs b/src/Core/Identity/ReadOnlyIdentityUserStore.cs index 50c42c8197..d27b0a32fa 100644 --- a/src/Core/Identity/ReadOnlyIdentityUserStore.cs +++ b/src/Core/Identity/ReadOnlyIdentityUserStore.cs @@ -1,119 +1,120 @@ using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity; - -public abstract class ReadOnlyIdentityUserStore : - IUserStore, - IUserEmailStore, - IUserSecurityStampStore +namespace Bit.Core.Identity { - public void Dispose() { } - - public Task CreateAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) + public abstract class ReadOnlyIdentityUserStore : + IUserStore, + IUserEmailStore, + IUserSecurityStampStore { - throw new NotImplementedException(); - } + public void Dispose() { } - public Task DeleteAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task CreateAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public abstract Task FindByEmailAsync(string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)); + public Task DeleteAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public abstract Task FindByIdAsync(string userId, - CancellationToken cancellationToken = default(CancellationToken)); + public abstract Task FindByEmailAsync(string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)); - public async Task FindByNameAsync(string normalizedUserName, - CancellationToken cancellationToken = default(CancellationToken)) - { - return await FindByEmailAsync(normalizedUserName, cancellationToken); - } + public abstract Task FindByIdAsync(string userId, + CancellationToken cancellationToken = default(CancellationToken)); - public Task GetEmailAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public async Task FindByNameAsync(string normalizedUserName, + CancellationToken cancellationToken = default(CancellationToken)) + { + return await FindByEmailAsync(normalizedUserName, cancellationToken); + } - public Task GetEmailConfirmedAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.EmailConfirmed); - } + public Task GetEmailAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetNormalizedEmailAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetEmailConfirmedAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.EmailConfirmed); + } - public Task GetNormalizedUserNameAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetNormalizedEmailAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetUserIdAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Id); - } + public Task GetNormalizedUserNameAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetUserNameAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetUserIdAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Id); + } - public Task SetEmailAsync(IdentityUser user, string email, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task GetUserNameAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task SetEmailConfirmedAsync(IdentityUser user, bool confirmed, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task SetEmailAsync(IdentityUser user, string email, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task SetNormalizedEmailAsync(IdentityUser user, string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)) - { - user.NormalizedEmail = normalizedEmail; - return Task.FromResult(0); - } + public Task SetEmailConfirmedAsync(IdentityUser user, bool confirmed, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task SetNormalizedUserNameAsync(IdentityUser user, string normalizedName, - CancellationToken cancellationToken = default(CancellationToken)) - { - user.NormalizedUserName = normalizedName; - return Task.FromResult(0); - } + public Task SetNormalizedEmailAsync(IdentityUser user, string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)) + { + user.NormalizedEmail = normalizedEmail; + return Task.FromResult(0); + } - public Task SetUserNameAsync(IdentityUser user, string userName, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task SetNormalizedUserNameAsync(IdentityUser user, string normalizedName, + CancellationToken cancellationToken = default(CancellationToken)) + { + user.NormalizedUserName = normalizedName; + return Task.FromResult(0); + } - public Task UpdateAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(IdentityResult.Success); - } + public Task SetUserNameAsync(IdentityUser user, string userName, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task SetSecurityStampAsync(IdentityUser user, string stamp, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task UpdateAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(IdentityResult.Success); + } - public Task GetSecurityStampAsync(IdentityUser user, CancellationToken cancellationToken) - { - return Task.FromResult(user.SecurityStamp); + public Task SetSecurityStampAsync(IdentityUser user, string stamp, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public Task GetSecurityStampAsync(IdentityUser user, CancellationToken cancellationToken) + { + return Task.FromResult(user.SecurityStamp); + } } } diff --git a/src/Core/Identity/RoleStore.cs b/src/Core/Identity/RoleStore.cs index d6fe3f42f0..f96748f249 100644 --- a/src/Core/Identity/RoleStore.cs +++ b/src/Core/Identity/RoleStore.cs @@ -1,60 +1,61 @@ using Bit.Core.Entities; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity; - -public class RoleStore : IRoleStore +namespace Bit.Core.Identity { - public void Dispose() { } - - public Task CreateAsync(Role role, CancellationToken cancellationToken) + public class RoleStore : IRoleStore { - throw new NotImplementedException(); - } + public void Dispose() { } - public Task DeleteAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task CreateAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task FindByIdAsync(string roleId, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task DeleteAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task FindByNameAsync(string normalizedRoleName, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task FindByIdAsync(string roleId, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task GetNormalizedRoleNameAsync(Role role, CancellationToken cancellationToken) - { - return Task.FromResult(role.Name); - } + public Task FindByNameAsync(string normalizedRoleName, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task GetRoleIdAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task GetNormalizedRoleNameAsync(Role role, CancellationToken cancellationToken) + { + return Task.FromResult(role.Name); + } - public Task GetRoleNameAsync(Role role, CancellationToken cancellationToken) - { - return Task.FromResult(role.Name); - } + public Task GetRoleIdAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task SetNormalizedRoleNameAsync(Role role, string normalizedName, CancellationToken cancellationToken) - { - return Task.FromResult(0); - } + public Task GetRoleNameAsync(Role role, CancellationToken cancellationToken) + { + return Task.FromResult(role.Name); + } - public Task SetRoleNameAsync(Role role, string roleName, CancellationToken cancellationToken) - { - role.Name = roleName; - return Task.FromResult(0); - } + public Task SetNormalizedRoleNameAsync(Role role, string normalizedName, CancellationToken cancellationToken) + { + return Task.FromResult(0); + } - public Task UpdateAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); + public Task SetRoleNameAsync(Role role, string roleName, CancellationToken cancellationToken) + { + role.Name = roleName; + return Task.FromResult(0); + } + + public Task UpdateAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } } } diff --git a/src/Core/Identity/TwoFactorRememberTokenProvider.cs b/src/Core/Identity/TwoFactorRememberTokenProvider.cs index 711c8c9331..2902280ff2 100644 --- a/src/Core/Identity/TwoFactorRememberTokenProvider.cs +++ b/src/Core/Identity/TwoFactorRememberTokenProvider.cs @@ -4,17 +4,18 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -namespace Bit.Core.Identity; - -public class TwoFactorRememberTokenProvider : DataProtectorTokenProvider +namespace Bit.Core.Identity { - public TwoFactorRememberTokenProvider( - IDataProtectionProvider dataProtectionProvider, - IOptions options, - ILogger> logger) - : base(dataProtectionProvider, options, logger) + public class TwoFactorRememberTokenProvider : DataProtectorTokenProvider + { + public TwoFactorRememberTokenProvider( + IDataProtectionProvider dataProtectionProvider, + IOptions options, + ILogger> logger) + : base(dataProtectionProvider, options, logger) + { } + } + + public class TwoFactorRememberTokenProviderOptions : DataProtectionTokenProviderOptions { } } - -public class TwoFactorRememberTokenProviderOptions : DataProtectionTokenProviderOptions -{ } diff --git a/src/Core/Identity/UserStore.cs b/src/Core/Identity/UserStore.cs index afa0656c11..53bd744847 100644 --- a/src/Core/Identity/UserStore.cs +++ b/src/Core/Identity/UserStore.cs @@ -5,179 +5,180 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity; - -public class UserStore : - IUserStore, - IUserPasswordStore, - IUserEmailStore, - IUserTwoFactorStore, - IUserSecurityStampStore +namespace Bit.Core.Identity { - private readonly IServiceProvider _serviceProvider; - private readonly IUserRepository _userRepository; - private readonly ICurrentContext _currentContext; - - public UserStore( - IServiceProvider serviceProvider, - IUserRepository userRepository, - ICurrentContext currentContext) + public class UserStore : + IUserStore, + IUserPasswordStore, + IUserEmailStore, + IUserTwoFactorStore, + IUserSecurityStampStore { - _serviceProvider = serviceProvider; - _userRepository = userRepository; - _currentContext = currentContext; - } + private readonly IServiceProvider _serviceProvider; + private readonly IUserRepository _userRepository; + private readonly ICurrentContext _currentContext; - public void Dispose() { } - - public async Task CreateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - await _userRepository.CreateAsync(user); - return IdentityResult.Success; - } - - public async Task DeleteAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - await _userRepository.DeleteAsync(user); - return IdentityResult.Success; - } - - public async Task FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) - { - if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail) + public UserStore( + IServiceProvider serviceProvider, + IUserRepository userRepository, + ICurrentContext currentContext) { + _serviceProvider = serviceProvider; + _userRepository = userRepository; + _currentContext = currentContext; + } + + public void Dispose() { } + + public async Task CreateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + await _userRepository.CreateAsync(user); + return IdentityResult.Success; + } + + public async Task DeleteAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + await _userRepository.DeleteAsync(user); + return IdentityResult.Success; + } + + public async Task FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) + { + if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail) + { + return _currentContext.User; + } + + _currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail); return _currentContext.User; } - _currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail); - return _currentContext.User; - } - - public async Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) - { - if (_currentContext?.User != null && - string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) + public async Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) { + if (_currentContext?.User != null && + string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) + { + return _currentContext.User; + } + + Guid userIdGuid; + if (!Guid.TryParse(userId, out userIdGuid)) + { + return null; + } + + _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); return _currentContext.User; } - Guid userIdGuid; - if (!Guid.TryParse(userId, out userIdGuid)) + public async Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) { - return null; + return await FindByEmailAsync(normalizedUserName, cancellationToken); } - _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); - return _currentContext.User; - } + public Task GetEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public async Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) - { - return await FindByEmailAsync(normalizedUserName, cancellationToken); - } + public Task GetEmailConfirmedAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.EmailVerified); + } - public Task GetEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetNormalizedEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetEmailConfirmedAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.EmailVerified); - } + public Task GetNormalizedUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetNormalizedEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetPasswordHashAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.MasterPassword); + } - public Task GetNormalizedUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetUserIdAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Id.ToString()); + } - public Task GetPasswordHashAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.MasterPassword); - } + public Task GetUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetUserIdAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Id.ToString()); - } + public Task HasPasswordAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(!string.IsNullOrWhiteSpace(user.MasterPassword)); + } - public Task GetUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task SetEmailAsync(User user, string email, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = email; + return Task.FromResult(0); + } - public Task HasPasswordAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(!string.IsNullOrWhiteSpace(user.MasterPassword)); - } + public Task SetEmailConfirmedAsync(User user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) + { + user.EmailVerified = confirmed; + return Task.FromResult(0); + } - public Task SetEmailAsync(User user, string email, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = email; - return Task.FromResult(0); - } + public Task SetNormalizedEmailAsync(User user, string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = normalizedEmail; + return Task.FromResult(0); + } - public Task SetEmailConfirmedAsync(User user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) - { - user.EmailVerified = confirmed; - return Task.FromResult(0); - } + public Task SetNormalizedUserNameAsync(User user, string normalizedName, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = normalizedName; + return Task.FromResult(0); + } - public Task SetNormalizedEmailAsync(User user, string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = normalizedEmail; - return Task.FromResult(0); - } + public Task SetPasswordHashAsync(User user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) + { + user.MasterPassword = passwordHash; + return Task.FromResult(0); + } - public Task SetNormalizedUserNameAsync(User user, string normalizedName, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = normalizedName; - return Task.FromResult(0); - } + public Task SetUserNameAsync(User user, string userName, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = userName; + return Task.FromResult(0); + } - public Task SetPasswordHashAsync(User user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) - { - user.MasterPassword = passwordHash; - return Task.FromResult(0); - } + public async Task UpdateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + return IdentityResult.Success; + } - public Task SetUserNameAsync(User user, string userName, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = userName; - return Task.FromResult(0); - } + public Task SetTwoFactorEnabledAsync(User user, bool enabled, CancellationToken cancellationToken) + { + // Do nothing... + return Task.FromResult(0); + } - public async Task UpdateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - return IdentityResult.Success; - } + public async Task GetTwoFactorEnabledAsync(User user, CancellationToken cancellationToken) + { + return await _serviceProvider.GetRequiredService().TwoFactorIsEnabledAsync(user); + } - public Task SetTwoFactorEnabledAsync(User user, bool enabled, CancellationToken cancellationToken) - { - // Do nothing... - return Task.FromResult(0); - } + public Task SetSecurityStampAsync(User user, string stamp, CancellationToken cancellationToken) + { + user.SecurityStamp = stamp; + return Task.FromResult(0); + } - public async Task GetTwoFactorEnabledAsync(User user, CancellationToken cancellationToken) - { - return await _serviceProvider.GetRequiredService().TwoFactorIsEnabledAsync(user); - } - - public Task SetSecurityStampAsync(User user, string stamp, CancellationToken cancellationToken) - { - user.SecurityStamp = stamp; - return Task.FromResult(0); - } - - public Task GetSecurityStampAsync(User user, CancellationToken cancellationToken) - { - return Task.FromResult(user.SecurityStamp); + public Task GetSecurityStampAsync(User user, CancellationToken cancellationToken) + { + return Task.FromResult(user.SecurityStamp); + } } } diff --git a/src/Core/Identity/WebAuthnTokenProvider.cs b/src/Core/Identity/WebAuthnTokenProvider.cs index b34b6b1871..ee857422a1 100644 --- a/src/Core/Identity/WebAuthnTokenProvider.cs +++ b/src/Core/Identity/WebAuthnTokenProvider.cs @@ -10,145 +10,146 @@ using Fido2NetLib.Objects; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity; - -public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider +namespace Bit.Core.Identity { - private readonly IServiceProvider _serviceProvider; - private readonly IFido2 _fido2; - private readonly GlobalSettings _globalSettings; - - public WebAuthnTokenProvider(IServiceProvider serviceProvider, IFido2 fido2, GlobalSettings globalSettings) + public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider { - _serviceProvider = serviceProvider; - _fido2 = fido2; - _globalSettings = globalSettings; - } + private readonly IServiceProvider _serviceProvider; + private readonly IFido2 _fido2; + private readonly GlobalSettings _globalSettings; - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + public WebAuthnTokenProvider(IServiceProvider serviceProvider, IFido2 fido2, GlobalSettings globalSettings) { - return false; + _serviceProvider = serviceProvider; + _fido2 = fido2; + _globalSettings = globalSettings; } - var webAuthnProvider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (!HasProperMetaData(webAuthnProvider)) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) { - return false; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return false; + } + + var webAuthnProvider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (!HasProperMetaData(webAuthnProvider)) + { + return false; + } + + return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.WebAuthn, user); } - return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.WebAuthn, user); - } - - public async Task GenerateAsync(string purpose, UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + public async Task GenerateAsync(string purpose, UserManager manager, User user) { - return null; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return null; + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + var keys = LoadKeys(provider); + var existingCredentials = keys.Select(key => key.Item2.Descriptor).ToList(); + + if (existingCredentials.Count == 0) + { + return null; + } + + var exts = new AuthenticationExtensionsClientInputs() + { + UserVerificationMethod = true, + AppID = CoreHelpers.U2fAppIdUrl(_globalSettings), + }; + + var options = _fido2.GetAssertionOptions(existingCredentials, UserVerificationRequirement.Discouraged, exts); + + // TODO: Remove this when newtonsoft legacy converters are gone + provider.MetaData["login"] = JsonSerializer.Serialize(options); + + var providers = user.GetTwoFactorProviders(); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); + + return options.ToJson(); } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - var keys = LoadKeys(provider); - var existingCredentials = keys.Select(key => key.Item2.Descriptor).ToList(); - - if (existingCredentials.Count == 0) + public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) { - return null; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user)) || string.IsNullOrWhiteSpace(token)) + { + return false; + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + var keys = LoadKeys(provider); + + if (!provider.MetaData.ContainsKey("login")) + { + return false; + } + + var clientResponse = JsonSerializer.Deserialize(token, + new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + + var jsonOptions = provider.MetaData["login"].ToString(); + var options = AssertionOptions.FromJson(jsonOptions); + + var webAuthCred = keys.Find(k => k.Item2.Descriptor.Id.SequenceEqual(clientResponse.Id)); + + if (webAuthCred == null) + { + return false; + } + + IsUserHandleOwnerOfCredentialIdAsync callback = (args) => Task.FromResult(true); + + var res = await _fido2.MakeAssertionAsync(clientResponse, options, webAuthCred.Item2.PublicKey, webAuthCred.Item2.SignatureCounter, callback); + + provider.MetaData.Remove("login"); + + // Update SignatureCounter + webAuthCred.Item2.SignatureCounter = res.Counter; + + var providers = user.GetTwoFactorProviders(); + providers[TwoFactorProviderType.WebAuthn].MetaData[webAuthCred.Item1] = webAuthCred.Item2; + user.SetTwoFactorProviders(providers); + await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); + + return res.Status == "ok"; } - var exts = new AuthenticationExtensionsClientInputs() + private bool HasProperMetaData(TwoFactorProvider provider) { - UserVerificationMethod = true, - AppID = CoreHelpers.U2fAppIdUrl(_globalSettings), - }; - - var options = _fido2.GetAssertionOptions(existingCredentials, UserVerificationRequirement.Discouraged, exts); - - // TODO: Remove this when newtonsoft legacy converters are gone - provider.MetaData["login"] = JsonSerializer.Serialize(options); - - var providers = user.GetTwoFactorProviders(); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); - - return options.ToJson(); - } - - public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user)) || string.IsNullOrWhiteSpace(token)) - { - return false; + return provider?.MetaData?.Any() ?? false; } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - var keys = LoadKeys(provider); - - if (!provider.MetaData.ContainsKey("login")) + private List> LoadKeys(TwoFactorProvider provider) { - return false; - } + var keys = new List>(); + if (!HasProperMetaData(provider)) + { + return keys; + } - var clientResponse = JsonSerializer.Deserialize(token, - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + // Support up to 5 keys + for (var i = 1; i <= 5; i++) + { + var keyName = $"Key{i}"; + if (provider.MetaData.ContainsKey(keyName)) + { + var key = new TwoFactorProvider.WebAuthnData((dynamic)provider.MetaData[keyName]); - var jsonOptions = provider.MetaData["login"].ToString(); - var options = AssertionOptions.FromJson(jsonOptions); + keys.Add(new Tuple(keyName, key)); + } + } - var webAuthCred = keys.Find(k => k.Item2.Descriptor.Id.SequenceEqual(clientResponse.Id)); - - if (webAuthCred == null) - { - return false; - } - - IsUserHandleOwnerOfCredentialIdAsync callback = (args) => Task.FromResult(true); - - var res = await _fido2.MakeAssertionAsync(clientResponse, options, webAuthCred.Item2.PublicKey, webAuthCred.Item2.SignatureCounter, callback); - - provider.MetaData.Remove("login"); - - // Update SignatureCounter - webAuthCred.Item2.SignatureCounter = res.Counter; - - var providers = user.GetTwoFactorProviders(); - providers[TwoFactorProviderType.WebAuthn].MetaData[webAuthCred.Item1] = webAuthCred.Item2; - user.SetTwoFactorProviders(providers); - await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); - - return res.Status == "ok"; - } - - private bool HasProperMetaData(TwoFactorProvider provider) - { - return provider?.MetaData?.Any() ?? false; - } - - private List> LoadKeys(TwoFactorProvider provider) - { - var keys = new List>(); - if (!HasProperMetaData(provider)) - { return keys; } - - // Support up to 5 keys - for (var i = 1; i <= 5; i++) - { - var keyName = $"Key{i}"; - if (provider.MetaData.ContainsKey(keyName)) - { - var key = new TwoFactorProvider.WebAuthnData((dynamic)provider.MetaData[keyName]); - - keys.Add(new Tuple(keyName, key)); - } - } - - return keys; } } diff --git a/src/Core/Identity/YubicoOtpTokenProvider.cs b/src/Core/Identity/YubicoOtpTokenProvider.cs index 3d7bb9fe71..763cdbf4b9 100644 --- a/src/Core/Identity/YubicoOtpTokenProvider.cs +++ b/src/Core/Identity/YubicoOtpTokenProvider.cs @@ -6,70 +6,71 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; using YubicoDotNetClient; -namespace Bit.Core.Identity; - -public class YubicoOtpTokenProvider : IUserTwoFactorTokenProvider +namespace Bit.Core.Identity { - private readonly IServiceProvider _serviceProvider; - private readonly GlobalSettings _globalSettings; - - public YubicoOtpTokenProvider( - IServiceProvider serviceProvider, - GlobalSettings globalSettings) + public class YubicoOtpTokenProvider : IUserTwoFactorTokenProvider { - _serviceProvider = serviceProvider; - _globalSettings = globalSettings; - } + private readonly IServiceProvider _serviceProvider; + private readonly GlobalSettings _globalSettings; - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + public YubicoOtpTokenProvider( + IServiceProvider serviceProvider, + GlobalSettings globalSettings) { - return false; + _serviceProvider = serviceProvider; + _globalSettings = globalSettings; } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); - if (!provider?.MetaData.Values.Any(v => !string.IsNullOrWhiteSpace((string)v)) ?? true) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) { - return false; + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return false; + } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); + if (!provider?.MetaData.Values.Any(v => !string.IsNullOrWhiteSpace((string)v)) ?? true) + { + return false; + } + + return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.YubiKey, user); } - return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.YubiKey, user); - } - - public Task GenerateAsync(string purpose, UserManager manager, User user) - { - return Task.FromResult(null); - } - - public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) - { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) + public Task GenerateAsync(string purpose, UserManager manager, User user) { - return false; + return Task.FromResult(null); } - if (string.IsNullOrWhiteSpace(token) || token.Length < 32 || token.Length > 48) + public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) { - return false; - } + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) + { + return false; + } - var id = token.Substring(0, 12); + if (string.IsNullOrWhiteSpace(token) || token.Length < 32 || token.Length > 48) + { + return false; + } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); - if (!provider.MetaData.ContainsValue(id)) - { - return false; - } + var id = token.Substring(0, 12); - var client = new YubicoClient(_globalSettings.Yubico.ClientId, _globalSettings.Yubico.Key); - if (_globalSettings.Yubico.ValidationUrls != null && _globalSettings.Yubico.ValidationUrls.Length > 0) - { - client.SetUrls(_globalSettings.Yubico.ValidationUrls); + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); + if (!provider.MetaData.ContainsValue(id)) + { + return false; + } + + var client = new YubicoClient(_globalSettings.Yubico.ClientId, _globalSettings.Yubico.Key); + if (_globalSettings.Yubico.ValidationUrls != null && _globalSettings.Yubico.ValidationUrls.Length > 0) + { + client.SetUrls(_globalSettings.Yubico.ValidationUrls); + } + var response = await client.VerifyAsync(token); + return response.Status == YubicoResponseStatus.Ok; } - var response = await client.VerifyAsync(token); - return response.Status == YubicoResponseStatus.Ok; } } diff --git a/src/Core/IdentityServer/ApiClient.cs b/src/Core/IdentityServer/ApiClient.cs index b289da0015..a17bb32f9b 100644 --- a/src/Core/IdentityServer/ApiClient.cs +++ b/src/Core/IdentityServer/ApiClient.cs @@ -1,77 +1,78 @@ using Bit.Core.Settings; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer; - -public class ApiClient : Client +namespace Bit.Core.IdentityServer { - public ApiClient( - GlobalSettings globalSettings, - string id, - int refreshTokenSlidingDays, - int accessTokenLifetimeHours, - string[] scopes = null) + public class ApiClient : Client { - ClientId = id; - AllowedGrantTypes = new[] { GrantType.ResourceOwnerPassword, GrantType.AuthorizationCode }; - RefreshTokenExpiration = TokenExpiration.Sliding; - RefreshTokenUsage = TokenUsage.ReUse; - SlidingRefreshTokenLifetime = 86400 * refreshTokenSlidingDays; - AbsoluteRefreshTokenLifetime = 0; // forever - UpdateAccessTokenClaimsOnRefresh = true; - AccessTokenLifetime = 3600 * accessTokenLifetimeHours; - AllowOfflineAccess = true; + public ApiClient( + GlobalSettings globalSettings, + string id, + int refreshTokenSlidingDays, + int accessTokenLifetimeHours, + string[] scopes = null) + { + ClientId = id; + AllowedGrantTypes = new[] { GrantType.ResourceOwnerPassword, GrantType.AuthorizationCode }; + RefreshTokenExpiration = TokenExpiration.Sliding; + RefreshTokenUsage = TokenUsage.ReUse; + SlidingRefreshTokenLifetime = 86400 * refreshTokenSlidingDays; + AbsoluteRefreshTokenLifetime = 0; // forever + UpdateAccessTokenClaimsOnRefresh = true; + AccessTokenLifetime = 3600 * accessTokenLifetimeHours; + AllowOfflineAccess = true; - RequireConsent = false; - RequirePkce = true; - RequireClientSecret = false; - if (id == "web") - { - RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; - PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; - AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; - } - else if (id == "desktop") - { - RedirectUris = new[] { "bitwarden://sso-callback" }; - PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; - } - else if (id == "connector") - { - var connectorUris = new List(); - for (var port = 8065; port <= 8070; port++) + RequireConsent = false; + RequirePkce = true; + RequireClientSecret = false; + if (id == "web") { - connectorUris.Add(string.Format("http://localhost:{0}", port)); + RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; + PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; + AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; } - RedirectUris = connectorUris.Append("bwdc://sso-callback").ToList(); - PostLogoutRedirectUris = connectorUris.Append("bwdc://logged-out").ToList(); - } - else if (id == "browser") - { - RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; - PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; - AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; - } - else if (id == "cli") - { - var cliUris = new List(); - for (var port = 8065; port <= 8070; port++) + else if (id == "desktop") { - cliUris.Add(string.Format("http://localhost:{0}", port)); + RedirectUris = new[] { "bitwarden://sso-callback" }; + PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; + } + else if (id == "connector") + { + var connectorUris = new List(); + for (var port = 8065; port <= 8070; port++) + { + connectorUris.Add(string.Format("http://localhost:{0}", port)); + } + RedirectUris = connectorUris.Append("bwdc://sso-callback").ToList(); + PostLogoutRedirectUris = connectorUris.Append("bwdc://logged-out").ToList(); + } + else if (id == "browser") + { + RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; + PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; + AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; + } + else if (id == "cli") + { + var cliUris = new List(); + for (var port = 8065; port <= 8070; port++) + { + cliUris.Add(string.Format("http://localhost:{0}", port)); + } + RedirectUris = cliUris; + PostLogoutRedirectUris = cliUris; + } + else if (id == "mobile") + { + RedirectUris = new[] { "bitwarden://sso-callback" }; + PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; } - RedirectUris = cliUris; - PostLogoutRedirectUris = cliUris; - } - else if (id == "mobile") - { - RedirectUris = new[] { "bitwarden://sso-callback" }; - PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; - } - if (scopes == null) - { - scopes = new string[] { "api" }; + if (scopes == null) + { + scopes = new string[] { "api" }; + } + AllowedScopes = scopes; } - AllowedScopes = scopes; } } diff --git a/src/Core/IdentityServer/ApiResources.cs b/src/Core/IdentityServer/ApiResources.cs index 5a19fa2caf..55b3427cd8 100644 --- a/src/Core/IdentityServer/ApiResources.cs +++ b/src/Core/IdentityServer/ApiResources.cs @@ -1,35 +1,36 @@ using IdentityModel; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer; - -public class ApiResources +namespace Bit.Core.IdentityServer { - public static IEnumerable GetApiResources() + public class ApiResources { - return new List + public static IEnumerable GetApiResources() { - new ApiResource("api", new string[] { - JwtClaimTypes.Name, - JwtClaimTypes.Email, - JwtClaimTypes.EmailVerified, - "sstamp", // security stamp - "premium", - "device", - "orgowner", - "orgadmin", - "orgmanager", - "orguser", - "orgcustom", - "providerprovideradmin", - "providerserviceuser", - }), - new ApiResource("internal", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.push", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.licensing", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.organization", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.provider", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.installation", new string[] { JwtClaimTypes.Subject }), - }; + return new List + { + new ApiResource("api", new string[] { + JwtClaimTypes.Name, + JwtClaimTypes.Email, + JwtClaimTypes.EmailVerified, + "sstamp", // security stamp + "premium", + "device", + "orgowner", + "orgadmin", + "orgmanager", + "orguser", + "orgcustom", + "providerprovideradmin", + "providerserviceuser", + }), + new ApiResource("internal", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.push", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.licensing", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.organization", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.provider", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.installation", new string[] { JwtClaimTypes.Subject }), + }; + } } } diff --git a/src/Core/IdentityServer/ApiScopes.cs b/src/Core/IdentityServer/ApiScopes.cs index 2af512eb89..e98465964c 100644 --- a/src/Core/IdentityServer/ApiScopes.cs +++ b/src/Core/IdentityServer/ApiScopes.cs @@ -1,19 +1,20 @@ using IdentityServer4.Models; -namespace Bit.Core.IdentityServer; - -public class ApiScopes +namespace Bit.Core.IdentityServer { - public static IEnumerable GetApiScopes() + public class ApiScopes { - return new List + public static IEnumerable GetApiScopes() { - new ApiScope("api", "API Access"), - new ApiScope("api.push", "API Push Access"), - new ApiScope("api.licensing", "API Licensing Access"), - new ApiScope("api.organization", "API Organization Access"), - new ApiScope("api.installation", "API Installation Access"), - new ApiScope("internal", "Internal Access") - }; + return new List + { + new ApiScope("api", "API Access"), + new ApiScope("api.push", "API Push Access"), + new ApiScope("api.licensing", "API Licensing Access"), + new ApiScope("api.organization", "API Organization Access"), + new ApiScope("api.installation", "API Installation Access"), + new ApiScope("internal", "Internal Access") + }; + } } } diff --git a/src/Core/IdentityServer/AuthorizationCodeStore.cs b/src/Core/IdentityServer/AuthorizationCodeStore.cs index fc07f7aa6e..7bf01f6eb2 100644 --- a/src/Core/IdentityServer/AuthorizationCodeStore.cs +++ b/src/Core/IdentityServer/AuthorizationCodeStore.cs @@ -6,38 +6,39 @@ using IdentityServer4.Stores; using IdentityServer4.Stores.Serialization; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer; - -// ref: https://raw.githubusercontent.com/IdentityServer/IdentityServer4/3.1.3/src/IdentityServer4/src/Stores/Default/DefaultAuthorizationCodeStore.cs -public class AuthorizationCodeStore : DefaultGrantStore, IAuthorizationCodeStore +namespace Bit.Core.IdentityServer { - public AuthorizationCodeStore( - IPersistedGrantStore store, - IPersistentGrantSerializer serializer, - IHandleGenerationService handleGenerationService, - ILogger logger) - : base(IdentityServerConstants.PersistedGrantTypes.AuthorizationCode, store, serializer, - handleGenerationService, logger) - { } - - public Task StoreAuthorizationCodeAsync(AuthorizationCode code) + // ref: https://raw.githubusercontent.com/IdentityServer/IdentityServer4/3.1.3/src/IdentityServer4/src/Stores/Default/DefaultAuthorizationCodeStore.cs + public class AuthorizationCodeStore : DefaultGrantStore, IAuthorizationCodeStore { - return CreateItemAsync(code, code.ClientId, code.Subject.GetSubjectId(), code.SessionId, - code.Description, code.CreationTime, code.Lifetime); - } + public AuthorizationCodeStore( + IPersistedGrantStore store, + IPersistentGrantSerializer serializer, + IHandleGenerationService handleGenerationService, + ILogger logger) + : base(IdentityServerConstants.PersistedGrantTypes.AuthorizationCode, store, serializer, + handleGenerationService, logger) + { } - public Task GetAuthorizationCodeAsync(string code) - { - return GetItemAsync(code); - } + public Task StoreAuthorizationCodeAsync(AuthorizationCode code) + { + return CreateItemAsync(code, code.ClientId, code.Subject.GetSubjectId(), code.SessionId, + code.Description, code.CreationTime, code.Lifetime); + } - public Task RemoveAuthorizationCodeAsync(string code) - { - // return RemoveItemAsync(code); + public Task GetAuthorizationCodeAsync(string code) + { + return GetItemAsync(code); + } - // We don't want to delete authorization codes during validation. - // We'll rely on the authorization code lifecycle for short term validation and the - // DatabaseExpiredGrantsJob to clean up old authorization codes. - return Task.FromResult(0); + public Task RemoveAuthorizationCodeAsync(string code) + { + // return RemoveItemAsync(code); + + // We don't want to delete authorization codes during validation. + // We'll rely on the authorization code lifecycle for short term validation and the + // DatabaseExpiredGrantsJob to clean up old authorization codes. + return Task.FromResult(0); + } } } diff --git a/src/Core/IdentityServer/BaseRequestValidator.cs b/src/Core/IdentityServer/BaseRequestValidator.cs index d2c72c132b..632b548b2e 100644 --- a/src/Core/IdentityServer/BaseRequestValidator.cs +++ b/src/Core/IdentityServer/BaseRequestValidator.cs @@ -17,606 +17,607 @@ using IdentityServer4.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer; - -public abstract class BaseRequestValidator where T : class +namespace Bit.Core.IdentityServer { - private UserManager _userManager; - private readonly IDeviceRepository _deviceRepository; - private readonly IDeviceService _deviceService; - private readonly IUserService _userService; - private readonly IEventService _eventService; - private readonly IOrganizationDuoWebTokenProvider _organizationDuoWebTokenProvider; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IMailService _mailService; - private readonly ILogger _logger; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - private readonly IPolicyRepository _policyRepository; - private readonly IUserRepository _userRepository; - private readonly ICaptchaValidationService _captchaValidationService; - - public BaseRequestValidator( - UserManager userManager, - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService, - IEventService eventService, - IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IApplicationCacheService applicationCacheService, - IMailService mailService, - ILogger logger, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - IUserRepository userRepository, - ICaptchaValidationService captchaValidationService) + public abstract class BaseRequestValidator where T : class { - _userManager = userManager; - _deviceRepository = deviceRepository; - _deviceService = deviceService; - _userService = userService; - _eventService = eventService; - _organizationDuoWebTokenProvider = organizationDuoWebTokenProvider; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _applicationCacheService = applicationCacheService; - _mailService = mailService; - _logger = logger; - _currentContext = currentContext; - _globalSettings = globalSettings; - _policyRepository = policyRepository; - _userRepository = userRepository; - _captchaValidationService = captchaValidationService; - } + private UserManager _userManager; + private readonly IDeviceRepository _deviceRepository; + private readonly IDeviceService _deviceService; + private readonly IUserService _userService; + private readonly IEventService _eventService; + private readonly IOrganizationDuoWebTokenProvider _organizationDuoWebTokenProvider; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IMailService _mailService; + private readonly ILogger _logger; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + private readonly IPolicyRepository _policyRepository; + private readonly IUserRepository _userRepository; + private readonly ICaptchaValidationService _captchaValidationService; - protected async Task ValidateAsync(T context, ValidatedTokenRequest request, - CustomValidatorRequestContext validatorContext) - { - var isBot = (validatorContext.CaptchaResponse?.IsBot ?? false); - if (isBot) + public BaseRequestValidator( + UserManager userManager, + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService, + IEventService eventService, + IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IApplicationCacheService applicationCacheService, + IMailService mailService, + ILogger logger, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + IUserRepository userRepository, + ICaptchaValidationService captchaValidationService) { - _logger.LogInformation(Constants.BypassFiltersEventId, - "Login attempt for {0} detected as a captcha bot with score {1}.", - request.UserName, validatorContext.CaptchaResponse.Score); + _userManager = userManager; + _deviceRepository = deviceRepository; + _deviceService = deviceService; + _userService = userService; + _eventService = eventService; + _organizationDuoWebTokenProvider = organizationDuoWebTokenProvider; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _applicationCacheService = applicationCacheService; + _mailService = mailService; + _logger = logger; + _currentContext = currentContext; + _globalSettings = globalSettings; + _policyRepository = policyRepository; + _userRepository = userRepository; + _captchaValidationService = captchaValidationService; } - var twoFactorToken = request.Raw["TwoFactorToken"]?.ToString(); - var twoFactorProvider = request.Raw["TwoFactorProvider"]?.ToString(); - var twoFactorRemember = request.Raw["TwoFactorRemember"]?.ToString() == "1"; - var twoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) + protected async Task ValidateAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) { - await UpdateFailedAuthDetailsAsync(user, false, !validatorContext.KnownDevice); - } - if (!valid || isBot) - { - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); - return; - } - - var (isTwoFactorRequired, requires2FABecauseNewDevice, twoFactorOrganization) = await RequiresTwoFactorAsync(user, request); - if (isTwoFactorRequired) - { - // Just defaulting it - var twoFactorProviderType = TwoFactorProviderType.Authenticator; - if (!twoFactorRequest || !Enum.TryParse(twoFactorProvider, out twoFactorProviderType)) + var isBot = (validatorContext.CaptchaResponse?.IsBot ?? false); + if (isBot) { - await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); + _logger.LogInformation(Constants.BypassFiltersEventId, + "Login attempt for {0} detected as a captcha bot with score {1}.", + request.UserName, validatorContext.CaptchaResponse.Score); + } + + var twoFactorToken = request.Raw["TwoFactorToken"]?.ToString(); + var twoFactorProvider = request.Raw["TwoFactorProvider"]?.ToString(); + var twoFactorRemember = request.Raw["TwoFactorRemember"]?.ToString() == "1"; + var twoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (!valid) + { + await UpdateFailedAuthDetailsAsync(user, false, !validatorContext.KnownDevice); + } + if (!valid || isBot) + { + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); return; } - BeforeVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); - - var verified = await VerifyTwoFactor(user, twoFactorOrganization, - twoFactorProviderType, twoFactorToken); - - AfterVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); - - if ((!verified || isBot) && twoFactorProviderType != TwoFactorProviderType.Remember) + var (isTwoFactorRequired, requires2FABecauseNewDevice, twoFactorOrganization) = await RequiresTwoFactorAsync(user, request); + if (isTwoFactorRequired) { - await UpdateFailedAuthDetailsAsync(user, true, !validatorContext.KnownDevice); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - return; - } - else if ((!verified || isBot) && twoFactorProviderType == TwoFactorProviderType.Remember) - { - // Delay for brute force. - await Task.Delay(2000); - await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); - return; - } - } - else - { - twoFactorRequest = false; - twoFactorRemember = false; - twoFactorToken = null; - } - - // Returns true if can finish validation process - if (await IsValidAuthTypeAsync(user, request.GrantType)) - { - var device = await SaveDeviceAsync(user, request); - if (device == null) - { - await BuildErrorResultAsync("No device information provided.", false, context, user); - return; - } - await BuildSuccessResultAsync(user, context, device, twoFactorRequest && twoFactorRemember); - } - else - { - SetSsoResult(context, new Dictionary - {{ - "ErrorModel", new ErrorResponseModel("SSO authentication is required.") - }}); - } - } - - protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); - - protected async Task BuildSuccessResultAsync(User user, T context, Device device, bool sendRememberToken) - { - await _eventService.LogUserEventAsync(user.Id, EventType.User_LoggedIn); - - var claims = new List(); - - if (device != null) - { - claims.Add(new Claim("device", device.Identifier)); - } - - var customResponse = new Dictionary(); - if (!string.IsNullOrWhiteSpace(user.PrivateKey)) - { - customResponse.Add("PrivateKey", user.PrivateKey); - } - - if (!string.IsNullOrWhiteSpace(user.Key)) - { - customResponse.Add("Key", user.Key); - } - - customResponse.Add("ForcePasswordReset", user.ForcePasswordReset); - customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword)); - customResponse.Add("Kdf", (byte)user.Kdf); - customResponse.Add("KdfIterations", user.KdfIterations); - - if (sendRememberToken) - { - var token = await _userManager.GenerateTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); - customResponse.Add("TwoFactorToken", token); - } - - await ResetFailedAuthDetailsAsync(user); - await SetSuccessResult(context, user, claims, customResponse); - } - - protected async Task BuildTwoFactorResultAsync(User user, Organization organization, T context, bool requires2FABecauseNewDevice) - { - var providerKeys = new List(); - var providers = new Dictionary>(); - - var enabledProviders = new List>(); - if (organization?.GetTwoFactorProviders() != null) - { - enabledProviders.AddRange(organization.GetTwoFactorProviders().Where( - p => organization.TwoFactorProviderIsEnabled(p.Key))); - } - - if (user.GetTwoFactorProviders() != null) - { - foreach (var p in user.GetTwoFactorProviders()) - { - if (await _userService.TwoFactorProviderIsEnabledAsync(p.Key, user)) + // Just defaulting it + var twoFactorProviderType = TwoFactorProviderType.Authenticator; + if (!twoFactorRequest || !Enum.TryParse(twoFactorProvider, out twoFactorProviderType)) { - enabledProviders.Add(p); + await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); + return; + } + + BeforeVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); + + var verified = await VerifyTwoFactor(user, twoFactorOrganization, + twoFactorProviderType, twoFactorToken); + + AfterVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); + + if ((!verified || isBot) && twoFactorProviderType != TwoFactorProviderType.Remember) + { + await UpdateFailedAuthDetailsAsync(user, true, !validatorContext.KnownDevice); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); + return; + } + else if ((!verified || isBot) && twoFactorProviderType == TwoFactorProviderType.Remember) + { + // Delay for brute force. + await Task.Delay(2000); + await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); + return; } } - } - - if (!enabledProviders.Any()) - { - if (!requires2FABecauseNewDevice) + else { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); - return; + twoFactorRequest = false; + twoFactorRemember = false; + twoFactorToken = null; } - var emailProvider = new TwoFactorProvider + // Returns true if can finish validation process + if (await IsValidAuthTypeAsync(user, request.GrantType)) { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - }; - enabledProviders.Add(new KeyValuePair( - TwoFactorProviderType.Email, emailProvider)); - user.SetTwoFactorProviders(new Dictionary - { - [TwoFactorProviderType.Email] = emailProvider - }); - } - - foreach (var provider in enabledProviders) - { - providerKeys.Add((byte)provider.Key); - var infoDict = await BuildTwoFactorParams(organization, user, provider.Key, provider.Value); - providers.Add(((byte)provider.Key).ToString(), infoDict); - } - - SetTwoFactorResult(context, - new Dictionary - { - { "TwoFactorProviders", providers.Keys }, - { "TwoFactorProviders2", providers } - }); - - if (enabledProviders.Count() == 1 && enabledProviders.First().Key == TwoFactorProviderType.Email) - { - // Send email now if this is their only 2FA method - await _userService.SendTwoFactorEmailAsync(user, requires2FABecauseNewDevice); - } - } - - protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) - { - if (user != null) - { - await _eventService.LogUserEventAsync(user.Id, - twoFactorRequest ? EventType.User_FailedLogIn2fa : EventType.User_FailedLogIn); - } - - if (_globalSettings.SelfHosted) - { - _logger.LogWarning(Constants.BypassFiltersEventId, - string.Format("Failed login attempt{0}{1}", twoFactorRequest ? ", 2FA invalid." : ".", - $" {_currentContext.IpAddress}")); - } - - await Task.Delay(2000); // Delay for brute force. - SetErrorResult(context, - new Dictionary - {{ - "ErrorModel", new ErrorResponseModel(message) - }}); - } - - protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); - - protected abstract void SetSsoResult(T context, Dictionary customResponse); - - protected abstract Task SetSuccessResult(T context, User user, List claims, - Dictionary customResponse); - - protected abstract void SetErrorResult(T context, Dictionary customResponse); - - private async Task> RequiresTwoFactorAsync(User user, ValidatedTokenRequest request) - { - if (request.GrantType == "client_credentials") - { - // Do not require MFA for api key logins - return new Tuple(false, false, null); - } - - var individualRequired = _userManager.SupportsUserTwoFactor && - await _userManager.GetTwoFactorEnabledAsync(user) && - (await _userManager.GetValidTwoFactorProvidersAsync(user)).Count > 0; - - Organization firstEnabledOrg = null; - var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) - .ToList(); - if (orgs.Any()) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var twoFactorOrgs = orgs.Where(o => OrgUsing2fa(orgAbilities, o.Id)); - if (twoFactorOrgs.Any()) - { - var userOrgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - firstEnabledOrg = userOrgs.FirstOrDefault( - o => orgs.Any(om => om.Id == o.Id) && o.TwoFactorIsEnabled()); - } - } - - var requires2FA = individualRequired || firstEnabledOrg != null; - var requires2FABecauseNewDevice = !requires2FA - && - await _userService.Needs2FABecauseNewDeviceAsync( - user, - GetDeviceFromRequest(request)?.Identifier, - request.GrantType); - - requires2FA = requires2FA || requires2FABecauseNewDevice; - - return new Tuple(requires2FA, requires2FABecauseNewDevice, firstEnabledOrg); - } - - private async Task IsValidAuthTypeAsync(User user, string grantType) - { - if (grantType == "authorization_code" || grantType == "client_credentials") - { - // Already using SSO to authorize, finish successfully - // Or login via api key, skip SSO requirement - return true; - } - - // Is user apart of any orgs? Use cache for initial checks. - var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) - .ToList(); - if (orgs.Any()) - { - // Get all org abilities - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - // Parse all user orgs that are enabled and have the ability to use sso - var ssoOrgs = orgs.Where(o => OrgCanUseSso(orgAbilities, o.Id)); - if (ssoOrgs.Any()) - { - // Parse users orgs and determine if require sso policy is enabled - var userOrgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - foreach (var userOrg in userOrgs.Where(o => o.Enabled && o.UseSso)) + var device = await SaveDeviceAsync(user, request); + if (device == null) { - var orgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(userOrg.OrganizationId, - PolicyType.RequireSso); - // Owners and Admins are exempt from this policy - if (orgPolicy != null && orgPolicy.Enabled && - userOrg.Type != OrganizationUserType.Owner && userOrg.Type != OrganizationUserType.Admin) + await BuildErrorResultAsync("No device information provided.", false, context, user); + return; + } + await BuildSuccessResultAsync(user, context, device, twoFactorRequest && twoFactorRemember); + } + else + { + SetSsoResult(context, new Dictionary + {{ + "ErrorModel", new ErrorResponseModel("SSO authentication is required.") + }}); + } + } + + protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); + + protected async Task BuildSuccessResultAsync(User user, T context, Device device, bool sendRememberToken) + { + await _eventService.LogUserEventAsync(user.Id, EventType.User_LoggedIn); + + var claims = new List(); + + if (device != null) + { + claims.Add(new Claim("device", device.Identifier)); + } + + var customResponse = new Dictionary(); + if (!string.IsNullOrWhiteSpace(user.PrivateKey)) + { + customResponse.Add("PrivateKey", user.PrivateKey); + } + + if (!string.IsNullOrWhiteSpace(user.Key)) + { + customResponse.Add("Key", user.Key); + } + + customResponse.Add("ForcePasswordReset", user.ForcePasswordReset); + customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword)); + customResponse.Add("Kdf", (byte)user.Kdf); + customResponse.Add("KdfIterations", user.KdfIterations); + + if (sendRememberToken) + { + var token = await _userManager.GenerateTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); + customResponse.Add("TwoFactorToken", token); + } + + await ResetFailedAuthDetailsAsync(user); + await SetSuccessResult(context, user, claims, customResponse); + } + + protected async Task BuildTwoFactorResultAsync(User user, Organization organization, T context, bool requires2FABecauseNewDevice) + { + var providerKeys = new List(); + var providers = new Dictionary>(); + + var enabledProviders = new List>(); + if (organization?.GetTwoFactorProviders() != null) + { + enabledProviders.AddRange(organization.GetTwoFactorProviders().Where( + p => organization.TwoFactorProviderIsEnabled(p.Key))); + } + + if (user.GetTwoFactorProviders() != null) + { + foreach (var p in user.GetTwoFactorProviders()) + { + if (await _userService.TwoFactorProviderIsEnabledAsync(p.Key, user)) { - return false; + enabledProviders.Add(p); } } } - } - // Default - continue validation process - return true; - } - - private bool OrgUsing2fa(IDictionary orgAbilities, Guid orgId) - { - return orgAbilities != null && orgAbilities.ContainsKey(orgId) && - orgAbilities[orgId].Enabled && orgAbilities[orgId].Using2fa; - } - - private bool OrgCanUseSso(IDictionary orgAbilities, Guid orgId) - { - return orgAbilities != null && orgAbilities.ContainsKey(orgId) && - orgAbilities[orgId].Enabled && orgAbilities[orgId].UseSso; - } - - private Device GetDeviceFromRequest(ValidatedRequest request) - { - var deviceIdentifier = request.Raw["DeviceIdentifier"]?.ToString(); - var deviceType = request.Raw["DeviceType"]?.ToString(); - var deviceName = request.Raw["DeviceName"]?.ToString(); - var devicePushToken = request.Raw["DevicePushToken"]?.ToString(); - - if (string.IsNullOrWhiteSpace(deviceIdentifier) || string.IsNullOrWhiteSpace(deviceType) || - string.IsNullOrWhiteSpace(deviceName) || !Enum.TryParse(deviceType, out DeviceType type)) - { - return null; - } - - return new Device - { - Identifier = deviceIdentifier, - Name = deviceName, - Type = type, - PushToken = string.IsNullOrWhiteSpace(devicePushToken) ? null : devicePushToken - }; - } - - private void BeforeVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) - { - if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) - { - user.SetTwoFactorProviders(new Dictionary + if (!enabledProviders.Any()) { - [TwoFactorProviderType.Email] = new TwoFactorProvider + if (!requires2FABecauseNewDevice) + { + await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); + return; + } + + var emailProvider = new TwoFactorProvider { MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, Enabled = true - } - }); + }; + enabledProviders.Add(new KeyValuePair( + TwoFactorProviderType.Email, emailProvider)); + user.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = emailProvider + }); + } + + foreach (var provider in enabledProviders) + { + providerKeys.Add((byte)provider.Key); + var infoDict = await BuildTwoFactorParams(organization, user, provider.Key, provider.Value); + providers.Add(((byte)provider.Key).ToString(), infoDict); + } + + SetTwoFactorResult(context, + new Dictionary + { + { "TwoFactorProviders", providers.Keys }, + { "TwoFactorProviders2", providers } + }); + + if (enabledProviders.Count() == 1 && enabledProviders.First().Key == TwoFactorProviderType.Email) + { + // Send email now if this is their only 2FA method + await _userService.SendTwoFactorEmailAsync(user, requires2FABecauseNewDevice); + } } - } - private void AfterVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) - { - if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) + protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) { - user.ClearTwoFactorProviders(); + if (user != null) + { + await _eventService.LogUserEventAsync(user.Id, + twoFactorRequest ? EventType.User_FailedLogIn2fa : EventType.User_FailedLogIn); + } + + if (_globalSettings.SelfHosted) + { + _logger.LogWarning(Constants.BypassFiltersEventId, + string.Format("Failed login attempt{0}{1}", twoFactorRequest ? ", 2FA invalid." : ".", + $" {_currentContext.IpAddress}")); + } + + await Task.Delay(2000); // Delay for brute force. + SetErrorResult(context, + new Dictionary + {{ + "ErrorModel", new ErrorResponseModel(message) + }}); } - } - private async Task VerifyTwoFactor(User user, Organization organization, TwoFactorProviderType type, - string token) - { - switch (type) + protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); + + protected abstract void SetSsoResult(T context, Dictionary customResponse); + + protected abstract Task SetSuccessResult(T context, User user, List claims, + Dictionary customResponse); + + protected abstract void SetErrorResult(T context, Dictionary customResponse); + + private async Task> RequiresTwoFactorAsync(User user, ValidatedTokenRequest request) { - case TwoFactorProviderType.Authenticator: - case TwoFactorProviderType.Email: - case TwoFactorProviderType.Duo: - case TwoFactorProviderType.YubiKey: - case TwoFactorProviderType.WebAuthn: - case TwoFactorProviderType.Remember: - if (type != TwoFactorProviderType.Remember && - !(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) - { - return false; - } - return await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(type), token); - case TwoFactorProviderType.OrganizationDuo: - if (!organization?.TwoFactorProviderIsEnabled(type) ?? true) - { - return false; - } + if (request.GrantType == "client_credentials") + { + // Do not require MFA for api key logins + return new Tuple(false, false, null); + } - return await _organizationDuoWebTokenProvider.ValidateAsync(token, organization, user); - default: - return false; + var individualRequired = _userManager.SupportsUserTwoFactor && + await _userManager.GetTwoFactorEnabledAsync(user) && + (await _userManager.GetValidTwoFactorProvidersAsync(user)).Count > 0; + + Organization firstEnabledOrg = null; + var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) + .ToList(); + if (orgs.Any()) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var twoFactorOrgs = orgs.Where(o => OrgUsing2fa(orgAbilities, o.Id)); + if (twoFactorOrgs.Any()) + { + var userOrgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); + firstEnabledOrg = userOrgs.FirstOrDefault( + o => orgs.Any(om => om.Id == o.Id) && o.TwoFactorIsEnabled()); + } + } + + var requires2FA = individualRequired || firstEnabledOrg != null; + var requires2FABecauseNewDevice = !requires2FA + && + await _userService.Needs2FABecauseNewDeviceAsync( + user, + GetDeviceFromRequest(request)?.Identifier, + request.GrantType); + + requires2FA = requires2FA || requires2FABecauseNewDevice; + + return new Tuple(requires2FA, requires2FABecauseNewDevice, firstEnabledOrg); } - } - private async Task> BuildTwoFactorParams(Organization organization, User user, - TwoFactorProviderType type, TwoFactorProvider provider) - { - switch (type) + private async Task IsValidAuthTypeAsync(User user, string grantType) { - case TwoFactorProviderType.Duo: - case TwoFactorProviderType.WebAuthn: - case TwoFactorProviderType.Email: - case TwoFactorProviderType.YubiKey: - if (!(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) - { - return null; - } + if (grantType == "authorization_code" || grantType == "client_credentials") + { + // Already using SSO to authorize, finish successfully + // Or login via api key, skip SSO requirement + return true; + } - var token = await _userManager.GenerateTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(type)); - if (type == TwoFactorProviderType.Duo) + // Is user apart of any orgs? Use cache for initial checks. + var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) + .ToList(); + if (orgs.Any()) + { + // Get all org abilities + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + // Parse all user orgs that are enabled and have the ability to use sso + var ssoOrgs = orgs.Where(o => OrgCanUseSso(orgAbilities, o.Id)); + if (ssoOrgs.Any()) { - return new Dictionary + // Parse users orgs and determine if require sso policy is enabled + var userOrgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + foreach (var userOrg in userOrgs.Where(o => o.Enabled && o.UseSso)) { - ["Host"] = provider.MetaData["Host"], - ["Signature"] = token - }; + var orgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(userOrg.OrganizationId, + PolicyType.RequireSso); + // Owners and Admins are exempt from this policy + if (orgPolicy != null && orgPolicy.Enabled && + userOrg.Type != OrganizationUserType.Owner && userOrg.Type != OrganizationUserType.Admin) + { + return false; + } + } } - else if (type == TwoFactorProviderType.WebAuthn) + } + + // Default - continue validation process + return true; + } + + private bool OrgUsing2fa(IDictionary orgAbilities, Guid orgId) + { + return orgAbilities != null && orgAbilities.ContainsKey(orgId) && + orgAbilities[orgId].Enabled && orgAbilities[orgId].Using2fa; + } + + private bool OrgCanUseSso(IDictionary orgAbilities, Guid orgId) + { + return orgAbilities != null && orgAbilities.ContainsKey(orgId) && + orgAbilities[orgId].Enabled && orgAbilities[orgId].UseSso; + } + + private Device GetDeviceFromRequest(ValidatedRequest request) + { + var deviceIdentifier = request.Raw["DeviceIdentifier"]?.ToString(); + var deviceType = request.Raw["DeviceType"]?.ToString(); + var deviceName = request.Raw["DeviceName"]?.ToString(); + var devicePushToken = request.Raw["DevicePushToken"]?.ToString(); + + if (string.IsNullOrWhiteSpace(deviceIdentifier) || string.IsNullOrWhiteSpace(deviceType) || + string.IsNullOrWhiteSpace(deviceName) || !Enum.TryParse(deviceType, out DeviceType type)) + { + return null; + } + + return new Device + { + Identifier = deviceIdentifier, + Name = deviceName, + Type = type, + PushToken = string.IsNullOrWhiteSpace(devicePushToken) ? null : devicePushToken + }; + } + + private void BeforeVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) + { + if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) + { + user.SetTwoFactorProviders(new Dictionary { - if (token == null) + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); + } + } + + private void AfterVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) + { + if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) + { + user.ClearTwoFactorProviders(); + } + } + + private async Task VerifyTwoFactor(User user, Organization organization, TwoFactorProviderType type, + string token) + { + switch (type) + { + case TwoFactorProviderType.Authenticator: + case TwoFactorProviderType.Email: + case TwoFactorProviderType.Duo: + case TwoFactorProviderType.YubiKey: + case TwoFactorProviderType.WebAuthn: + case TwoFactorProviderType.Remember: + if (type != TwoFactorProviderType.Remember && + !(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) + { + return false; + } + return await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(type), token); + case TwoFactorProviderType.OrganizationDuo: + if (!organization?.TwoFactorProviderIsEnabled(type) ?? true) + { + return false; + } + + return await _organizationDuoWebTokenProvider.ValidateAsync(token, organization, user); + default: + return false; + } + } + + private async Task> BuildTwoFactorParams(Organization organization, User user, + TwoFactorProviderType type, TwoFactorProvider provider) + { + switch (type) + { + case TwoFactorProviderType.Duo: + case TwoFactorProviderType.WebAuthn: + case TwoFactorProviderType.Email: + case TwoFactorProviderType.YubiKey: + if (!(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) { return null; } - return JsonSerializer.Deserialize>(token); - } - else if (type == TwoFactorProviderType.Email) - { - return new Dictionary + var token = await _userManager.GenerateTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(type)); + if (type == TwoFactorProviderType.Duo) { - ["Email"] = token - }; - } - else if (type == TwoFactorProviderType.YubiKey) - { - return new Dictionary - { - ["Nfc"] = (bool)provider.MetaData["Nfc"] - }; - } - return null; - case TwoFactorProviderType.OrganizationDuo: - if (await _organizationDuoWebTokenProvider.CanGenerateTwoFactorTokenAsync(organization)) - { - return new Dictionary - { - ["Host"] = provider.MetaData["Host"], - ["Signature"] = await _organizationDuoWebTokenProvider.GenerateAsync(organization, user) - }; - } - return null; - default: - return null; - } - } - - protected async Task KnownDeviceAsync(User user, ValidatedTokenRequest request) => - (await GetKnownDeviceAsync(user, request)) != default; - - protected async Task GetKnownDeviceAsync(User user, ValidatedTokenRequest request) - { - if (user == null) - { - return default; - } - - return await _deviceRepository.GetByIdentifierAsync(GetDeviceFromRequest(request).Identifier, user.Id); - } - - private async Task SaveDeviceAsync(User user, ValidatedTokenRequest request) - { - var device = GetDeviceFromRequest(request); - if (device != null) - { - var existingDevice = await GetKnownDeviceAsync(user, request); - if (existingDevice == null) - { - device.UserId = user.Id; - await _deviceService.SaveAsync(device); - - var now = DateTime.UtcNow; - if (now - user.CreationDate > TimeSpan.FromMinutes(10)) - { - var deviceType = device.Type.GetType().GetMember(device.Type.ToString()) - .FirstOrDefault()?.GetCustomAttribute()?.GetName(); - if (!_globalSettings.DisableEmailNewDevice) - { - await _mailService.SendNewDeviceLoggedInEmail(user.Email, deviceType, now, - _currentContext.IpAddress); + return new Dictionary + { + ["Host"] = provider.MetaData["Host"], + ["Signature"] = token + }; } + else if (type == TwoFactorProviderType.WebAuthn) + { + if (token == null) + { + return null; + } + + return JsonSerializer.Deserialize>(token); + } + else if (type == TwoFactorProviderType.Email) + { + return new Dictionary + { + ["Email"] = token + }; + } + else if (type == TwoFactorProviderType.YubiKey) + { + return new Dictionary + { + ["Nfc"] = (bool)provider.MetaData["Nfc"] + }; + } + return null; + case TwoFactorProviderType.OrganizationDuo: + if (await _organizationDuoWebTokenProvider.CanGenerateTwoFactorTokenAsync(organization)) + { + return new Dictionary + { + ["Host"] = provider.MetaData["Host"], + ["Signature"] = await _organizationDuoWebTokenProvider.GenerateAsync(organization, user) + }; + } + return null; + default: + return null; + } + } + + protected async Task KnownDeviceAsync(User user, ValidatedTokenRequest request) => + (await GetKnownDeviceAsync(user, request)) != default; + + protected async Task GetKnownDeviceAsync(User user, ValidatedTokenRequest request) + { + if (user == null) + { + return default; + } + + return await _deviceRepository.GetByIdentifierAsync(GetDeviceFromRequest(request).Identifier, user.Id); + } + + private async Task SaveDeviceAsync(User user, ValidatedTokenRequest request) + { + var device = GetDeviceFromRequest(request); + if (device != null) + { + var existingDevice = await GetKnownDeviceAsync(user, request); + if (existingDevice == null) + { + device.UserId = user.Id; + await _deviceService.SaveAsync(device); + + var now = DateTime.UtcNow; + if (now - user.CreationDate > TimeSpan.FromMinutes(10)) + { + var deviceType = device.Type.GetType().GetMember(device.Type.ToString()) + .FirstOrDefault()?.GetCustomAttribute()?.GetName(); + if (!_globalSettings.DisableEmailNewDevice) + { + await _mailService.SendNewDeviceLoggedInEmail(user.Email, deviceType, now, + _currentContext.IpAddress); + } + } + + return device; } - return device; + return existingDevice; } - return existingDevice; + return null; } - return null; - } - - private async Task ResetFailedAuthDetailsAsync(User user) - { - // Early escape if db hit not necessary - if (user == null || user.FailedLoginCount == 0) + private async Task ResetFailedAuthDetailsAsync(User user) { - return; - } - - user.FailedLoginCount = 0; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - - private async Task UpdateFailedAuthDetailsAsync(User user, bool twoFactorInvalid, bool unknownDevice) - { - if (user == null) - { - return; - } - - var utcNow = DateTime.UtcNow; - user.FailedLoginCount = ++user.FailedLoginCount; - user.LastFailedLoginDate = user.RevisionDate = utcNow; - await _userRepository.ReplaceAsync(user); - - if (ValidateFailedAuthEmailConditions(unknownDevice, user)) - { - if (twoFactorInvalid) + // Early escape if db hit not necessary + if (user == null || user.FailedLoginCount == 0) { - await _mailService.SendFailedTwoFactorAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + return; } - else + + user.FailedLoginCount = 0; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + + private async Task UpdateFailedAuthDetailsAsync(User user, bool twoFactorInvalid, bool unknownDevice) + { + if (user == null) { - await _mailService.SendFailedLoginAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + return; + } + + var utcNow = DateTime.UtcNow; + user.FailedLoginCount = ++user.FailedLoginCount; + user.LastFailedLoginDate = user.RevisionDate = utcNow; + await _userRepository.ReplaceAsync(user); + + if (ValidateFailedAuthEmailConditions(unknownDevice, user)) + { + if (twoFactorInvalid) + { + await _mailService.SendFailedTwoFactorAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + } + else + { + await _mailService.SendFailedLoginAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + } } } - } - private bool ValidateFailedAuthEmailConditions(bool unknownDevice, User user) - { - var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; - var failedLoginCount = user?.FailedLoginCount ?? 0; - return unknownDevice && failedLoginCeiling > 0 && failedLoginCount == failedLoginCeiling; + private bool ValidateFailedAuthEmailConditions(bool unknownDevice, User user) + { + var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; + var failedLoginCount = user?.FailedLoginCount ?? 0; + return unknownDevice && failedLoginCeiling > 0 && failedLoginCount == failedLoginCeiling; + } } } diff --git a/src/Core/IdentityServer/ClientStore.cs b/src/Core/IdentityServer/ClientStore.cs index 2e6fa06bde..ebe247f19b 100644 --- a/src/Core/IdentityServer/ClientStore.cs +++ b/src/Core/IdentityServer/ClientStore.cs @@ -10,171 +10,172 @@ using IdentityModel; using IdentityServer4.Models; using IdentityServer4.Stores; -namespace Bit.Core.IdentityServer; - -public class ClientStore : IClientStore +namespace Bit.Core.IdentityServer { - private readonly IInstallationRepository _installationRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly GlobalSettings _globalSettings; - private readonly StaticClientStore _staticClientStore; - private readonly ILicensingService _licensingService; - private readonly ICurrentContext _currentContext; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - - public ClientStore( - IInstallationRepository installationRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - GlobalSettings globalSettings, - StaticClientStore staticClientStore, - ILicensingService licensingService, - ICurrentContext currentContext, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository) + public class ClientStore : IClientStore { - _installationRepository = installationRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _globalSettings = globalSettings; - _staticClientStore = staticClientStore; - _licensingService = licensingService; - _currentContext = currentContext; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _organizationApiKeyRepository = organizationApiKeyRepository; - } + private readonly IInstallationRepository _installationRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly GlobalSettings _globalSettings; + private readonly StaticClientStore _staticClientStore; + private readonly ILicensingService _licensingService; + private readonly ICurrentContext _currentContext; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - public async Task FindClientByIdAsync(string clientId) - { - if (!_globalSettings.SelfHosted && clientId.StartsWith("installation.")) + public ClientStore( + IInstallationRepository installationRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + GlobalSettings globalSettings, + StaticClientStore staticClientStore, + ILicensingService licensingService, + ICurrentContext currentContext, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository) { - var idParts = clientId.Split('.'); - if (idParts.Length > 1 && Guid.TryParse(idParts[1], out Guid id)) - { - var installation = await _installationRepository.GetByIdAsync(id); - if (installation != null) - { - return new Client - { - ClientId = $"installation.{installation.Id}", - RequireClientSecret = true, - ClientSecrets = { new Secret(installation.Key.Sha256()) }, - AllowedScopes = new string[] { "api.push", "api.licensing", "api.installation" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 24, - Enabled = installation.Enabled, - Claims = new List - { - new ClientClaim(JwtClaimTypes.Subject, installation.Id.ToString()) - } - }; - } - } + _installationRepository = installationRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _globalSettings = globalSettings; + _staticClientStore = staticClientStore; + _licensingService = licensingService; + _currentContext = currentContext; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; } - else if (_globalSettings.SelfHosted && clientId.StartsWith("internal.") && - CoreHelpers.SettingHasValue(_globalSettings.InternalIdentityKey)) + + public async Task FindClientByIdAsync(string clientId) { - var idParts = clientId.Split('.'); - if (idParts.Length > 1) + if (!_globalSettings.SelfHosted && clientId.StartsWith("installation.")) { - var id = idParts[1]; - if (!string.IsNullOrWhiteSpace(id)) + var idParts = clientId.Split('.'); + if (idParts.Length > 1 && Guid.TryParse(idParts[1], out Guid id)) { - return new Client + var installation = await _installationRepository.GetByIdAsync(id); + if (installation != null) { - ClientId = $"internal.{id}", - RequireClientSecret = true, - ClientSecrets = { new Secret(_globalSettings.InternalIdentityKey.Sha256()) }, - AllowedScopes = new string[] { "internal" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 24, - Enabled = true, - Claims = new List + return new Client { - new ClientClaim(JwtClaimTypes.Subject, id) - } - }; - } - } - } - else if (clientId.StartsWith("organization.")) - { - var idParts = clientId.Split('.'); - if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) - { - var org = await _organizationRepository.GetByIdAsync(id); - if (org != null) - { - var orgApiKey = (await _organizationApiKeyRepository - .GetManyByOrganizationIdTypeAsync(org.Id, OrganizationApiKeyType.Default)) - .First(); - return new Client - { - ClientId = $"organization.{org.Id}", - RequireClientSecret = true, - ClientSecrets = { new Secret(orgApiKey.ApiKey.Sha256()) }, - AllowedScopes = new string[] { "api.organization" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 1, - Enabled = org.Enabled && org.UseApi, - Claims = new List - { - new ClientClaim(JwtClaimTypes.Subject, org.Id.ToString()) - } - }; - } - } - } - else if (clientId.StartsWith("user.")) - { - var idParts = clientId.Split('.'); - if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) - { - var user = await _userRepository.GetByIdAsync(id); - if (user != null) - { - var claims = new Collection() - { - new ClientClaim(JwtClaimTypes.Subject, user.Id.ToString()), - new ClientClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external") - }; - var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); - var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); - var isPremium = await _licensingService.ValidateUserPremiumAsync(user); - foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) - { - var upperValue = claim.Value.ToUpperInvariant(); - var isBool = upperValue == "TRUE" || upperValue == "FALSE"; - claims.Add(isBool ? - new ClientClaim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : - new ClientClaim(claim.Key, claim.Value) - ); + ClientId = $"installation.{installation.Id}", + RequireClientSecret = true, + ClientSecrets = { new Secret(installation.Key.Sha256()) }, + AllowedScopes = new string[] { "api.push", "api.licensing", "api.installation" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 24, + Enabled = installation.Enabled, + Claims = new List + { + new ClientClaim(JwtClaimTypes.Subject, installation.Id.ToString()) + } + }; } - - return new Client - { - ClientId = clientId, - RequireClientSecret = true, - ClientSecrets = { new Secret(user.ApiKey.Sha256()) }, - AllowedScopes = new string[] { "api" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 1, - ClientClaimsPrefix = null, - Claims = claims - }; } } - } + else if (_globalSettings.SelfHosted && clientId.StartsWith("internal.") && + CoreHelpers.SettingHasValue(_globalSettings.InternalIdentityKey)) + { + var idParts = clientId.Split('.'); + if (idParts.Length > 1) + { + var id = idParts[1]; + if (!string.IsNullOrWhiteSpace(id)) + { + return new Client + { + ClientId = $"internal.{id}", + RequireClientSecret = true, + ClientSecrets = { new Secret(_globalSettings.InternalIdentityKey.Sha256()) }, + AllowedScopes = new string[] { "internal" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 24, + Enabled = true, + Claims = new List + { + new ClientClaim(JwtClaimTypes.Subject, id) + } + }; + } + } + } + else if (clientId.StartsWith("organization.")) + { + var idParts = clientId.Split('.'); + if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) + { + var org = await _organizationRepository.GetByIdAsync(id); + if (org != null) + { + var orgApiKey = (await _organizationApiKeyRepository + .GetManyByOrganizationIdTypeAsync(org.Id, OrganizationApiKeyType.Default)) + .First(); + return new Client + { + ClientId = $"organization.{org.Id}", + RequireClientSecret = true, + ClientSecrets = { new Secret(orgApiKey.ApiKey.Sha256()) }, + AllowedScopes = new string[] { "api.organization" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 1, + Enabled = org.Enabled && org.UseApi, + Claims = new List + { + new ClientClaim(JwtClaimTypes.Subject, org.Id.ToString()) + } + }; + } + } + } + else if (clientId.StartsWith("user.")) + { + var idParts = clientId.Split('.'); + if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) + { + var user = await _userRepository.GetByIdAsync(id); + if (user != null) + { + var claims = new Collection() + { + new ClientClaim(JwtClaimTypes.Subject, user.Id.ToString()), + new ClientClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external") + }; + var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); + var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); + var isPremium = await _licensingService.ValidateUserPremiumAsync(user); + foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) + { + var upperValue = claim.Value.ToUpperInvariant(); + var isBool = upperValue == "TRUE" || upperValue == "FALSE"; + claims.Add(isBool ? + new ClientClaim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : + new ClientClaim(claim.Key, claim.Value) + ); + } - return _staticClientStore.ApiClients.ContainsKey(clientId) ? - _staticClientStore.ApiClients[clientId] : null; + return new Client + { + ClientId = clientId, + RequireClientSecret = true, + ClientSecrets = { new Secret(user.ApiKey.Sha256()) }, + AllowedScopes = new string[] { "api" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 1, + ClientClaimsPrefix = null, + Claims = claims + }; + } + } + } + + return _staticClientStore.ApiClients.ContainsKey(clientId) ? + _staticClientStore.ApiClients[clientId] : null; + } } } diff --git a/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs b/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs index 084f98a275..b3846e81f3 100644 --- a/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs +++ b/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs @@ -5,48 +5,49 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.StackExchangeRedis; using Microsoft.Extensions.Options; -namespace Bit.Core.IdentityServer; - -public class ConfigureOpenIdConnectDistributedOptions : IPostConfigureOptions +namespace Bit.Core.IdentityServer { - private readonly IdentityServerOptions _idsrv; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly GlobalSettings _globalSettings; - - public ConfigureOpenIdConnectDistributedOptions(IHttpContextAccessor httpContextAccessor, GlobalSettings globalSettings, - IdentityServerOptions idsrv) + public class ConfigureOpenIdConnectDistributedOptions : IPostConfigureOptions { - _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); - _globalSettings = globalSettings; - _idsrv = idsrv; - } + private readonly IdentityServerOptions _idsrv; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly GlobalSettings _globalSettings; - public void PostConfigure(string name, CookieAuthenticationOptions options) - { - options.CookieManager = new DistributedCacheCookieManager(); - - if (name != AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) + public ConfigureOpenIdConnectDistributedOptions(IHttpContextAccessor httpContextAccessor, GlobalSettings globalSettings, + IdentityServerOptions idsrv) { - // Ignore - return; + _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); + _globalSettings = globalSettings; + _idsrv = idsrv; } - options.Cookie.Name = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; - options.Cookie.IsEssential = true; - options.Cookie.SameSite = _idsrv.Authentication.CookieSameSiteMode; - options.TicketDataFormat = new DistributedCacheTicketDataFormatter(_httpContextAccessor, name); + public void PostConfigure(string name, CookieAuthenticationOptions options) + { + options.CookieManager = new DistributedCacheCookieManager(); - if (string.IsNullOrWhiteSpace(_globalSettings.IdentityServer?.RedisConnectionString)) - { - options.SessionStore = new MemoryCacheTicketStore(); - } - else - { - var redisOptions = new RedisCacheOptions + if (name != AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) { - Configuration = _globalSettings.IdentityServer.RedisConnectionString, - }; - options.SessionStore = new RedisCacheTicketStore(redisOptions); + // Ignore + return; + } + + options.Cookie.Name = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; + options.Cookie.IsEssential = true; + options.Cookie.SameSite = _idsrv.Authentication.CookieSameSiteMode; + options.TicketDataFormat = new DistributedCacheTicketDataFormatter(_httpContextAccessor, name); + + if (string.IsNullOrWhiteSpace(_globalSettings.IdentityServer?.RedisConnectionString)) + { + options.SessionStore = new MemoryCacheTicketStore(); + } + else + { + var redisOptions = new RedisCacheOptions + { + Configuration = _globalSettings.IdentityServer.RedisConnectionString, + }; + options.SessionStore = new RedisCacheTicketStore(redisOptions); + } } } } diff --git a/src/Core/IdentityServer/CustomTokenRequestValidator.cs b/src/Core/IdentityServer/CustomTokenRequestValidator.cs index 1354af70a1..f37e165f37 100644 --- a/src/Core/IdentityServer/CustomTokenRequestValidator.cs +++ b/src/Core/IdentityServer/CustomTokenRequestValidator.cs @@ -11,142 +11,143 @@ using IdentityServer4.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer; - -public class CustomTokenRequestValidator : BaseRequestValidator, - ICustomTokenRequestValidator +namespace Bit.Core.IdentityServer { - private UserManager _userManager; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IOrganizationRepository _organizationRepository; - - public CustomTokenRequestValidator( - UserManager userManager, - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService, - IEventService eventService, - IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IApplicationCacheService applicationCacheService, - IMailService mailService, - ILogger logger, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - IUserRepository userRepository, - ICaptchaValidationService captchaValidationService) - : base(userManager, deviceRepository, deviceService, userService, eventService, - organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, - applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, - userRepository, captchaValidationService) + public class CustomTokenRequestValidator : BaseRequestValidator, + ICustomTokenRequestValidator { - _userManager = userManager; - _ssoConfigRepository = ssoConfigRepository; - _organizationRepository = organizationRepository; - } + private UserManager _userManager; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IOrganizationRepository _organizationRepository; - public async Task ValidateAsync(CustomTokenRequestValidationContext context) - { - string[] allowedGrantTypes = { "authorization_code", "client_credentials" }; - if (!allowedGrantTypes.Contains(context.Result.ValidatedRequest.GrantType) - || context.Result.ValidatedRequest.ClientId.StartsWith("organization") - || context.Result.ValidatedRequest.ClientId.StartsWith("installation")) + public CustomTokenRequestValidator( + UserManager userManager, + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService, + IEventService eventService, + IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IApplicationCacheService applicationCacheService, + IMailService mailService, + ILogger logger, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + IUserRepository userRepository, + ICaptchaValidationService captchaValidationService) + : base(userManager, deviceRepository, deviceService, userService, eventService, + organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, + applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, + userRepository, captchaValidationService) { - return; + _userManager = userManager; + _ssoConfigRepository = ssoConfigRepository; + _organizationRepository = organizationRepository; } - await ValidateAsync(context, context.Result.ValidatedRequest, - new CustomValidatorRequestContext { KnownDevice = true }); - } - protected async override Task ValidateContextAsync(CustomTokenRequestValidationContext context, - CustomValidatorRequestContext validatorContext) - { - var email = context.Result.ValidatedRequest.Subject?.GetDisplayName() - ?? context.Result.ValidatedRequest.ClientClaims?.FirstOrDefault(claim => claim.Type == JwtClaimTypes.Email)?.Value; - if (!string.IsNullOrWhiteSpace(email)) + public async Task ValidateAsync(CustomTokenRequestValidationContext context) { - validatorContext.User = await _userManager.FindByEmailAsync(email); - } - return validatorContext.User != null; - } - - protected override async Task SetSuccessResult(CustomTokenRequestValidationContext context, User user, - List claims, Dictionary customResponse) - { - context.Result.CustomResponse = customResponse; - if (claims?.Any() ?? false) - { - context.Result.ValidatedRequest.Client.AlwaysSendClientClaims = true; - context.Result.ValidatedRequest.Client.ClientClaimsPrefix = string.Empty; - foreach (var claim in claims) + string[] allowedGrantTypes = { "authorization_code", "client_credentials" }; + if (!allowedGrantTypes.Contains(context.Result.ValidatedRequest.GrantType) + || context.Result.ValidatedRequest.ClientId.StartsWith("organization") + || context.Result.ValidatedRequest.ClientId.StartsWith("installation")) { - context.Result.ValidatedRequest.ClientClaims.Add(claim); + return; + } + await ValidateAsync(context, context.Result.ValidatedRequest, + new CustomValidatorRequestContext { KnownDevice = true }); + } + + protected async override Task ValidateContextAsync(CustomTokenRequestValidationContext context, + CustomValidatorRequestContext validatorContext) + { + var email = context.Result.ValidatedRequest.Subject?.GetDisplayName() + ?? context.Result.ValidatedRequest.ClientClaims?.FirstOrDefault(claim => claim.Type == JwtClaimTypes.Email)?.Value; + if (!string.IsNullOrWhiteSpace(email)) + { + validatorContext.User = await _userManager.FindByEmailAsync(email); + } + return validatorContext.User != null; + } + + protected override async Task SetSuccessResult(CustomTokenRequestValidationContext context, User user, + List claims, Dictionary customResponse) + { + context.Result.CustomResponse = customResponse; + if (claims?.Any() ?? false) + { + context.Result.ValidatedRequest.Client.AlwaysSendClientClaims = true; + context.Result.ValidatedRequest.Client.ClientClaimsPrefix = string.Empty; + foreach (var claim in claims) + { + context.Result.ValidatedRequest.ClientClaims.Add(claim); + } + } + + if (context.Result.CustomResponse == null || user.MasterPassword != null) + { + return; + } + + // KeyConnector responses below + + // Apikey login + if (context.Result.ValidatedRequest.GrantType == "client_credentials") + { + if (user.UsesKeyConnector) + { + // KeyConnectorUrl is configured in the CLI client, we just need to tell the client to use it + context.Result.CustomResponse["ApiUseKeyConnector"] = true; + context.Result.CustomResponse["ResetMasterPassword"] = false; + } + return; + } + + // SSO login + var organizationClaim = context.Result.ValidatedRequest.Subject?.FindFirst(c => c.Type == "organizationId"); + if (organizationClaim?.Value != null) + { + var organizationId = new Guid(organizationClaim.Value); + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); + var ssoConfigData = ssoConfig.GetData(); + + if (ssoConfigData is { KeyConnectorEnabled: true } && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl)) + { + context.Result.CustomResponse["KeyConnectorUrl"] = ssoConfigData.KeyConnectorUrl; + // Prevent clients redirecting to set-password + context.Result.CustomResponse["ResetMasterPassword"] = false; + } } } - if (context.Result.CustomResponse == null || user.MasterPassword != null) + protected override void SetTwoFactorResult(CustomTokenRequestValidationContext context, + Dictionary customResponse) { - return; + context.Result.Error = "invalid_grant"; + context.Result.ErrorDescription = "Two factor required."; + context.Result.IsError = true; + context.Result.CustomResponse = customResponse; } - // KeyConnector responses below - - // Apikey login - if (context.Result.ValidatedRequest.GrantType == "client_credentials") + protected override void SetSsoResult(CustomTokenRequestValidationContext context, + Dictionary customResponse) { - if (user.UsesKeyConnector) - { - // KeyConnectorUrl is configured in the CLI client, we just need to tell the client to use it - context.Result.CustomResponse["ApiUseKeyConnector"] = true; - context.Result.CustomResponse["ResetMasterPassword"] = false; - } - return; + context.Result.Error = "invalid_grant"; + context.Result.ErrorDescription = "Single Sign on required."; + context.Result.IsError = true; + context.Result.CustomResponse = customResponse; } - // SSO login - var organizationClaim = context.Result.ValidatedRequest.Subject?.FindFirst(c => c.Type == "organizationId"); - if (organizationClaim?.Value != null) + protected override void SetErrorResult(CustomTokenRequestValidationContext context, + Dictionary customResponse) { - var organizationId = new Guid(organizationClaim.Value); - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); - var ssoConfigData = ssoConfig.GetData(); - - if (ssoConfigData is { KeyConnectorEnabled: true } && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl)) - { - context.Result.CustomResponse["KeyConnectorUrl"] = ssoConfigData.KeyConnectorUrl; - // Prevent clients redirecting to set-password - context.Result.CustomResponse["ResetMasterPassword"] = false; - } + context.Result.Error = "invalid_grant"; + context.Result.IsError = true; + context.Result.CustomResponse = customResponse; } } - - protected override void SetTwoFactorResult(CustomTokenRequestValidationContext context, - Dictionary customResponse) - { - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Two factor required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; - } - - protected override void SetSsoResult(CustomTokenRequestValidationContext context, - Dictionary customResponse) - { - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Single Sign on required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; - } - - protected override void SetErrorResult(CustomTokenRequestValidationContext context, - Dictionary customResponse) - { - context.Result.Error = "invalid_grant"; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; - } } diff --git a/src/Core/IdentityServer/CustomValidatorRequestContext.cs b/src/Core/IdentityServer/CustomValidatorRequestContext.cs index 66fdc1e7e9..f5e95aaa8c 100644 --- a/src/Core/IdentityServer/CustomValidatorRequestContext.cs +++ b/src/Core/IdentityServer/CustomValidatorRequestContext.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.IdentityServer; - -public class CustomValidatorRequestContext +namespace Bit.Core.IdentityServer { - public User User { get; set; } - public bool KnownDevice { get; set; } - public CaptchaResponse CaptchaResponse { get; set; } + public class CustomValidatorRequestContext + { + public User User { get; set; } + public bool KnownDevice { get; set; } + public CaptchaResponse CaptchaResponse { get; set; } + } } diff --git a/src/Core/IdentityServer/DistributedCacheCookieManager.cs b/src/Core/IdentityServer/DistributedCacheCookieManager.cs index d202581bd2..988afc018e 100644 --- a/src/Core/IdentityServer/DistributedCacheCookieManager.cs +++ b/src/Core/IdentityServer/DistributedCacheCookieManager.cs @@ -4,65 +4,66 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.IdentityServer; - -public class DistributedCacheCookieManager : ICookieManager +namespace Bit.Core.IdentityServer { - private readonly ChunkingCookieManager _cookieManager; - - public DistributedCacheCookieManager() + public class DistributedCacheCookieManager : ICookieManager { - _cookieManager = new ChunkingCookieManager(); - } + private readonly ChunkingCookieManager _cookieManager; - private string CacheKeyPrefix => "cookie-data"; - - public void AppendResponseCookie(HttpContext context, string key, string value, CookieOptions options) - { - var id = Guid.NewGuid().ToString(); - var cacheKey = GetKey(key, id); - - var expiresUtc = options.Expires ?? DateTimeOffset.UtcNow.AddMinutes(15); - var cacheOptions = new DistributedCacheEntryOptions() - .SetAbsoluteExpiration(expiresUtc); - - var data = Encoding.UTF8.GetBytes(value); - - var cache = GetCache(context); - cache.Set(cacheKey, data, cacheOptions); - - // Write the cookie with the identifier as the body - _cookieManager.AppendResponseCookie(context, key, id, options); - } - - public void DeleteCookie(HttpContext context, string key, CookieOptions options) - { - _cookieManager.DeleteCookie(context, key, options); - var id = GetId(context, key); - if (!string.IsNullOrWhiteSpace(id)) + public DistributedCacheCookieManager() { + _cookieManager = new ChunkingCookieManager(); + } + + private string CacheKeyPrefix => "cookie-data"; + + public void AppendResponseCookie(HttpContext context, string key, string value, CookieOptions options) + { + var id = Guid.NewGuid().ToString(); var cacheKey = GetKey(key, id); - GetCache(context).Remove(cacheKey); - } - } - public string GetRequestCookie(HttpContext context, string key) - { - var id = GetId(context, key); - if (string.IsNullOrWhiteSpace(id)) + var expiresUtc = options.Expires ?? DateTimeOffset.UtcNow.AddMinutes(15); + var cacheOptions = new DistributedCacheEntryOptions() + .SetAbsoluteExpiration(expiresUtc); + + var data = Encoding.UTF8.GetBytes(value); + + var cache = GetCache(context); + cache.Set(cacheKey, data, cacheOptions); + + // Write the cookie with the identifier as the body + _cookieManager.AppendResponseCookie(context, key, id, options); + } + + public void DeleteCookie(HttpContext context, string key, CookieOptions options) { - return null; + _cookieManager.DeleteCookie(context, key, options); + var id = GetId(context, key); + if (!string.IsNullOrWhiteSpace(id)) + { + var cacheKey = GetKey(key, id); + GetCache(context).Remove(cacheKey); + } } - var cacheKey = GetKey(key, id); - return GetCache(context).GetString(cacheKey); + + public string GetRequestCookie(HttpContext context, string key) + { + var id = GetId(context, key); + if (string.IsNullOrWhiteSpace(id)) + { + return null; + } + var cacheKey = GetKey(key, id); + return GetCache(context).GetString(cacheKey); + } + + private IDistributedCache GetCache(HttpContext context) => + context.RequestServices.GetRequiredService(); + + private string GetKey(string key, string id) => $"{CacheKeyPrefix}-{key}-{id}"; + + private string GetId(HttpContext context, string key) => + context.Request.Cookies.ContainsKey(key) ? + context.Request.Cookies[key] : null; } - - private IDistributedCache GetCache(HttpContext context) => - context.RequestServices.GetRequiredService(); - - private string GetKey(string key, string id) => $"{CacheKeyPrefix}-{key}-{id}"; - - private string GetId(HttpContext context, string key) => - context.Request.Cookies.ContainsKey(key) ? - context.Request.Cookies[key] : null; } diff --git a/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs b/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs index ec47a0f7c0..bbd1d40876 100644 --- a/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs +++ b/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs @@ -4,61 +4,62 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.IdentityServer; - -public class DistributedCacheTicketDataFormatter : ISecureDataFormat +namespace Bit.Core.IdentityServer { - private readonly IHttpContextAccessor _httpContext; - private readonly string _name; - - public DistributedCacheTicketDataFormatter(IHttpContextAccessor httpContext, string name) + public class DistributedCacheTicketDataFormatter : ISecureDataFormat { - _httpContext = httpContext; - _name = name; - } + private readonly IHttpContextAccessor _httpContext; + private readonly string _name; - private string CacheKeyPrefix => "ticket-data"; - private IDistributedCache Cache => _httpContext.HttpContext.RequestServices.GetRequiredService(); - private IDataProtector Protector => _httpContext.HttpContext.RequestServices.GetRequiredService() - .CreateProtector(CacheKeyPrefix, _name); - - public string Protect(AuthenticationTicket data) => Protect(data, null); - public string Protect(AuthenticationTicket data, string purpose) - { - var key = Guid.NewGuid().ToString(); - var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; - - var expiresUtc = data.Properties.ExpiresUtc ?? - DateTimeOffset.UtcNow.AddMinutes(15); - - var options = new DistributedCacheEntryOptions(); - options.SetAbsoluteExpiration(expiresUtc); - - var ticket = TicketSerializer.Default.Serialize(data); - Cache.Set(cacheKey, ticket, options); - - return Protector.Protect(key); - } - - public AuthenticationTicket Unprotect(string protectedText) => Unprotect(protectedText, null); - public AuthenticationTicket Unprotect(string protectedText, string purpose) - { - if (string.IsNullOrWhiteSpace(protectedText)) + public DistributedCacheTicketDataFormatter(IHttpContextAccessor httpContext, string name) { - return null; + _httpContext = httpContext; + _name = name; } - // Decrypt the key and retrieve the data from the cache. - var key = Protector.Unprotect(protectedText); - var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; - var ticket = Cache.Get(cacheKey); + private string CacheKeyPrefix => "ticket-data"; + private IDistributedCache Cache => _httpContext.HttpContext.RequestServices.GetRequiredService(); + private IDataProtector Protector => _httpContext.HttpContext.RequestServices.GetRequiredService() + .CreateProtector(CacheKeyPrefix, _name); - if (ticket == null) + public string Protect(AuthenticationTicket data) => Protect(data, null); + public string Protect(AuthenticationTicket data, string purpose) { - return null; + var key = Guid.NewGuid().ToString(); + var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; + + var expiresUtc = data.Properties.ExpiresUtc ?? + DateTimeOffset.UtcNow.AddMinutes(15); + + var options = new DistributedCacheEntryOptions(); + options.SetAbsoluteExpiration(expiresUtc); + + var ticket = TicketSerializer.Default.Serialize(data); + Cache.Set(cacheKey, ticket, options); + + return Protector.Protect(key); } - var data = TicketSerializer.Default.Deserialize(ticket); - return data; + public AuthenticationTicket Unprotect(string protectedText) => Unprotect(protectedText, null); + public AuthenticationTicket Unprotect(string protectedText, string purpose) + { + if (string.IsNullOrWhiteSpace(protectedText)) + { + return null; + } + + // Decrypt the key and retrieve the data from the cache. + var key = Protector.Unprotect(protectedText); + var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; + var ticket = Cache.Get(cacheKey); + + if (ticket == null) + { + return null; + } + + var data = TicketSerializer.Default.Deserialize(ticket); + return data; + } } } diff --git a/src/Core/IdentityServer/MemoryCacheTicketStore.cs b/src/Core/IdentityServer/MemoryCacheTicketStore.cs index dc8d763c9c..7120aee07c 100644 --- a/src/Core/IdentityServer/MemoryCacheTicketStore.cs +++ b/src/Core/IdentityServer/MemoryCacheTicketStore.cs @@ -2,52 +2,53 @@ using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.Extensions.Caching.Memory; -namespace Bit.Core.IdentityServer; - -public class MemoryCacheTicketStore : ITicketStore +namespace Bit.Core.IdentityServer { - private const string _keyPrefix = "auth-"; - private readonly IMemoryCache _cache; - - public MemoryCacheTicketStore() + public class MemoryCacheTicketStore : ITicketStore { - _cache = new MemoryCache(new MemoryCacheOptions()); - } + private const string _keyPrefix = "auth-"; + private readonly IMemoryCache _cache; - public async Task StoreAsync(AuthenticationTicket ticket) - { - var key = $"{_keyPrefix}{Guid.NewGuid()}"; - await RenewAsync(key, ticket); - return key; - } - - public Task RenewAsync(string key, AuthenticationTicket ticket) - { - var options = new MemoryCacheEntryOptions(); - var expiresUtc = ticket.Properties.ExpiresUtc; - if (expiresUtc.HasValue) + public MemoryCacheTicketStore() { - options.SetAbsoluteExpiration(expiresUtc.Value); - } - else - { - options.SetSlidingExpiration(TimeSpan.FromMinutes(15)); + _cache = new MemoryCache(new MemoryCacheOptions()); } - _cache.Set(key, ticket, options); + public async Task StoreAsync(AuthenticationTicket ticket) + { + var key = $"{_keyPrefix}{Guid.NewGuid()}"; + await RenewAsync(key, ticket); + return key; + } - return Task.FromResult(0); - } + public Task RenewAsync(string key, AuthenticationTicket ticket) + { + var options = new MemoryCacheEntryOptions(); + var expiresUtc = ticket.Properties.ExpiresUtc; + if (expiresUtc.HasValue) + { + options.SetAbsoluteExpiration(expiresUtc.Value); + } + else + { + options.SetSlidingExpiration(TimeSpan.FromMinutes(15)); + } - public Task RetrieveAsync(string key) - { - _cache.TryGetValue(key, out AuthenticationTicket ticket); - return Task.FromResult(ticket); - } + _cache.Set(key, ticket, options); - public Task RemoveAsync(string key) - { - _cache.Remove(key); - return Task.FromResult(0); + return Task.FromResult(0); + } + + public Task RetrieveAsync(string key) + { + _cache.TryGetValue(key, out AuthenticationTicket ticket); + return Task.FromResult(ticket); + } + + public Task RemoveAsync(string key) + { + _cache.Remove(key); + return Task.FromResult(0); + } } } diff --git a/src/Core/IdentityServer/OidcIdentityClient.cs b/src/Core/IdentityServer/OidcIdentityClient.cs index 822ac56cd4..7f24f66e2f 100644 --- a/src/Core/IdentityServer/OidcIdentityClient.cs +++ b/src/Core/IdentityServer/OidcIdentityClient.cs @@ -2,24 +2,25 @@ using IdentityServer4; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer; - -public class OidcIdentityClient : Client +namespace Bit.Core.IdentityServer { - public OidcIdentityClient(GlobalSettings globalSettings) + public class OidcIdentityClient : Client { - ClientId = "oidc-identity"; - RequireClientSecret = true; - RequirePkce = true; - ClientSecrets = new List { new Secret(globalSettings.OidcIdentityClientKey.Sha256()) }; - AllowedScopes = new string[] + public OidcIdentityClient(GlobalSettings globalSettings) { - IdentityServerConstants.StandardScopes.OpenId, - IdentityServerConstants.StandardScopes.Profile - }; - AllowedGrantTypes = GrantTypes.Code; - Enabled = true; - RedirectUris = new List { $"{globalSettings.BaseServiceUri.Identity}/signin-oidc" }; - RequireConsent = false; + ClientId = "oidc-identity"; + RequireClientSecret = true; + RequirePkce = true; + ClientSecrets = new List { new Secret(globalSettings.OidcIdentityClientKey.Sha256()) }; + AllowedScopes = new string[] + { + IdentityServerConstants.StandardScopes.OpenId, + IdentityServerConstants.StandardScopes.Profile + }; + AllowedGrantTypes = GrantTypes.Code; + Enabled = true; + RedirectUris = new List { $"{globalSettings.BaseServiceUri.Identity}/signin-oidc" }; + RequireConsent = false; + } } } diff --git a/src/Core/IdentityServer/PersistedGrantStore.cs b/src/Core/IdentityServer/PersistedGrantStore.cs index a1b3294ba0..7094265e7d 100644 --- a/src/Core/IdentityServer/PersistedGrantStore.cs +++ b/src/Core/IdentityServer/PersistedGrantStore.cs @@ -3,85 +3,86 @@ using IdentityServer4.Models; using IdentityServer4.Stores; using Grant = Bit.Core.Entities.Grant; -namespace Bit.Core.IdentityServer; - -public class PersistedGrantStore : IPersistedGrantStore +namespace Bit.Core.IdentityServer { - private readonly IGrantRepository _grantRepository; - - public PersistedGrantStore( - IGrantRepository grantRepository) + public class PersistedGrantStore : IPersistedGrantStore { - _grantRepository = grantRepository; - } + private readonly IGrantRepository _grantRepository; - public async Task GetAsync(string key) - { - var grant = await _grantRepository.GetByKeyAsync(key); - if (grant == null) + public PersistedGrantStore( + IGrantRepository grantRepository) { - return null; + _grantRepository = grantRepository; } - var pGrant = ToPersistedGrant(grant); - return pGrant; - } - - public async Task> GetAllAsync(PersistedGrantFilter filter) - { - var grants = await _grantRepository.GetManyAsync(filter.SubjectId, filter.SessionId, - filter.ClientId, filter.Type); - var pGrants = grants.Select(g => ToPersistedGrant(g)); - return pGrants; - } - - public async Task RemoveAllAsync(PersistedGrantFilter filter) - { - await _grantRepository.DeleteManyAsync(filter.SubjectId, filter.SessionId, filter.ClientId, filter.Type); - } - - public async Task RemoveAsync(string key) - { - await _grantRepository.DeleteByKeyAsync(key); - } - - public async Task StoreAsync(PersistedGrant pGrant) - { - var grant = ToGrant(pGrant); - await _grantRepository.SaveAsync(grant); - } - - private Grant ToGrant(PersistedGrant pGrant) - { - return new Grant + public async Task GetAsync(string key) { - Key = pGrant.Key, - Type = pGrant.Type, - SubjectId = pGrant.SubjectId, - SessionId = pGrant.SessionId, - ClientId = pGrant.ClientId, - Description = pGrant.Description, - CreationDate = pGrant.CreationTime, - ExpirationDate = pGrant.Expiration, - ConsumedDate = pGrant.ConsumedTime, - Data = pGrant.Data - }; - } + var grant = await _grantRepository.GetByKeyAsync(key); + if (grant == null) + { + return null; + } - private PersistedGrant ToPersistedGrant(Grant grant) - { - return new PersistedGrant + var pGrant = ToPersistedGrant(grant); + return pGrant; + } + + public async Task> GetAllAsync(PersistedGrantFilter filter) { - Key = grant.Key, - Type = grant.Type, - SubjectId = grant.SubjectId, - SessionId = grant.SessionId, - ClientId = grant.ClientId, - Description = grant.Description, - CreationTime = grant.CreationDate, - Expiration = grant.ExpirationDate, - ConsumedTime = grant.ConsumedDate, - Data = grant.Data - }; + var grants = await _grantRepository.GetManyAsync(filter.SubjectId, filter.SessionId, + filter.ClientId, filter.Type); + var pGrants = grants.Select(g => ToPersistedGrant(g)); + return pGrants; + } + + public async Task RemoveAllAsync(PersistedGrantFilter filter) + { + await _grantRepository.DeleteManyAsync(filter.SubjectId, filter.SessionId, filter.ClientId, filter.Type); + } + + public async Task RemoveAsync(string key) + { + await _grantRepository.DeleteByKeyAsync(key); + } + + public async Task StoreAsync(PersistedGrant pGrant) + { + var grant = ToGrant(pGrant); + await _grantRepository.SaveAsync(grant); + } + + private Grant ToGrant(PersistedGrant pGrant) + { + return new Grant + { + Key = pGrant.Key, + Type = pGrant.Type, + SubjectId = pGrant.SubjectId, + SessionId = pGrant.SessionId, + ClientId = pGrant.ClientId, + Description = pGrant.Description, + CreationDate = pGrant.CreationTime, + ExpirationDate = pGrant.Expiration, + ConsumedDate = pGrant.ConsumedTime, + Data = pGrant.Data + }; + } + + private PersistedGrant ToPersistedGrant(Grant grant) + { + return new PersistedGrant + { + Key = grant.Key, + Type = grant.Type, + SubjectId = grant.SubjectId, + SessionId = grant.SessionId, + ClientId = grant.ClientId, + Description = grant.Description, + CreationTime = grant.CreationDate, + Expiration = grant.ExpirationDate, + ConsumedTime = grant.ConsumedDate, + Data = grant.Data + }; + } } } diff --git a/src/Core/IdentityServer/ProfileService.cs b/src/Core/IdentityServer/ProfileService.cs index 873ad6b5ab..aa79c60d13 100644 --- a/src/Core/IdentityServer/ProfileService.cs +++ b/src/Core/IdentityServer/ProfileService.cs @@ -6,82 +6,83 @@ using Bit.Core.Utilities; using IdentityServer4.Models; using IdentityServer4.Services; -namespace Bit.Core.IdentityServer; - -public class ProfileService : IProfileService +namespace Bit.Core.IdentityServer { - private readonly IUserService _userService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly ILicensingService _licensingService; - private readonly ICurrentContext _currentContext; - - public ProfileService( - IUserService userService, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, - ILicensingService licensingService, - ICurrentContext currentContext) + public class ProfileService : IProfileService { - _userService = userService; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _licensingService = licensingService; - _currentContext = currentContext; - } + private readonly IUserService _userService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly ILicensingService _licensingService; + private readonly ICurrentContext _currentContext; - public async Task GetProfileDataAsync(ProfileDataRequestContext context) - { - var existingClaims = context.Subject.Claims; - var newClaims = new List(); - - var user = await _userService.GetUserByPrincipalAsync(context.Subject); - if (user != null) + public ProfileService( + IUserService userService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, + ILicensingService licensingService, + ICurrentContext currentContext) { - var isPremium = await _licensingService.ValidateUserPremiumAsync(user); - var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); - var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); - foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) + _userService = userService; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _licensingService = licensingService; + _currentContext = currentContext; + } + + public async Task GetProfileDataAsync(ProfileDataRequestContext context) + { + var existingClaims = context.Subject.Claims; + var newClaims = new List(); + + var user = await _userService.GetUserByPrincipalAsync(context.Subject); + if (user != null) { - var upperValue = claim.Value.ToUpperInvariant(); - var isBool = upperValue == "TRUE" || upperValue == "FALSE"; - newClaims.Add(isBool ? - new Claim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : - new Claim(claim.Key, claim.Value) - ); + var isPremium = await _licensingService.ValidateUserPremiumAsync(user); + var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); + var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); + foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) + { + var upperValue = claim.Value.ToUpperInvariant(); + var isBool = upperValue == "TRUE" || upperValue == "FALSE"; + newClaims.Add(isBool ? + new Claim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : + new Claim(claim.Key, claim.Value) + ); + } + } + + // filter out any of the new claims + var existingClaimsToKeep = existingClaims + .Where(c => !c.Type.StartsWith("org") && + (newClaims.Count == 0 || !newClaims.Any(nc => nc.Type == c.Type))) + .ToList(); + + newClaims.AddRange(existingClaimsToKeep); + if (newClaims.Any()) + { + context.IssuedClaims.AddRange(newClaims); } } - // filter out any of the new claims - var existingClaimsToKeep = existingClaims - .Where(c => !c.Type.StartsWith("org") && - (newClaims.Count == 0 || !newClaims.Any(nc => nc.Type == c.Type))) - .ToList(); - - newClaims.AddRange(existingClaimsToKeep); - if (newClaims.Any()) + public async Task IsActiveAsync(IsActiveContext context) { - context.IssuedClaims.AddRange(newClaims); - } - } + var securityTokenClaim = context.Subject?.Claims.FirstOrDefault(c => c.Type == "sstamp"); + var user = await _userService.GetUserByPrincipalAsync(context.Subject); - public async Task IsActiveAsync(IsActiveContext context) - { - var securityTokenClaim = context.Subject?.Claims.FirstOrDefault(c => c.Type == "sstamp"); - var user = await _userService.GetUserByPrincipalAsync(context.Subject); - - if (user != null && securityTokenClaim != null) - { - context.IsActive = string.Equals(user.SecurityStamp, securityTokenClaim.Value, - StringComparison.InvariantCultureIgnoreCase); - return; - } - else - { - context.IsActive = true; + if (user != null && securityTokenClaim != null) + { + context.IsActive = string.Equals(user.SecurityStamp, securityTokenClaim.Value, + StringComparison.InvariantCultureIgnoreCase); + return; + } + else + { + context.IsActive = true; + } } } } diff --git a/src/Core/IdentityServer/RedisCacheTicketStore.cs b/src/Core/IdentityServer/RedisCacheTicketStore.cs index 139158c329..f7aa8c0a97 100644 --- a/src/Core/IdentityServer/RedisCacheTicketStore.cs +++ b/src/Core/IdentityServer/RedisCacheTicketStore.cs @@ -3,62 +3,63 @@ using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Caching.StackExchangeRedis; -namespace Bit.Core.IdentityServer; - -public class RedisCacheTicketStore : ITicketStore +namespace Bit.Core.IdentityServer { - private const string _keyPrefix = "auth-"; - private readonly IDistributedCache _cache; - - public RedisCacheTicketStore(RedisCacheOptions options) + public class RedisCacheTicketStore : ITicketStore { - _cache = new RedisCache(options); - } + private const string _keyPrefix = "auth-"; + private readonly IDistributedCache _cache; - public async Task StoreAsync(AuthenticationTicket ticket) - { - var key = $"{_keyPrefix}{Guid.NewGuid()}"; - await RenewAsync(key, ticket); + public RedisCacheTicketStore(RedisCacheOptions options) + { + _cache = new RedisCache(options); + } - return key; - } + public async Task StoreAsync(AuthenticationTicket ticket) + { + var key = $"{_keyPrefix}{Guid.NewGuid()}"; + await RenewAsync(key, ticket); - public Task RenewAsync(string key, AuthenticationTicket ticket) - { - var options = new DistributedCacheEntryOptions(); - var expiresUtc = ticket.Properties.ExpiresUtc ?? - DateTimeOffset.UtcNow.AddMinutes(15); - options.SetAbsoluteExpiration(expiresUtc); + return key; + } - var val = SerializeToBytes(ticket); - _cache.Set(key, val, options); + public Task RenewAsync(string key, AuthenticationTicket ticket) + { + var options = new DistributedCacheEntryOptions(); + var expiresUtc = ticket.Properties.ExpiresUtc ?? + DateTimeOffset.UtcNow.AddMinutes(15); + options.SetAbsoluteExpiration(expiresUtc); - return Task.FromResult(0); - } + var val = SerializeToBytes(ticket); + _cache.Set(key, val, options); - public Task RetrieveAsync(string key) - { - AuthenticationTicket ticket; - var bytes = _cache.Get(key); - ticket = DeserializeFromBytes(bytes); + return Task.FromResult(0); + } - return Task.FromResult(ticket); - } + public Task RetrieveAsync(string key) + { + AuthenticationTicket ticket; + var bytes = _cache.Get(key); + ticket = DeserializeFromBytes(bytes); - public Task RemoveAsync(string key) - { - _cache.Remove(key); + return Task.FromResult(ticket); + } - return Task.FromResult(0); - } + public Task RemoveAsync(string key) + { + _cache.Remove(key); - private static byte[] SerializeToBytes(AuthenticationTicket source) - { - return TicketSerializer.Default.Serialize(source); - } + return Task.FromResult(0); + } - private static AuthenticationTicket DeserializeFromBytes(byte[] source) - { - return source == null ? null : TicketSerializer.Default.Deserialize(source); + private static byte[] SerializeToBytes(AuthenticationTicket source) + { + return TicketSerializer.Default.Serialize(source); + } + + private static AuthenticationTicket DeserializeFromBytes(byte[] source) + { + return source == null ? null : TicketSerializer.Default.Deserialize(source); + } } } diff --git a/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs b/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs index 82b3cf50a8..f831431416 100644 --- a/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs +++ b/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs @@ -11,162 +11,163 @@ using IdentityServer4.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer; - -public class ResourceOwnerPasswordValidator : BaseRequestValidator, - IResourceOwnerPasswordValidator +namespace Bit.Core.IdentityServer { - private UserManager _userManager; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - private readonly ICaptchaValidationService _captchaValidationService; - public ResourceOwnerPasswordValidator( - UserManager userManager, - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService, - IEventService eventService, - IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IApplicationCacheService applicationCacheService, - IMailService mailService, - ILogger logger, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - ICaptchaValidationService captchaValidationService, - IUserRepository userRepository) - : base(userManager, deviceRepository, deviceService, userService, eventService, - organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, - applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, - userRepository, captchaValidationService) + public class ResourceOwnerPasswordValidator : BaseRequestValidator, + IResourceOwnerPasswordValidator { - _userManager = userManager; - _userService = userService; - _currentContext = currentContext; - _captchaValidationService = captchaValidationService; - } - - public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context) - { - if (!AuthEmailHeaderIsValid(context)) + private UserManager _userManager; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + private readonly ICaptchaValidationService _captchaValidationService; + public ResourceOwnerPasswordValidator( + UserManager userManager, + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService, + IEventService eventService, + IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IApplicationCacheService applicationCacheService, + IMailService mailService, + ILogger logger, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + ICaptchaValidationService captchaValidationService, + IUserRepository userRepository) + : base(userManager, deviceRepository, deviceService, userService, eventService, + organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, + applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, + userRepository, captchaValidationService) { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, - "Auth-Email header invalid."); - return; + _userManager = userManager; + _userService = userService; + _currentContext = currentContext; + _captchaValidationService = captchaValidationService; } - var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant()); - var validatorContext = new CustomValidatorRequestContext + public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context) { - User = user, - KnownDevice = await KnownDeviceAsync(user, context.Request) - }; - string bypassToken = null; - if (!validatorContext.KnownDevice && - _captchaValidationService.RequireCaptchaValidation(_currentContext, user)) - { - var captchaResponse = context.Request.Raw["captchaResponse"]?.ToString(); - - if (string.IsNullOrWhiteSpace(captchaResponse)) + if (!AuthEmailHeaderIsValid(context)) { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Captcha required.", - new Dictionary - { - { _captchaValidationService.SiteKeyResponseKeyName, _captchaValidationService.SiteKey }, - }); + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, + "Auth-Email header invalid."); return; } - validatorContext.CaptchaResponse = await _captchaValidationService.ValidateCaptchaResponseAsync( - captchaResponse, _currentContext.IpAddress, user); - if (!validatorContext.CaptchaResponse.Success) + var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant()); + var validatorContext = new CustomValidatorRequestContext { - await BuildErrorResultAsync("Captcha is invalid. Please refresh and try again", false, context, null); - return; - } - bypassToken = _captchaValidationService.GenerateCaptchaBypassToken(user); - } - - await ValidateAsync(context, context.Request, validatorContext); - if (context.Result.CustomResponse != null && bypassToken != null) - { - context.Result.CustomResponse["CaptchaBypassToken"] = bypassToken; - } - } - - protected async override Task ValidateContextAsync(ResourceOwnerPasswordValidationContext context, - CustomValidatorRequestContext validatorContext) - { - if (string.IsNullOrWhiteSpace(context.UserName) || validatorContext.User == null) - { - return false; - } - - if (!await _userService.CheckPasswordAsync(validatorContext.User, context.Password)) - { - return false; - } - - return true; - } - - protected override Task SetSuccessResult(ResourceOwnerPasswordValidationContext context, User user, - List claims, Dictionary customResponse) - { - context.Result = new GrantValidationResult(user.Id.ToString(), "Application", - identityProvider: "bitwarden", - claims: claims.Count > 0 ? claims : null, - customResponse: customResponse); - return Task.CompletedTask; - } - - protected override void SetTwoFactorResult(ResourceOwnerPasswordValidationContext context, - Dictionary customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Two factor required.", - customResponse); - } - - protected override void SetSsoResult(ResourceOwnerPasswordValidationContext context, - Dictionary customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - - protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, - Dictionary customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse); - } - - private bool AuthEmailHeaderIsValid(ResourceOwnerPasswordValidationContext context) - { - if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Auth-Email")) - { - return false; - } - else - { - try + User = user, + KnownDevice = await KnownDeviceAsync(user, context.Request) + }; + string bypassToken = null; + if (!validatorContext.KnownDevice && + _captchaValidationService.RequireCaptchaValidation(_currentContext, user)) { - var authEmailHeader = _currentContext.HttpContext.Request.Headers["Auth-Email"]; - var authEmailDecoded = CoreHelpers.Base64UrlDecodeString(authEmailHeader); + var captchaResponse = context.Request.Raw["captchaResponse"]?.ToString(); - if (authEmailDecoded != context.UserName) + if (string.IsNullOrWhiteSpace(captchaResponse)) { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Captcha required.", + new Dictionary + { + { _captchaValidationService.SiteKeyResponseKeyName, _captchaValidationService.SiteKey }, + }); + return; + } + + validatorContext.CaptchaResponse = await _captchaValidationService.ValidateCaptchaResponseAsync( + captchaResponse, _currentContext.IpAddress, user); + if (!validatorContext.CaptchaResponse.Success) + { + await BuildErrorResultAsync("Captcha is invalid. Please refresh and try again", false, context, null); + return; + } + bypassToken = _captchaValidationService.GenerateCaptchaBypassToken(user); + } + + await ValidateAsync(context, context.Request, validatorContext); + if (context.Result.CustomResponse != null && bypassToken != null) + { + context.Result.CustomResponse["CaptchaBypassToken"] = bypassToken; + } + } + + protected async override Task ValidateContextAsync(ResourceOwnerPasswordValidationContext context, + CustomValidatorRequestContext validatorContext) + { + if (string.IsNullOrWhiteSpace(context.UserName) || validatorContext.User == null) + { + return false; + } + + if (!await _userService.CheckPasswordAsync(validatorContext.User, context.Password)) + { + return false; + } + + return true; + } + + protected override Task SetSuccessResult(ResourceOwnerPasswordValidationContext context, User user, + List claims, Dictionary customResponse) + { + context.Result = new GrantValidationResult(user.Id.ToString(), "Application", + identityProvider: "bitwarden", + claims: claims.Count > 0 ? claims : null, + customResponse: customResponse); + return Task.CompletedTask; + } + + protected override void SetTwoFactorResult(ResourceOwnerPasswordValidationContext context, + Dictionary customResponse) + { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Two factor required.", + customResponse); + } + + protected override void SetSsoResult(ResourceOwnerPasswordValidationContext context, + Dictionary customResponse) + { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", + customResponse); + } + + protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, + Dictionary customResponse) + { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse); + } + + private bool AuthEmailHeaderIsValid(ResourceOwnerPasswordValidationContext context) + { + if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Auth-Email")) + { + return false; + } + else + { + try + { + var authEmailHeader = _currentContext.HttpContext.Request.Headers["Auth-Email"]; + var authEmailDecoded = CoreHelpers.Base64UrlDecodeString(authEmailHeader); + + if (authEmailDecoded != context.UserName) + { + return false; + } + } + catch (System.Exception e) when (e is System.InvalidOperationException || e is System.FormatException) + { + // Invalid B64 encoding return false; } } - catch (System.Exception e) when (e is System.InvalidOperationException || e is System.FormatException) - { - // Invalid B64 encoding - return false; - } - } - return true; + return true; + } } } diff --git a/src/Core/IdentityServer/StaticClientStore.cs b/src/Core/IdentityServer/StaticClientStore.cs index 92c124f26d..60bff26e70 100644 --- a/src/Core/IdentityServer/StaticClientStore.cs +++ b/src/Core/IdentityServer/StaticClientStore.cs @@ -2,22 +2,23 @@ using Bit.Core.Settings; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer; - -public class StaticClientStore +namespace Bit.Core.IdentityServer { - public StaticClientStore(GlobalSettings globalSettings) + public class StaticClientStore { - ApiClients = new List + public StaticClientStore(GlobalSettings globalSettings) { - new ApiClient(globalSettings, BitwardenClient.Mobile, 90, 1), - new ApiClient(globalSettings, BitwardenClient.Web, 30, 1), - new ApiClient(globalSettings, BitwardenClient.Browser, 30, 1), - new ApiClient(globalSettings, BitwardenClient.Desktop, 30, 1), - new ApiClient(globalSettings, BitwardenClient.Cli, 30, 1), - new ApiClient(globalSettings, BitwardenClient.DirectoryConnector, 30, 24) - }.ToDictionary(c => c.ClientId); - } + ApiClients = new List + { + new ApiClient(globalSettings, BitwardenClient.Mobile, 90, 1), + new ApiClient(globalSettings, BitwardenClient.Web, 30, 1), + new ApiClient(globalSettings, BitwardenClient.Browser, 30, 1), + new ApiClient(globalSettings, BitwardenClient.Desktop, 30, 1), + new ApiClient(globalSettings, BitwardenClient.Cli, 30, 1), + new ApiClient(globalSettings, BitwardenClient.DirectoryConnector, 30, 24) + }.ToDictionary(c => c.ClientId); + } - public IDictionary ApiClients { get; private set; } + public IDictionary ApiClients { get; private set; } + } } diff --git a/src/Core/IdentityServer/TokenRetrieval.cs b/src/Core/IdentityServer/TokenRetrieval.cs index 8c8ecfbc45..7290576f02 100644 --- a/src/Core/IdentityServer/TokenRetrieval.cs +++ b/src/Core/IdentityServer/TokenRetrieval.cs @@ -1,29 +1,30 @@ using Microsoft.AspNetCore.Http; -namespace Bit.Core.IdentityServer; - -public static class TokenRetrieval +namespace Bit.Core.IdentityServer { - private static string _headerScheme = "Bearer "; - private static string _queuryScheme = "access_token"; - private static string _authHeader = "Authorization"; - - public static Func FromAuthorizationHeaderOrQueryString() + public static class TokenRetrieval { - return (request) => + private static string _headerScheme = "Bearer "; + private static string _queuryScheme = "access_token"; + private static string _authHeader = "Authorization"; + + public static Func FromAuthorizationHeaderOrQueryString() { - var authorization = request.Headers[_authHeader].FirstOrDefault(); - if (string.IsNullOrWhiteSpace(authorization)) + return (request) => { - return request.Query[_queuryScheme].FirstOrDefault(); - } + var authorization = request.Headers[_authHeader].FirstOrDefault(); + if (string.IsNullOrWhiteSpace(authorization)) + { + return request.Query[_queuryScheme].FirstOrDefault(); + } - if (authorization.StartsWith(_headerScheme, StringComparison.OrdinalIgnoreCase)) - { - return authorization.Substring(_headerScheme.Length).Trim(); - } + if (authorization.StartsWith(_headerScheme, StringComparison.OrdinalIgnoreCase)) + { + return authorization.Substring(_headerScheme.Length).Trim(); + } - return null; - }; + return null; + }; + } } } diff --git a/src/Core/IdentityServer/VaultCorsPolicyService.cs b/src/Core/IdentityServer/VaultCorsPolicyService.cs index 49abcb4aad..42b76135e0 100644 --- a/src/Core/IdentityServer/VaultCorsPolicyService.cs +++ b/src/Core/IdentityServer/VaultCorsPolicyService.cs @@ -2,19 +2,20 @@ using Bit.Core.Utilities; using IdentityServer4.Services; -namespace Bit.Core.IdentityServer; - -public class CustomCorsPolicyService : ICorsPolicyService +namespace Bit.Core.IdentityServer { - private readonly GlobalSettings _globalSettings; - - public CustomCorsPolicyService(GlobalSettings globalSettings) + public class CustomCorsPolicyService : ICorsPolicyService { - _globalSettings = globalSettings; - } + private readonly GlobalSettings _globalSettings; - public Task IsOriginAllowedAsync(string origin) - { - return Task.FromResult(CoreHelpers.IsCorsOriginAllowed(origin, _globalSettings)); + public CustomCorsPolicyService(GlobalSettings globalSettings) + { + _globalSettings = globalSettings; + } + + public Task IsOriginAllowedAsync(string origin) + { + return Task.FromResult(CoreHelpers.IsCorsOriginAllowed(origin, _globalSettings)); + } } } diff --git a/src/Core/Jobs/BaseJob.cs b/src/Core/Jobs/BaseJob.cs index 56c39014a7..c1eb7d264e 100644 --- a/src/Core/Jobs/BaseJob.cs +++ b/src/Core/Jobs/BaseJob.cs @@ -1,28 +1,29 @@ using Microsoft.Extensions.Logging; using Quartz; -namespace Bit.Core.Jobs; - -public abstract class BaseJob : IJob +namespace Bit.Core.Jobs { - protected readonly ILogger _logger; - - public BaseJob(ILogger logger) + public abstract class BaseJob : IJob { - _logger = logger; - } + protected readonly ILogger _logger; - public async Task Execute(IJobExecutionContext context) - { - try + public BaseJob(ILogger logger) { - await ExecuteJobAsync(context); + _logger = logger; } - catch (Exception e) - { - _logger.LogError(2, e, "Error performing {0}.", GetType().Name); - } - } - protected abstract Task ExecuteJobAsync(IJobExecutionContext context); + public async Task Execute(IJobExecutionContext context) + { + try + { + await ExecuteJobAsync(context); + } + catch (Exception e) + { + _logger.LogError(2, e, "Error performing {0}.", GetType().Name); + } + } + + protected abstract Task ExecuteJobAsync(IJobExecutionContext context); + } } diff --git a/src/Core/Jobs/BaseJobsHostedService.cs b/src/Core/Jobs/BaseJobsHostedService.cs index 897a382a2b..c9d2bda1c1 100644 --- a/src/Core/Jobs/BaseJobsHostedService.cs +++ b/src/Core/Jobs/BaseJobsHostedService.cs @@ -6,145 +6,146 @@ using Quartz; using Quartz.Impl; using Quartz.Impl.Matchers; -namespace Bit.Core.Jobs; - -public abstract class BaseJobsHostedService : IHostedService, IDisposable +namespace Bit.Core.Jobs { - private const int MaximumJobRetries = 10; - - private readonly IServiceProvider _serviceProvider; - private readonly ILogger _listenerLogger; - protected readonly ILogger _logger; - - private IScheduler _scheduler; - protected GlobalSettings _globalSettings; - - public BaseJobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) + public abstract class BaseJobsHostedService : IHostedService, IDisposable { - _serviceProvider = serviceProvider; - _logger = logger; - _listenerLogger = listenerLogger; - _globalSettings = globalSettings; - } + private const int MaximumJobRetries = 10; - public IEnumerable> Jobs { get; protected set; } + private readonly IServiceProvider _serviceProvider; + private readonly ILogger _listenerLogger; + protected readonly ILogger _logger; - public virtual async Task StartAsync(CancellationToken cancellationToken) - { - var props = new NameValueCollection + private IScheduler _scheduler; + protected GlobalSettings _globalSettings; + + public BaseJobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) { - {"quartz.serializer.type", "binary"}, - }; - - if (!string.IsNullOrEmpty(_globalSettings.SqlServer.JobSchedulerConnectionString)) - { - // Ensure each project has a unique instanceName - props.Add("quartz.scheduler.instanceName", GetType().FullName); - props.Add("quartz.scheduler.instanceId", "AUTO"); - props.Add("quartz.jobStore.type", "Quartz.Impl.AdoJobStore.JobStoreTX"); - props.Add("quartz.jobStore.driverDelegateType", "Quartz.Impl.AdoJobStore.SqlServerDelegate"); - props.Add("quartz.jobStore.useProperties", "true"); - props.Add("quartz.jobStore.dataSource", "default"); - props.Add("quartz.jobStore.tablePrefix", "QRTZ_"); - props.Add("quartz.jobStore.clustered", "true"); - props.Add("quartz.dataSource.default.provider", "SqlServer"); - props.Add("quartz.dataSource.default.connectionString", _globalSettings.SqlServer.JobSchedulerConnectionString); + _serviceProvider = serviceProvider; + _logger = logger; + _listenerLogger = listenerLogger; + _globalSettings = globalSettings; } - var factory = new StdSchedulerFactory(props); - _scheduler = await factory.GetScheduler(cancellationToken); - _scheduler.JobFactory = new JobFactory(_serviceProvider); - _scheduler.ListenerManager.AddJobListener(new JobListener(_listenerLogger), - GroupMatcher.AnyGroup()); - await _scheduler.Start(cancellationToken); - if (Jobs != null) + public IEnumerable> Jobs { get; protected set; } + + public virtual async Task StartAsync(CancellationToken cancellationToken) { - foreach (var (job, trigger) in Jobs) + var props = new NameValueCollection { - for (var retry = 0; retry < MaximumJobRetries; retry++) + {"quartz.serializer.type", "binary"}, + }; + + if (!string.IsNullOrEmpty(_globalSettings.SqlServer.JobSchedulerConnectionString)) + { + // Ensure each project has a unique instanceName + props.Add("quartz.scheduler.instanceName", GetType().FullName); + props.Add("quartz.scheduler.instanceId", "AUTO"); + props.Add("quartz.jobStore.type", "Quartz.Impl.AdoJobStore.JobStoreTX"); + props.Add("quartz.jobStore.driverDelegateType", "Quartz.Impl.AdoJobStore.SqlServerDelegate"); + props.Add("quartz.jobStore.useProperties", "true"); + props.Add("quartz.jobStore.dataSource", "default"); + props.Add("quartz.jobStore.tablePrefix", "QRTZ_"); + props.Add("quartz.jobStore.clustered", "true"); + props.Add("quartz.dataSource.default.provider", "SqlServer"); + props.Add("quartz.dataSource.default.connectionString", _globalSettings.SqlServer.JobSchedulerConnectionString); + } + + var factory = new StdSchedulerFactory(props); + _scheduler = await factory.GetScheduler(cancellationToken); + _scheduler.JobFactory = new JobFactory(_serviceProvider); + _scheduler.ListenerManager.AddJobListener(new JobListener(_listenerLogger), + GroupMatcher.AnyGroup()); + await _scheduler.Start(cancellationToken); + if (Jobs != null) + { + foreach (var (job, trigger) in Jobs) { - // There's a race condition when starting multiple containers simultaneously, retry until it succeeds.. - try + for (var retry = 0; retry < MaximumJobRetries; retry++) { - var dupeT = await _scheduler.GetTrigger(trigger.Key); - if (dupeT != null) + // There's a race condition when starting multiple containers simultaneously, retry until it succeeds.. + try { - await _scheduler.RescheduleJob(trigger.Key, trigger); + var dupeT = await _scheduler.GetTrigger(trigger.Key); + if (dupeT != null) + { + await _scheduler.RescheduleJob(trigger.Key, trigger); + } + + var jobDetail = JobBuilder.Create(job) + .WithIdentity(job.FullName) + .Build(); + + var dupeJ = await _scheduler.GetJobDetail(jobDetail.Key); + if (dupeJ != null) + { + await _scheduler.DeleteJob(jobDetail.Key); + } + + await _scheduler.ScheduleJob(jobDetail, trigger); + break; } - - var jobDetail = JobBuilder.Create(job) - .WithIdentity(job.FullName) - .Build(); - - var dupeJ = await _scheduler.GetJobDetail(jobDetail.Key); - if (dupeJ != null) + catch (Exception e) { - await _scheduler.DeleteJob(jobDetail.Key); - } + if (retry == MaximumJobRetries - 1) + { + throw new Exception("Job failed to start after 10 retries."); + } - await _scheduler.ScheduleJob(jobDetail, trigger); - break; - } - catch (Exception e) - { - if (retry == MaximumJobRetries - 1) - { - throw new Exception("Job failed to start after 10 retries."); + _logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}"); + var random = new Random(); + Thread.Sleep(random.Next(50, 250)); } - - _logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}"); - var random = new Random(); - Thread.Sleep(random.Next(50, 250)); } } } - } - // Delete old Jobs and Triggers - var existingJobKeys = await _scheduler.GetJobKeys(GroupMatcher.AnyGroup()); - var jobKeys = Jobs.Select(j => - { - var job = j.Item1; - return JobBuilder.Create(job) - .WithIdentity(job.FullName) - .Build().Key; - }); - - foreach (var key in existingJobKeys) - { - if (jobKeys.Contains(key)) + // Delete old Jobs and Triggers + var existingJobKeys = await _scheduler.GetJobKeys(GroupMatcher.AnyGroup()); + var jobKeys = Jobs.Select(j => { - continue; + var job = j.Item1; + return JobBuilder.Create(job) + .WithIdentity(job.FullName) + .Build().Key; + }); + + foreach (var key in existingJobKeys) + { + if (jobKeys.Contains(key)) + { + continue; + } + + _logger.LogInformation($"Deleting old job with key {key}"); + await _scheduler.DeleteJob(key); } - _logger.LogInformation($"Deleting old job with key {key}"); - await _scheduler.DeleteJob(key); - } + var existingTriggerKeys = await _scheduler.GetTriggerKeys(GroupMatcher.AnyGroup()); + var triggerKeys = Jobs.Select(j => j.Item2.Key); - var existingTriggerKeys = await _scheduler.GetTriggerKeys(GroupMatcher.AnyGroup()); - var triggerKeys = Jobs.Select(j => j.Item2.Key); - - foreach (var key in existingTriggerKeys) - { - if (triggerKeys.Contains(key)) + foreach (var key in existingTriggerKeys) { - continue; + if (triggerKeys.Contains(key)) + { + continue; + } + + _logger.LogInformation($"Unscheduling old trigger with key {key}"); + await _scheduler.UnscheduleJob(key); } - - _logger.LogInformation($"Unscheduling old trigger with key {key}"); - await _scheduler.UnscheduleJob(key); } - } - public virtual async Task StopAsync(CancellationToken cancellationToken) - { - await _scheduler?.Shutdown(cancellationToken); - } + public virtual async Task StopAsync(CancellationToken cancellationToken) + { + await _scheduler?.Shutdown(cancellationToken); + } - public virtual void Dispose() - { } + public virtual void Dispose() + { } + } } diff --git a/src/Core/Jobs/JobFactory.cs b/src/Core/Jobs/JobFactory.cs index ee95c6b2d6..00cf63b268 100644 --- a/src/Core/Jobs/JobFactory.cs +++ b/src/Core/Jobs/JobFactory.cs @@ -1,25 +1,26 @@ using Quartz; using Quartz.Spi; -namespace Bit.Core.Jobs; - -public class JobFactory : IJobFactory +namespace Bit.Core.Jobs { - private readonly IServiceProvider _container; - - public JobFactory(IServiceProvider container) + public class JobFactory : IJobFactory { - _container = container; - } + private readonly IServiceProvider _container; - public IJob NewJob(TriggerFiredBundle bundle, IScheduler scheduler) - { - return _container.GetService(bundle.JobDetail.JobType) as IJob; - } + public JobFactory(IServiceProvider container) + { + _container = container; + } - public void ReturnJob(IJob job) - { - var disposable = job as IDisposable; - disposable?.Dispose(); + public IJob NewJob(TriggerFiredBundle bundle, IScheduler scheduler) + { + return _container.GetService(bundle.JobDetail.JobType) as IJob; + } + + public void ReturnJob(IJob job) + { + var disposable = job as IDisposable; + disposable?.Dispose(); + } } } diff --git a/src/Core/Jobs/JobListener.cs b/src/Core/Jobs/JobListener.cs index e5e05e4b6b..8fb56828e8 100644 --- a/src/Core/Jobs/JobListener.cs +++ b/src/Core/Jobs/JobListener.cs @@ -1,38 +1,39 @@ using Microsoft.Extensions.Logging; using Quartz; -namespace Bit.Core.Jobs; - -public class JobListener : IJobListener +namespace Bit.Core.Jobs { - private readonly ILogger _logger; - - public JobListener(ILogger logger) + public class JobListener : IJobListener { - _logger = logger; - } + private readonly ILogger _logger; - public string Name => "JobListener"; + public JobListener(ILogger logger) + { + _logger = logger; + } - public Task JobExecutionVetoed(IJobExecutionContext context, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(0); - } + public string Name => "JobListener"; - public Task JobToBeExecuted(IJobExecutionContext context, - CancellationToken cancellationToken = default(CancellationToken)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, "Starting job {0} at {1}.", - context.JobDetail.JobType.Name, DateTime.UtcNow); - return Task.FromResult(0); - } + public Task JobExecutionVetoed(IJobExecutionContext context, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(0); + } - public Task JobWasExecuted(IJobExecutionContext context, JobExecutionException jobException, - CancellationToken cancellationToken = default(CancellationToken)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, "Finished job {0} at {1}.", - context.JobDetail.JobType.Name, DateTime.UtcNow); - return Task.FromResult(0); + public Task JobToBeExecuted(IJobExecutionContext context, + CancellationToken cancellationToken = default(CancellationToken)) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, "Starting job {0} at {1}.", + context.JobDetail.JobType.Name, DateTime.UtcNow); + return Task.FromResult(0); + } + + public Task JobWasExecuted(IJobExecutionContext context, JobExecutionException jobException, + CancellationToken cancellationToken = default(CancellationToken)) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, "Finished job {0} at {1}.", + context.JobDetail.JobType.Name, DateTime.UtcNow); + return Task.FromResult(0); + } } } diff --git a/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs b/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs index 18e6c1f5e7..77015a96ef 100644 --- a/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs @@ -1,26 +1,27 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Core.Models.Api.Request.Accounts; - -public class KeysRequestModel +namespace Bit.Core.Models.Api.Request.Accounts { - public string PublicKey { get; set; } - [Required] - public string EncryptedPrivateKey { get; set; } - - public User ToUser(User existingUser) + public class KeysRequestModel { - if (string.IsNullOrWhiteSpace(existingUser.PublicKey) && !string.IsNullOrWhiteSpace(PublicKey)) - { - existingUser.PublicKey = PublicKey; - } + public string PublicKey { get; set; } + [Required] + public string EncryptedPrivateKey { get; set; } - if (string.IsNullOrWhiteSpace(existingUser.PrivateKey)) + public User ToUser(User existingUser) { - existingUser.PrivateKey = EncryptedPrivateKey; - } + if (string.IsNullOrWhiteSpace(existingUser.PublicKey) && !string.IsNullOrWhiteSpace(PublicKey)) + { + existingUser.PublicKey = PublicKey; + } - return existingUser; + if (string.IsNullOrWhiteSpace(existingUser.PrivateKey)) + { + existingUser.PrivateKey = EncryptedPrivateKey; + } + + return existingUser; + } } } diff --git a/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs b/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs index 43a391ab94..dca9e08bf7 100644 --- a/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs @@ -1,11 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Api.Request.Accounts; - -public class PreloginRequestModel +namespace Bit.Core.Models.Api.Request.Accounts { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } + public class PreloginRequestModel + { + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + } } diff --git a/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs b/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs index eac394b110..2b7c36a896 100644 --- a/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs @@ -4,73 +4,74 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Models.Api.Request.Accounts; - -public class RegisterRequestModel : IValidatableObject, ICaptchaProtectedModel +namespace Bit.Core.Models.Api.Request.Accounts { - [StringLength(50)] - public string Name { get; set; } - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string Email { get; set; } - [Required] - [StringLength(1000)] - public string MasterPasswordHash { get; set; } - [StringLength(50)] - public string MasterPasswordHint { get; set; } - public string CaptchaResponse { get; set; } - public string Key { get; set; } - public KeysRequestModel Keys { get; set; } - public string Token { get; set; } - public Guid? OrganizationUserId { get; set; } - public KdfType? Kdf { get; set; } - public int? KdfIterations { get; set; } - public Dictionary ReferenceData { get; set; } - - public User ToUser() + public class RegisterRequestModel : IValidatableObject, ICaptchaProtectedModel { - var user = new User - { - Name = Name, - Email = Email, - MasterPasswordHint = MasterPasswordHint, - Kdf = Kdf.GetValueOrDefault(KdfType.PBKDF2_SHA256), - KdfIterations = KdfIterations.GetValueOrDefault(5000), - }; + [StringLength(50)] + public string Name { get; set; } + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string Email { get; set; } + [Required] + [StringLength(1000)] + public string MasterPasswordHash { get; set; } + [StringLength(50)] + public string MasterPasswordHint { get; set; } + public string CaptchaResponse { get; set; } + public string Key { get; set; } + public KeysRequestModel Keys { get; set; } + public string Token { get; set; } + public Guid? OrganizationUserId { get; set; } + public KdfType? Kdf { get; set; } + public int? KdfIterations { get; set; } + public Dictionary ReferenceData { get; set; } - if (ReferenceData != null) + public User ToUser() { - user.ReferenceData = JsonSerializer.Serialize(ReferenceData); - } - - if (Key != null) - { - user.Key = Key; - } - - if (Keys != null) - { - Keys.ToUser(user); - } - - return user; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (Kdf.HasValue && KdfIterations.HasValue) - { - switch (Kdf.Value) + var user = new User { - case KdfType.PBKDF2_SHA256: - if (KdfIterations.Value < 5000 || KdfIterations.Value > 1_000_000) - { - yield return new ValidationResult("KDF iterations must be between 5000 and 1000000."); - } - break; - default: - break; + Name = Name, + Email = Email, + MasterPasswordHint = MasterPasswordHint, + Kdf = Kdf.GetValueOrDefault(KdfType.PBKDF2_SHA256), + KdfIterations = KdfIterations.GetValueOrDefault(5000), + }; + + if (ReferenceData != null) + { + user.ReferenceData = JsonSerializer.Serialize(ReferenceData); + } + + if (Key != null) + { + user.Key = Key; + } + + if (Keys != null) + { + Keys.ToUser(user); + } + + return user; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Kdf.HasValue && KdfIterations.HasValue) + { + switch (Kdf.Value) + { + case KdfType.PBKDF2_SHA256: + if (KdfIterations.Value < 5000 || KdfIterations.Value > 1_000_000) + { + yield return new ValidationResult("KDF iterations must be between 5000 and 1000000."); + } + break; + default: + break; + } } } } diff --git a/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs b/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs index f1c9771d15..9084ecc89e 100644 --- a/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs +++ b/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Api; - -public interface ICaptchaProtectedModel +namespace Bit.Core.Models.Api { - string CaptchaResponse { get; set; } + public interface ICaptchaProtectedModel + { + string CaptchaResponse { get; set; } + } } diff --git a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs index 8be4a672db..7440e7ba3f 100644 --- a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs +++ b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs @@ -2,54 +2,55 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Request.OrganizationSponsorships; - -public class OrganizationSponsorshipRequestModel +namespace Bit.Core.Models.Api.Request.OrganizationSponsorships { - public Guid SponsoringOrganizationUserId { get; set; } - public string FriendlyName { get; set; } - public string OfferedToEmail { get; set; } - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } - - public OrganizationSponsorshipRequestModel() { } - - public OrganizationSponsorshipRequestModel(OrganizationSponsorshipData sponsorshipData) + public class OrganizationSponsorshipRequestModel { - SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; - FriendlyName = sponsorshipData.FriendlyName; - OfferedToEmail = sponsorshipData.OfferedToEmail; - PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; - LastSyncDate = sponsorshipData.LastSyncDate; - ValidUntil = sponsorshipData.ValidUntil; - ToDelete = sponsorshipData.ToDelete; - } + public Guid SponsoringOrganizationUserId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } - public OrganizationSponsorshipRequestModel(OrganizationSponsorship sponsorship) - { - SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; - FriendlyName = sponsorship.FriendlyName; - OfferedToEmail = sponsorship.OfferedToEmail; - PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); - LastSyncDate = sponsorship.LastSyncDate; - ValidUntil = sponsorship.ValidUntil; - ToDelete = sponsorship.ToDelete; - } + public OrganizationSponsorshipRequestModel() { } - public OrganizationSponsorshipData ToOrganizationSponsorship() - { - return new OrganizationSponsorshipData + public OrganizationSponsorshipRequestModel(OrganizationSponsorshipData sponsorshipData) { - SponsoringOrganizationUserId = SponsoringOrganizationUserId, - FriendlyName = FriendlyName, - OfferedToEmail = OfferedToEmail, - PlanSponsorshipType = PlanSponsorshipType, - LastSyncDate = LastSyncDate, - ValidUntil = ValidUntil, - ToDelete = ToDelete, - }; + SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; + FriendlyName = sponsorshipData.FriendlyName; + OfferedToEmail = sponsorshipData.OfferedToEmail; + PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; + LastSyncDate = sponsorshipData.LastSyncDate; + ValidUntil = sponsorshipData.ValidUntil; + ToDelete = sponsorshipData.ToDelete; + } + public OrganizationSponsorshipRequestModel(OrganizationSponsorship sponsorship) + { + SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; + FriendlyName = sponsorship.FriendlyName; + OfferedToEmail = sponsorship.OfferedToEmail; + PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); + LastSyncDate = sponsorship.LastSyncDate; + ValidUntil = sponsorship.ValidUntil; + ToDelete = sponsorship.ToDelete; + } + + public OrganizationSponsorshipData ToOrganizationSponsorship() + { + return new OrganizationSponsorshipData + { + SponsoringOrganizationUserId = SponsoringOrganizationUserId, + FriendlyName = FriendlyName, + OfferedToEmail = OfferedToEmail, + PlanSponsorshipType = PlanSponsorshipType, + LastSyncDate = LastSyncDate, + ValidUntil = ValidUntil, + ToDelete = ToDelete, + }; + + } } } diff --git a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs index 283c07d199..9def44d60b 100644 --- a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs +++ b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs @@ -1,39 +1,40 @@ using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Request.OrganizationSponsorships; - -public class OrganizationSponsorshipSyncRequestModel +namespace Bit.Core.Models.Api.Request.OrganizationSponsorships { - public string BillingSyncKey { get; set; } - public Guid SponsoringOrganizationCloudId { get; set; } - public IEnumerable SponsorshipsBatch { get; set; } - - public OrganizationSponsorshipSyncRequestModel() { } - - public OrganizationSponsorshipSyncRequestModel(IEnumerable sponsorshipsBatch) + public class OrganizationSponsorshipSyncRequestModel { - SponsorshipsBatch = sponsorshipsBatch; - } + public string BillingSyncKey { get; set; } + public Guid SponsoringOrganizationCloudId { get; set; } + public IEnumerable SponsorshipsBatch { get; set; } - public OrganizationSponsorshipSyncRequestModel(OrganizationSponsorshipSyncData syncData) - { - if (syncData == null) + public OrganizationSponsorshipSyncRequestModel() { } + + public OrganizationSponsorshipSyncRequestModel(IEnumerable sponsorshipsBatch) { - return; + SponsorshipsBatch = sponsorshipsBatch; } - BillingSyncKey = syncData.BillingSyncKey; - SponsoringOrganizationCloudId = syncData.SponsoringOrganizationCloudId; - SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipRequestModel(o)); - } - public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() - { - return new OrganizationSponsorshipSyncData() + public OrganizationSponsorshipSyncRequestModel(OrganizationSponsorshipSyncData syncData) { - BillingSyncKey = BillingSyncKey, - SponsoringOrganizationCloudId = SponsoringOrganizationCloudId, - SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) - }; - } + if (syncData == null) + { + return; + } + BillingSyncKey = syncData.BillingSyncKey; + SponsoringOrganizationCloudId = syncData.SponsoringOrganizationCloudId; + SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipRequestModel(o)); + } + public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() + { + return new OrganizationSponsorshipSyncData() + { + BillingSyncKey = BillingSyncKey, + SponsoringOrganizationCloudId = SponsoringOrganizationCloudId, + SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) + }; + } + + } } diff --git a/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs b/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs index 580c1c3b60..fd74b50afd 100644 --- a/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs +++ b/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs @@ -1,18 +1,19 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Core.Models.Api; - -public class PushRegistrationRequestModel +namespace Bit.Core.Models.Api { - [Required] - public string DeviceId { get; set; } - [Required] - public string PushToken { get; set; } - [Required] - public string UserId { get; set; } - [Required] - public DeviceType Type { get; set; } - [Required] - public string Identifier { get; set; } + public class PushRegistrationRequestModel + { + [Required] + public string DeviceId { get; set; } + [Required] + public string PushToken { get; set; } + [Required] + public string UserId { get; set; } + [Required] + public DeviceType Type { get; set; } + [Required] + public string Identifier { get; set; } + } } diff --git a/src/Core/Models/Api/Request/PushSendRequestModel.cs b/src/Core/Models/Api/Request/PushSendRequestModel.cs index b85c8fb555..108db58048 100644 --- a/src/Core/Models/Api/Request/PushSendRequestModel.cs +++ b/src/Core/Models/Api/Request/PushSendRequestModel.cs @@ -1,24 +1,25 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Core.Models.Api; - -public class PushSendRequestModel : IValidatableObject +namespace Bit.Core.Models.Api { - public string UserId { get; set; } - public string OrganizationId { get; set; } - public string DeviceId { get; set; } - public string Identifier { get; set; } - [Required] - public PushType? Type { get; set; } - [Required] - public object Payload { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + public class PushSendRequestModel : IValidatableObject { - if (string.IsNullOrWhiteSpace(UserId) && string.IsNullOrWhiteSpace(OrganizationId)) + public string UserId { get; set; } + public string OrganizationId { get; set; } + public string DeviceId { get; set; } + public string Identifier { get; set; } + [Required] + public PushType? Type { get; set; } + [Required] + public object Payload { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - yield return new ValidationResult($"{nameof(UserId)} or {nameof(OrganizationId)} is required."); + if (string.IsNullOrWhiteSpace(UserId) && string.IsNullOrWhiteSpace(OrganizationId)) + { + yield return new ValidationResult($"{nameof(UserId)} or {nameof(OrganizationId)} is required."); + } } } } diff --git a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs index 2ccbf6eb00..ba5c3bf961 100644 --- a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs +++ b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs @@ -1,20 +1,21 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Api; - -public class PushUpdateRequestModel +namespace Bit.Core.Models.Api { - public PushUpdateRequestModel() - { } - - public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId) + public class PushUpdateRequestModel { - DeviceIds = deviceIds; - OrganizationId = organizationId; - } + public PushUpdateRequestModel() + { } - [Required] - public IEnumerable DeviceIds { get; set; } - [Required] - public string OrganizationId { get; set; } + public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId) + { + DeviceIds = deviceIds; + OrganizationId = organizationId; + } + + [Required] + public IEnumerable DeviceIds { get; set; } + [Required] + public string OrganizationId { get; set; } + } } diff --git a/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs b/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs index 9fb2de7de2..755182f765 100644 --- a/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs +++ b/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs @@ -1,16 +1,17 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Models.Api.Response.Accounts; - -public class PreloginResponseModel +namespace Bit.Core.Models.Api.Response.Accounts { - public PreloginResponseModel(UserKdfInformation kdfInformation) + public class PreloginResponseModel { - Kdf = kdfInformation.Kdf; - KdfIterations = kdfInformation.KdfIterations; - } + public PreloginResponseModel(UserKdfInformation kdfInformation) + { + Kdf = kdfInformation.Kdf; + KdfIterations = kdfInformation.KdfIterations; + } - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } + } } diff --git a/src/Core/Models/Api/Response/ErrorResponseModel.cs b/src/Core/Models/Api/Response/ErrorResponseModel.cs index 39d6adddb1..e7f77099c0 100644 --- a/src/Core/Models/Api/Response/ErrorResponseModel.cs +++ b/src/Core/Models/Api/Response/ErrorResponseModel.cs @@ -1,73 +1,74 @@ using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Core.Models.Api; - -public class ErrorResponseModel : ResponseModel +namespace Bit.Core.Models.Api { - public ErrorResponseModel() - : base("error") - { } - - public ErrorResponseModel(string message) - : this() + public class ErrorResponseModel : ResponseModel { - Message = message; - } + public ErrorResponseModel() + : base("error") + { } - public ErrorResponseModel(ModelStateDictionary modelState) - : this() - { - Message = "The model state is invalid."; - ValidationErrors = new Dictionary>(); - - var keys = modelState.Keys.ToList(); - var values = modelState.Values.ToList(); - - for (var i = 0; i < values.Count; i++) + public ErrorResponseModel(string message) + : this() { - var value = values[i]; - - if (keys.Count <= i) - { - // Keys not available for some reason. - break; - } - - var key = keys[i]; - - if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) - { - continue; - } - - var errors = value.Errors.Select(e => e.ErrorMessage); - ValidationErrors.Add(key, errors); + Message = message; } + + public ErrorResponseModel(ModelStateDictionary modelState) + : this() + { + Message = "The model state is invalid."; + ValidationErrors = new Dictionary>(); + + var keys = modelState.Keys.ToList(); + var values = modelState.Values.ToList(); + + for (var i = 0; i < values.Count; i++) + { + var value = values[i]; + + if (keys.Count <= i) + { + // Keys not available for some reason. + break; + } + + var key = keys[i]; + + if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) + { + continue; + } + + var errors = value.Errors.Select(e => e.ErrorMessage); + ValidationErrors.Add(key, errors); + } + } + + public ErrorResponseModel(Dictionary> errors) + : this("Errors have occurred.", errors) + { } + + public ErrorResponseModel(string errorKey, string errorValue) + : this(errorKey, new string[] { errorValue }) + { } + + public ErrorResponseModel(string errorKey, IEnumerable errorValues) + : this(new Dictionary> { { errorKey, errorValues } }) + { } + + public ErrorResponseModel(string message, Dictionary> errors) + : this() + { + Message = message; + ValidationErrors = errors; + } + + public string Message { get; set; } + public Dictionary> ValidationErrors { get; set; } + // For use in development environments. + public string ExceptionMessage { get; set; } + public string ExceptionStackTrace { get; set; } + public string InnerExceptionMessage { get; set; } } - - public ErrorResponseModel(Dictionary> errors) - : this("Errors have occurred.", errors) - { } - - public ErrorResponseModel(string errorKey, string errorValue) - : this(errorKey, new string[] { errorValue }) - { } - - public ErrorResponseModel(string errorKey, IEnumerable errorValues) - : this(new Dictionary> { { errorKey, errorValues } }) - { } - - public ErrorResponseModel(string message, Dictionary> errors) - : this() - { - Message = message; - ValidationErrors = errors; - } - - public string Message { get; set; } - public Dictionary> ValidationErrors { get; set; } - // For use in development environments. - public string ExceptionMessage { get; set; } - public string ExceptionStackTrace { get; set; } - public string InnerExceptionMessage { get; set; } } diff --git a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs index 58c1b2cffb..fc5fbc70d5 100644 --- a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs +++ b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs @@ -1,47 +1,48 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Response.OrganizationSponsorships; - -public class OrganizationSponsorshipResponseModel +namespace Bit.Core.Models.Api.Response.OrganizationSponsorships { - public Guid SponsoringOrganizationUserId { get; set; } - public string FriendlyName { get; set; } - public string OfferedToEmail { get; set; } - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } - - public bool CloudSponsorshipRemoved { get; set; } - - public OrganizationSponsorshipResponseModel() { } - - public OrganizationSponsorshipResponseModel(OrganizationSponsorshipData sponsorshipData) + public class OrganizationSponsorshipResponseModel { - SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; - FriendlyName = sponsorshipData.FriendlyName; - OfferedToEmail = sponsorshipData.OfferedToEmail; - PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; - LastSyncDate = sponsorshipData.LastSyncDate; - ValidUntil = sponsorshipData.ValidUntil; - ToDelete = sponsorshipData.ToDelete; - CloudSponsorshipRemoved = sponsorshipData.CloudSponsorshipRemoved; - } + public Guid SponsoringOrganizationUserId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } - public OrganizationSponsorshipData ToOrganizationSponsorship() - { - return new OrganizationSponsorshipData + public bool CloudSponsorshipRemoved { get; set; } + + public OrganizationSponsorshipResponseModel() { } + + public OrganizationSponsorshipResponseModel(OrganizationSponsorshipData sponsorshipData) { - SponsoringOrganizationUserId = SponsoringOrganizationUserId, - FriendlyName = FriendlyName, - OfferedToEmail = OfferedToEmail, - PlanSponsorshipType = PlanSponsorshipType, - LastSyncDate = LastSyncDate, - ValidUntil = ValidUntil, - ToDelete = ToDelete, - CloudSponsorshipRemoved = CloudSponsorshipRemoved - }; + SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; + FriendlyName = sponsorshipData.FriendlyName; + OfferedToEmail = sponsorshipData.OfferedToEmail; + PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; + LastSyncDate = sponsorshipData.LastSyncDate; + ValidUntil = sponsorshipData.ValidUntil; + ToDelete = sponsorshipData.ToDelete; + CloudSponsorshipRemoved = sponsorshipData.CloudSponsorshipRemoved; + } + public OrganizationSponsorshipData ToOrganizationSponsorship() + { + return new OrganizationSponsorshipData + { + SponsoringOrganizationUserId = SponsoringOrganizationUserId, + FriendlyName = FriendlyName, + OfferedToEmail = OfferedToEmail, + PlanSponsorshipType = PlanSponsorshipType, + LastSyncDate = LastSyncDate, + ValidUntil = ValidUntil, + ToDelete = ToDelete, + CloudSponsorshipRemoved = CloudSponsorshipRemoved + }; + + } } } diff --git a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs index 5a6b635c5a..4d44ab1653 100644 --- a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs +++ b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs @@ -1,29 +1,30 @@ using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Response.OrganizationSponsorships; - -public class OrganizationSponsorshipSyncResponseModel +namespace Bit.Core.Models.Api.Response.OrganizationSponsorships { - public IEnumerable SponsorshipsBatch { get; set; } - - public OrganizationSponsorshipSyncResponseModel() { } - - public OrganizationSponsorshipSyncResponseModel(OrganizationSponsorshipSyncData syncData) + public class OrganizationSponsorshipSyncResponseModel { - if (syncData == null) + public IEnumerable SponsorshipsBatch { get; set; } + + public OrganizationSponsorshipSyncResponseModel() { } + + public OrganizationSponsorshipSyncResponseModel(OrganizationSponsorshipSyncData syncData) { - return; + if (syncData == null) + { + return; + } + SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipResponseModel(o)); + } - SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipResponseModel(o)); - } - - public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() - { - return new OrganizationSponsorshipSyncData() + public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() { - SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) - }; - } + return new OrganizationSponsorshipSyncData() + { + SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) + }; + } + } } diff --git a/src/Core/Models/Api/Response/ResponseModel.cs b/src/Core/Models/Api/Response/ResponseModel.cs index 22278b807a..539d52d107 100644 --- a/src/Core/Models/Api/Response/ResponseModel.cs +++ b/src/Core/Models/Api/Response/ResponseModel.cs @@ -1,19 +1,20 @@ using Newtonsoft.Json; -namespace Bit.Core.Models.Api; - -public abstract class ResponseModel +namespace Bit.Core.Models.Api { - public ResponseModel(string obj) + public abstract class ResponseModel { - if (string.IsNullOrWhiteSpace(obj)) + public ResponseModel(string obj) { - throw new ArgumentNullException(nameof(obj)); + if (string.IsNullOrWhiteSpace(obj)) + { + throw new ArgumentNullException(nameof(obj)); + } + + Object = obj; } - Object = obj; + [JsonProperty(Order = -200)] // Always the first property + public string Object { get; private set; } } - - [JsonProperty(Order = -200)] // Always the first property - public string Object { get; private set; } } diff --git a/src/Core/Models/Business/AppleReceiptStatus.cs b/src/Core/Models/Business/AppleReceiptStatus.cs index e54ce91e67..26f7537afd 100644 --- a/src/Core/Models/Business/AppleReceiptStatus.cs +++ b/src/Core/Models/Business/AppleReceiptStatus.cs @@ -3,132 +3,133 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Billing.Models; - -public class AppleReceiptStatus +namespace Bit.Billing.Models { - [JsonPropertyName("status")] - public int? Status { get; set; } - [JsonPropertyName("environment")] - public string Environment { get; set; } - [JsonPropertyName("latest_receipt")] - public string LatestReceipt { get; set; } - [JsonPropertyName("receipt")] - public AppleReceipt Receipt { get; set; } - [JsonPropertyName("latest_receipt_info")] - public List LatestReceiptInfo { get; set; } - [JsonPropertyName("pending_renewal_info")] - public List PendingRenewalInfo { get; set; } - - public string GetOriginalTransactionId() + public class AppleReceiptStatus { - return LatestReceiptInfo?.LastOrDefault()?.OriginalTransactionId; - } + [JsonPropertyName("status")] + public int? Status { get; set; } + [JsonPropertyName("environment")] + public string Environment { get; set; } + [JsonPropertyName("latest_receipt")] + public string LatestReceipt { get; set; } + [JsonPropertyName("receipt")] + public AppleReceipt Receipt { get; set; } + [JsonPropertyName("latest_receipt_info")] + public List LatestReceiptInfo { get; set; } + [JsonPropertyName("pending_renewal_info")] + public List PendingRenewalInfo { get; set; } - public string GetLastTransactionId() - { - return LatestReceiptInfo?.LastOrDefault()?.TransactionId; - } - - public AppleTransaction GetLastTransaction() - { - return LatestReceiptInfo?.LastOrDefault(); - } - - public DateTime? GetLastExpiresDate() - { - return LatestReceiptInfo?.LastOrDefault()?.ExpiresDate; - } - - public string GetReceiptData() - { - return LatestReceipt; - } - - public DateTime? GetLastCancellationDate() - { - return LatestReceiptInfo?.LastOrDefault()?.CancellationDate; - } - - public bool IsRefunded() - { - var cancellationDate = GetLastCancellationDate(); - var expiresDate = GetLastCancellationDate(); - if (cancellationDate.HasValue && expiresDate.HasValue) + public string GetOriginalTransactionId() { - return cancellationDate.Value <= expiresDate.Value; + return LatestReceiptInfo?.LastOrDefault()?.OriginalTransactionId; } - return false; - } - public Transaction BuildTransactionFromLastTransaction(decimal amount, Guid userId) - { - return new Transaction + public string GetLastTransactionId() { - Amount = amount, - CreationDate = GetLastTransaction().PurchaseDate, - Gateway = GatewayType.AppStore, - GatewayId = GetLastTransactionId(), - UserId = userId, - PaymentMethodType = PaymentMethodType.AppleInApp, - Details = GetLastTransactionId() - }; - } + return LatestReceiptInfo?.LastOrDefault()?.TransactionId; + } - public class AppleReceipt - { - [JsonPropertyName("receipt_type")] - public string ReceiptType { get; set; } - [JsonPropertyName("bundle_id")] - public string BundleId { get; set; } - [JsonPropertyName("receipt_creation_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime ReceiptCreationDate { get; set; } - [JsonPropertyName("in_app")] - public List InApp { get; set; } - } + public AppleTransaction GetLastTransaction() + { + return LatestReceiptInfo?.LastOrDefault(); + } - public class AppleRenewalInfo - { - [JsonPropertyName("expiration_intent")] - public string ExpirationIntent { get; set; } - [JsonPropertyName("auto_renew_product_id")] - public string AutoRenewProductId { get; set; } - [JsonPropertyName("original_transaction_id")] - public string OriginalTransactionId { get; set; } - [JsonPropertyName("is_in_billing_retry_period")] - public string IsInBillingRetryPeriod { get; set; } - [JsonPropertyName("product_id")] - public string ProductId { get; set; } - [JsonPropertyName("auto_renew_status")] - public string AutoRenewStatus { get; set; } - } + public DateTime? GetLastExpiresDate() + { + return LatestReceiptInfo?.LastOrDefault()?.ExpiresDate; + } - public class AppleTransaction - { - [JsonPropertyName("quantity")] - public string Quantity { get; set; } - [JsonPropertyName("product_id")] - public string ProductId { get; set; } - [JsonPropertyName("transaction_id")] - public string TransactionId { get; set; } - [JsonPropertyName("original_transaction_id")] - public string OriginalTransactionId { get; set; } - [JsonPropertyName("purchase_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime PurchaseDate { get; set; } - [JsonPropertyName("original_purchase_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime OriginalPurchaseDate { get; set; } - [JsonPropertyName("expires_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime ExpiresDate { get; set; } - [JsonPropertyName("cancellation_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime? CancellationDate { get; set; } - [JsonPropertyName("web_order_line_item_id")] - public string WebOrderLineItemId { get; set; } - [JsonPropertyName("cancellation_reason")] - public string CancellationReason { get; set; } + public string GetReceiptData() + { + return LatestReceipt; + } + + public DateTime? GetLastCancellationDate() + { + return LatestReceiptInfo?.LastOrDefault()?.CancellationDate; + } + + public bool IsRefunded() + { + var cancellationDate = GetLastCancellationDate(); + var expiresDate = GetLastCancellationDate(); + if (cancellationDate.HasValue && expiresDate.HasValue) + { + return cancellationDate.Value <= expiresDate.Value; + } + return false; + } + + public Transaction BuildTransactionFromLastTransaction(decimal amount, Guid userId) + { + return new Transaction + { + Amount = amount, + CreationDate = GetLastTransaction().PurchaseDate, + Gateway = GatewayType.AppStore, + GatewayId = GetLastTransactionId(), + UserId = userId, + PaymentMethodType = PaymentMethodType.AppleInApp, + Details = GetLastTransactionId() + }; + } + + public class AppleReceipt + { + [JsonPropertyName("receipt_type")] + public string ReceiptType { get; set; } + [JsonPropertyName("bundle_id")] + public string BundleId { get; set; } + [JsonPropertyName("receipt_creation_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime ReceiptCreationDate { get; set; } + [JsonPropertyName("in_app")] + public List InApp { get; set; } + } + + public class AppleRenewalInfo + { + [JsonPropertyName("expiration_intent")] + public string ExpirationIntent { get; set; } + [JsonPropertyName("auto_renew_product_id")] + public string AutoRenewProductId { get; set; } + [JsonPropertyName("original_transaction_id")] + public string OriginalTransactionId { get; set; } + [JsonPropertyName("is_in_billing_retry_period")] + public string IsInBillingRetryPeriod { get; set; } + [JsonPropertyName("product_id")] + public string ProductId { get; set; } + [JsonPropertyName("auto_renew_status")] + public string AutoRenewStatus { get; set; } + } + + public class AppleTransaction + { + [JsonPropertyName("quantity")] + public string Quantity { get; set; } + [JsonPropertyName("product_id")] + public string ProductId { get; set; } + [JsonPropertyName("transaction_id")] + public string TransactionId { get; set; } + [JsonPropertyName("original_transaction_id")] + public string OriginalTransactionId { get; set; } + [JsonPropertyName("purchase_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime PurchaseDate { get; set; } + [JsonPropertyName("original_purchase_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime OriginalPurchaseDate { get; set; } + [JsonPropertyName("expires_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime ExpiresDate { get; set; } + [JsonPropertyName("cancellation_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime? CancellationDate { get; set; } + [JsonPropertyName("web_order_line_item_id")] + public string WebOrderLineItemId { get; set; } + [JsonPropertyName("cancellation_reason")] + public string CancellationReason { get; set; } + } } } diff --git a/src/Core/Models/Business/BillingInfo.cs b/src/Core/Models/Business/BillingInfo.cs index 1e1915566c..557a3288fb 100644 --- a/src/Core/Models/Business/BillingInfo.cs +++ b/src/Core/Models/Business/BillingInfo.cs @@ -2,154 +2,155 @@ using Bit.Core.Enums; using Stripe; -namespace Bit.Core.Models.Business; - -public class BillingInfo +namespace Bit.Core.Models.Business { - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } - public IEnumerable Invoices { get; set; } = new List(); - public IEnumerable Transactions { get; set; } = new List(); - - public class BillingSource + public class BillingInfo { - public BillingSource() { } + public decimal Balance { get; set; } + public BillingSource PaymentSource { get; set; } + public IEnumerable Invoices { get; set; } = new List(); + public IEnumerable Transactions { get; set; } = new List(); - public BillingSource(PaymentMethod method) + public class BillingSource { - if (method.Card != null) - { - Type = PaymentMethodType.Card; - Description = $"{method.Card.Brand?.ToUpperInvariant()}, *{method.Card.Last4}, " + - string.Format("{0}/{1}", - string.Concat(method.Card.ExpMonth < 10 ? - "0" : string.Empty, method.Card.ExpMonth), - method.Card.ExpYear); - CardBrand = method.Card.Brand; - } - } + public BillingSource() { } - public BillingSource(IPaymentSource source) - { - if (source is BankAccount bankAccount) + public BillingSource(PaymentMethod method) { - Type = PaymentMethodType.BankAccount; - Description = $"{bankAccount.BankName}, *{bankAccount.Last4} - " + - (bankAccount.Status == "verified" ? "verified" : - bankAccount.Status == "errored" ? "invalid" : - bankAccount.Status == "verification_failed" ? "verification failed" : "unverified"); - NeedsVerification = bankAccount.Status == "new" || bankAccount.Status == "validated"; + if (method.Card != null) + { + Type = PaymentMethodType.Card; + Description = $"{method.Card.Brand?.ToUpperInvariant()}, *{method.Card.Last4}, " + + string.Format("{0}/{1}", + string.Concat(method.Card.ExpMonth < 10 ? + "0" : string.Empty, method.Card.ExpMonth), + method.Card.ExpYear); + CardBrand = method.Card.Brand; + } } - else if (source is Card card) - { - Type = PaymentMethodType.Card; - Description = $"{card.Brand}, *{card.Last4}, " + - string.Format("{0}/{1}", - string.Concat(card.ExpMonth < 10 ? - "0" : string.Empty, card.ExpMonth), - card.ExpYear); - CardBrand = card.Brand; - } - else if (source is Source src && src.Card != null) - { - Type = PaymentMethodType.Card; - Description = $"{src.Card.Brand}, *{src.Card.Last4}, " + - string.Format("{0}/{1}", - string.Concat(src.Card.ExpMonth < 10 ? - "0" : string.Empty, src.Card.ExpMonth), - src.Card.ExpYear); - CardBrand = src.Card.Brand; - } - } - public BillingSource(Braintree.PaymentMethod method) - { - if (method is Braintree.PayPalAccount paypal) + public BillingSource(IPaymentSource source) { - Type = PaymentMethodType.PayPal; - Description = paypal.Email; + if (source is BankAccount bankAccount) + { + Type = PaymentMethodType.BankAccount; + Description = $"{bankAccount.BankName}, *{bankAccount.Last4} - " + + (bankAccount.Status == "verified" ? "verified" : + bankAccount.Status == "errored" ? "invalid" : + bankAccount.Status == "verification_failed" ? "verification failed" : "unverified"); + NeedsVerification = bankAccount.Status == "new" || bankAccount.Status == "validated"; + } + else if (source is Card card) + { + Type = PaymentMethodType.Card; + Description = $"{card.Brand}, *{card.Last4}, " + + string.Format("{0}/{1}", + string.Concat(card.ExpMonth < 10 ? + "0" : string.Empty, card.ExpMonth), + card.ExpYear); + CardBrand = card.Brand; + } + else if (source is Source src && src.Card != null) + { + Type = PaymentMethodType.Card; + Description = $"{src.Card.Brand}, *{src.Card.Last4}, " + + string.Format("{0}/{1}", + string.Concat(src.Card.ExpMonth < 10 ? + "0" : string.Empty, src.Card.ExpMonth), + src.Card.ExpYear); + CardBrand = src.Card.Brand; + } } - else if (method is Braintree.CreditCard card) + + public BillingSource(Braintree.PaymentMethod method) { - Type = PaymentMethodType.Card; - Description = $"{card.CardType.ToString()}, *{card.LastFour}, " + - string.Format("{0}/{1}", - string.Concat(card.ExpirationMonth.Length == 1 ? - "0" : string.Empty, card.ExpirationMonth), - card.ExpirationYear); - CardBrand = card.CardType.ToString(); + if (method is Braintree.PayPalAccount paypal) + { + Type = PaymentMethodType.PayPal; + Description = paypal.Email; + } + else if (method is Braintree.CreditCard card) + { + Type = PaymentMethodType.Card; + Description = $"{card.CardType.ToString()}, *{card.LastFour}, " + + string.Format("{0}/{1}", + string.Concat(card.ExpirationMonth.Length == 1 ? + "0" : string.Empty, card.ExpirationMonth), + card.ExpirationYear); + CardBrand = card.CardType.ToString(); + } + else if (method is Braintree.UsBankAccount bank) + { + Type = PaymentMethodType.BankAccount; + Description = $"{bank.BankName}, *{bank.Last4}"; + } + else + { + throw new NotSupportedException("Method not supported."); + } } - else if (method is Braintree.UsBankAccount bank) + + public BillingSource(Braintree.UsBankAccountDetails bank) { Type = PaymentMethodType.BankAccount; Description = $"{bank.BankName}, *{bank.Last4}"; } - else + + public BillingSource(Braintree.PayPalDetails paypal) { - throw new NotSupportedException("Method not supported."); + Type = PaymentMethodType.PayPal; + Description = paypal.PayerEmail; } + + public PaymentMethodType Type { get; set; } + public string CardBrand { get; set; } + public string Description { get; set; } + public bool NeedsVerification { get; set; } } - public BillingSource(Braintree.UsBankAccountDetails bank) + public class BillingTransaction { - Type = PaymentMethodType.BankAccount; - Description = $"{bank.BankName}, *{bank.Last4}"; + public BillingTransaction(Transaction transaction) + { + Id = transaction.Id; + CreatedDate = transaction.CreationDate; + Refunded = transaction.Refunded; + Type = transaction.Type; + PaymentMethodType = transaction.PaymentMethodType; + Details = transaction.Details; + Amount = transaction.Amount; + RefundedAmount = transaction.RefundedAmount; + } + + public Guid Id { get; set; } + public DateTime CreatedDate { get; set; } + public decimal Amount { get; set; } + public bool? Refunded { get; set; } + public bool? PartiallyRefunded => !Refunded.GetValueOrDefault() && RefundedAmount.GetValueOrDefault() > 0; + public decimal? RefundedAmount { get; set; } + public TransactionType Type { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string Details { get; set; } } - public BillingSource(Braintree.PayPalDetails paypal) + public class BillingInvoice { - Type = PaymentMethodType.PayPal; - Description = paypal.PayerEmail; + public BillingInvoice(Invoice inv) + { + Date = inv.Created; + Url = inv.HostedInvoiceUrl; + PdfUrl = inv.InvoicePdf; + Number = inv.Number; + Paid = inv.Paid; + Amount = inv.Total / 100M; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } + public string Url { get; set; } + public string PdfUrl { get; set; } + public string Number { get; set; } + public bool Paid { get; set; } } - - public PaymentMethodType Type { get; set; } - public string CardBrand { get; set; } - public string Description { get; set; } - public bool NeedsVerification { get; set; } - } - - public class BillingTransaction - { - public BillingTransaction(Transaction transaction) - { - Id = transaction.Id; - CreatedDate = transaction.CreationDate; - Refunded = transaction.Refunded; - Type = transaction.Type; - PaymentMethodType = transaction.PaymentMethodType; - Details = transaction.Details; - Amount = transaction.Amount; - RefundedAmount = transaction.RefundedAmount; - } - - public Guid Id { get; set; } - public DateTime CreatedDate { get; set; } - public decimal Amount { get; set; } - public bool? Refunded { get; set; } - public bool? PartiallyRefunded => !Refunded.GetValueOrDefault() && RefundedAmount.GetValueOrDefault() > 0; - public decimal? RefundedAmount { get; set; } - public TransactionType Type { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string Details { get; set; } - } - - public class BillingInvoice - { - public BillingInvoice(Invoice inv) - { - Date = inv.Created; - Url = inv.HostedInvoiceUrl; - PdfUrl = inv.InvoicePdf; - Number = inv.Number; - Paid = inv.Paid; - Amount = inv.Total / 100M; - } - - public decimal Amount { get; set; } - public DateTime? Date { get; set; } - public string Url { get; set; } - public string PdfUrl { get; set; } - public string Number { get; set; } - public bool Paid { get; set; } } } diff --git a/src/Core/Models/Business/CaptchaResponse.cs b/src/Core/Models/Business/CaptchaResponse.cs index aaafc8e7d3..c77330242f 100644 --- a/src/Core/Models/Business/CaptchaResponse.cs +++ b/src/Core/Models/Business/CaptchaResponse.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Models.Business; - -public class CaptchaResponse +namespace Bit.Core.Models.Business { - public bool Success { get; set; } - public bool MaybeBot { get; set; } - public bool IsBot { get; set; } - public double Score { get; set; } + public class CaptchaResponse + { + public bool Success { get; set; } + public bool MaybeBot { get; set; } + public bool IsBot { get; set; } + public double Score { get; set; } + } } diff --git a/src/Core/Models/Business/ExpiringToken.cs b/src/Core/Models/Business/ExpiringToken.cs index db09a540ff..3ed16a1cfc 100644 --- a/src/Core/Models/Business/ExpiringToken.cs +++ b/src/Core/Models/Business/ExpiringToken.cs @@ -1,13 +1,14 @@ -namespace Bit.Core.Models.Business; - -public class ExpiringToken +namespace Bit.Core.Models.Business { - public readonly string Token; - public readonly DateTime ExpirationDate; - - public ExpiringToken(string token, DateTime expirationDate) + public class ExpiringToken { - Token = token; - ExpirationDate = expirationDate; + public readonly string Token; + public readonly DateTime ExpirationDate; + + public ExpiringToken(string token, DateTime expirationDate) + { + Token = token; + ExpirationDate = expirationDate; + } } } diff --git a/src/Core/Models/Business/ILicense.cs b/src/Core/Models/Business/ILicense.cs index ad389b0a12..0e03e41c6d 100644 --- a/src/Core/Models/Business/ILicense.cs +++ b/src/Core/Models/Business/ILicense.cs @@ -1,20 +1,21 @@ using System.Security.Cryptography.X509Certificates; -namespace Bit.Core.Models.Business; - -public interface ILicense +namespace Bit.Core.Models.Business { - string LicenseKey { get; set; } - int Version { get; set; } - DateTime Issued { get; set; } - DateTime? Refresh { get; set; } - DateTime? Expires { get; set; } - bool Trial { get; set; } - string Hash { get; set; } - string Signature { get; set; } - byte[] SignatureBytes { get; } - byte[] GetDataBytes(bool forHash = false); - byte[] ComputeHash(); - bool VerifySignature(X509Certificate2 certificate); - byte[] Sign(X509Certificate2 certificate); + public interface ILicense + { + string LicenseKey { get; set; } + int Version { get; set; } + DateTime Issued { get; set; } + DateTime? Refresh { get; set; } + DateTime? Expires { get; set; } + bool Trial { get; set; } + string Hash { get; set; } + string Signature { get; set; } + byte[] SignatureBytes { get; } + byte[] GetDataBytes(bool forHash = false); + byte[] ComputeHash(); + bool VerifySignature(X509Certificate2 certificate); + byte[] Sign(X509Certificate2 certificate); + } } diff --git a/src/Core/Models/Business/ImportedGroup.cs b/src/Core/Models/Business/ImportedGroup.cs index bd0e389339..ee4589dfa2 100644 --- a/src/Core/Models/Business/ImportedGroup.cs +++ b/src/Core/Models/Business/ImportedGroup.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Business; - -public class ImportedGroup +namespace Bit.Core.Models.Business { - public Group Group { get; set; } - public HashSet ExternalUserIds { get; set; } + public class ImportedGroup + { + public Group Group { get; set; } + public HashSet ExternalUserIds { get; set; } + } } diff --git a/src/Core/Models/Business/ImportedOrganizationUser.cs b/src/Core/Models/Business/ImportedOrganizationUser.cs index 967cdf253d..c57ce21230 100644 --- a/src/Core/Models/Business/ImportedOrganizationUser.cs +++ b/src/Core/Models/Business/ImportedOrganizationUser.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Business; - -public class ImportedOrganizationUser +namespace Bit.Core.Models.Business { - public string Email { get; set; } - public string ExternalId { get; set; } + public class ImportedOrganizationUser + { + public string Email { get; set; } + public string ExternalId { get; set; } + } } diff --git a/src/Core/Models/Business/OrganizationLicense.cs b/src/Core/Models/Business/OrganizationLicense.cs index 6f13bc3583..6d1ff069db 100644 --- a/src/Core/Models/Business/OrganizationLicense.cs +++ b/src/Core/Models/Business/OrganizationLicense.cs @@ -8,303 +8,304 @@ using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; -namespace Bit.Core.Models.Business; - -public class OrganizationLicense : ILicense +namespace Bit.Core.Models.Business { - public OrganizationLicense() - { } - - public OrganizationLicense(Organization org, SubscriptionInfo subscriptionInfo, Guid installationId, - ILicensingService licenseService, int? version = null) + public class OrganizationLicense : ILicense { - Version = version.GetValueOrDefault(CURRENT_LICENSE_FILE_VERSION); // TODO: Remember to change the constant - LicenseType = Enums.LicenseType.Organization; - LicenseKey = org.LicenseKey; - InstallationId = installationId; - Id = org.Id; - Name = org.Name; - BillingEmail = org.BillingEmail; - BusinessName = org.BusinessName; - Enabled = org.Enabled; - Plan = org.Plan; - PlanType = org.PlanType; - Seats = org.Seats; - MaxCollections = org.MaxCollections; - UsePolicies = org.UsePolicies; - UseSso = org.UseSso; - UseKeyConnector = org.UseKeyConnector; - UseScim = org.UseScim; - UseGroups = org.UseGroups; - UseEvents = org.UseEvents; - UseDirectory = org.UseDirectory; - UseTotp = org.UseTotp; - Use2fa = org.Use2fa; - UseApi = org.UseApi; - UseResetPassword = org.UseResetPassword; - MaxStorageGb = org.MaxStorageGb; - SelfHost = org.SelfHost; - UsersGetPremium = org.UsersGetPremium; - Issued = DateTime.UtcNow; + public OrganizationLicense() + { } - if (subscriptionInfo?.Subscription == null) + public OrganizationLicense(Organization org, SubscriptionInfo subscriptionInfo, Guid installationId, + ILicensingService licenseService, int? version = null) { - if (org.PlanType == PlanType.Custom && org.ExpirationDate.HasValue) + Version = version.GetValueOrDefault(CURRENT_LICENSE_FILE_VERSION); // TODO: Remember to change the constant + LicenseType = Enums.LicenseType.Organization; + LicenseKey = org.LicenseKey; + InstallationId = installationId; + Id = org.Id; + Name = org.Name; + BillingEmail = org.BillingEmail; + BusinessName = org.BusinessName; + Enabled = org.Enabled; + Plan = org.Plan; + PlanType = org.PlanType; + Seats = org.Seats; + MaxCollections = org.MaxCollections; + UsePolicies = org.UsePolicies; + UseSso = org.UseSso; + UseKeyConnector = org.UseKeyConnector; + UseScim = org.UseScim; + UseGroups = org.UseGroups; + UseEvents = org.UseEvents; + UseDirectory = org.UseDirectory; + UseTotp = org.UseTotp; + Use2fa = org.Use2fa; + UseApi = org.UseApi; + UseResetPassword = org.UseResetPassword; + MaxStorageGb = org.MaxStorageGb; + SelfHost = org.SelfHost; + UsersGetPremium = org.UsersGetPremium; + Issued = DateTime.UtcNow; + + if (subscriptionInfo?.Subscription == null) { - Expires = Refresh = org.ExpirationDate.Value; - Trial = false; + if (org.PlanType == PlanType.Custom && org.ExpirationDate.HasValue) + { + Expires = Refresh = org.ExpirationDate.Value; + Trial = false; + } + else + { + Expires = Refresh = Issued.AddDays(7); + Trial = true; + } } - else + else if (subscriptionInfo.Subscription.TrialEndDate.HasValue && + subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow) { - Expires = Refresh = Issued.AddDays(7); + Expires = Refresh = subscriptionInfo.Subscription.TrialEndDate.Value; Trial = true; } - } - else if (subscriptionInfo.Subscription.TrialEndDate.HasValue && - subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow) - { - Expires = Refresh = subscriptionInfo.Subscription.TrialEndDate.Value; - Trial = true; - } - else - { - if (org.ExpirationDate.HasValue && org.ExpirationDate.Value < DateTime.UtcNow) + else { - // expired - Expires = Refresh = org.ExpirationDate.Value; + if (org.ExpirationDate.HasValue && org.ExpirationDate.Value < DateTime.UtcNow) + { + // expired + Expires = Refresh = org.ExpirationDate.Value; + } + else if (subscriptionInfo?.Subscription?.PeriodDuration != null && + subscriptionInfo.Subscription.PeriodDuration > TimeSpan.FromDays(180)) + { + Refresh = DateTime.UtcNow.AddDays(30); + Expires = subscriptionInfo?.Subscription.PeriodEndDate.Value.AddDays(60); + } + else + { + Expires = org.ExpirationDate.HasValue ? org.ExpirationDate.Value.AddMonths(11) : Issued.AddYears(1); + Refresh = DateTime.UtcNow - Expires > TimeSpan.FromDays(30) ? DateTime.UtcNow.AddDays(30) : Expires; + } + + Trial = false; } - else if (subscriptionInfo?.Subscription?.PeriodDuration != null && - subscriptionInfo.Subscription.PeriodDuration > TimeSpan.FromDays(180)) + + Hash = Convert.ToBase64String(ComputeHash()); + Signature = Convert.ToBase64String(licenseService.SignLicense(this)); + } + + public string LicenseKey { get; set; } + public Guid InstallationId { get; set; } + public Guid Id { get; set; } + public string Name { get; set; } + public string BillingEmail { get; set; } + public string BusinessName { get; set; } + public bool Enabled { get; set; } + public string Plan { get; set; } + public PlanType PlanType { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseEvents { get; set; } + public bool UseDirectory { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public short? MaxStorageGb { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public int Version { get; set; } + public DateTime Issued { get; set; } + public DateTime? Refresh { get; set; } + public DateTime? Expires { get; set; } + public bool Trial { get; set; } + public LicenseType? LicenseType { get; set; } + public string Hash { get; set; } + public string Signature { get; set; } + [JsonIgnore] + public byte[] SignatureBytes => Convert.FromBase64String(Signature); + + /// + /// Represents the current version of the license format. Should be updated whenever new fields are added. + /// + private const int CURRENT_LICENSE_FILE_VERSION = 9; + private bool ValidLicenseVersion + { + get => Version is >= 1 and <= 10; + } + + public byte[] GetDataBytes(bool forHash = false) + { + string data = null; + if (ValidLicenseVersion) { - Refresh = DateTime.UtcNow.AddDays(30); - Expires = subscriptionInfo?.Subscription.PeriodEndDate.Value.AddDays(60); + var props = typeof(OrganizationLicense) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => + !p.Name.Equals(nameof(Signature)) && + !p.Name.Equals(nameof(SignatureBytes)) && + !p.Name.Equals(nameof(LicenseType)) && + // UsersGetPremium was added in Version 2 + (Version >= 2 || !p.Name.Equals(nameof(UsersGetPremium))) && + // UseEvents was added in Version 3 + (Version >= 3 || !p.Name.Equals(nameof(UseEvents))) && + // Use2fa was added in Version 4 + (Version >= 4 || !p.Name.Equals(nameof(Use2fa))) && + // UseApi was added in Version 5 + (Version >= 5 || !p.Name.Equals(nameof(UseApi))) && + // UsePolicies was added in Version 6 + (Version >= 6 || !p.Name.Equals(nameof(UsePolicies))) && + // UseSso was added in Version 7 + (Version >= 7 || !p.Name.Equals(nameof(UseSso))) && + // UseResetPassword was added in Version 8 + (Version >= 8 || !p.Name.Equals(nameof(UseResetPassword))) && + // UseKeyConnector was added in Version 9 + (Version >= 9 || !p.Name.Equals(nameof(UseKeyConnector))) && + // UseScim was added in Version 10 + (Version >= 10 || !p.Name.Equals(nameof(UseScim))) && + ( + !forHash || + ( + !p.Name.Equals(nameof(Hash)) && + !p.Name.Equals(nameof(Issued)) && + !p.Name.Equals(nameof(Refresh)) + ) + )) + .OrderBy(p => p.Name) + .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") + .Aggregate((c, n) => $"{c}|{n}"); + data = $"license:organization|{props}"; } else { - Expires = org.ExpirationDate.HasValue ? org.ExpirationDate.Value.AddMonths(11) : Issued.AddYears(1); - Refresh = DateTime.UtcNow - Expires > TimeSpan.FromDays(30) ? DateTime.UtcNow.AddDays(30) : Expires; + throw new NotSupportedException($"Version {Version} is not supported."); } - Trial = false; + return Encoding.UTF8.GetBytes(data); } - Hash = Convert.ToBase64String(ComputeHash()); - Signature = Convert.ToBase64String(licenseService.SignLicense(this)); - } - - public string LicenseKey { get; set; } - public Guid InstallationId { get; set; } - public Guid Id { get; set; } - public string Name { get; set; } - public string BillingEmail { get; set; } - public string BusinessName { get; set; } - public bool Enabled { get; set; } - public string Plan { get; set; } - public PlanType PlanType { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseEvents { get; set; } - public bool UseDirectory { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public short? MaxStorageGb { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public int Version { get; set; } - public DateTime Issued { get; set; } - public DateTime? Refresh { get; set; } - public DateTime? Expires { get; set; } - public bool Trial { get; set; } - public LicenseType? LicenseType { get; set; } - public string Hash { get; set; } - public string Signature { get; set; } - [JsonIgnore] - public byte[] SignatureBytes => Convert.FromBase64String(Signature); - - /// - /// Represents the current version of the license format. Should be updated whenever new fields are added. - /// - private const int CURRENT_LICENSE_FILE_VERSION = 9; - private bool ValidLicenseVersion - { - get => Version is >= 1 and <= 10; - } - - public byte[] GetDataBytes(bool forHash = false) - { - string data = null; - if (ValidLicenseVersion) + public byte[] ComputeHash() { - var props = typeof(OrganizationLicense) - .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(p => - !p.Name.Equals(nameof(Signature)) && - !p.Name.Equals(nameof(SignatureBytes)) && - !p.Name.Equals(nameof(LicenseType)) && - // UsersGetPremium was added in Version 2 - (Version >= 2 || !p.Name.Equals(nameof(UsersGetPremium))) && - // UseEvents was added in Version 3 - (Version >= 3 || !p.Name.Equals(nameof(UseEvents))) && - // Use2fa was added in Version 4 - (Version >= 4 || !p.Name.Equals(nameof(Use2fa))) && - // UseApi was added in Version 5 - (Version >= 5 || !p.Name.Equals(nameof(UseApi))) && - // UsePolicies was added in Version 6 - (Version >= 6 || !p.Name.Equals(nameof(UsePolicies))) && - // UseSso was added in Version 7 - (Version >= 7 || !p.Name.Equals(nameof(UseSso))) && - // UseResetPassword was added in Version 8 - (Version >= 8 || !p.Name.Equals(nameof(UseResetPassword))) && - // UseKeyConnector was added in Version 9 - (Version >= 9 || !p.Name.Equals(nameof(UseKeyConnector))) && - // UseScim was added in Version 10 - (Version >= 10 || !p.Name.Equals(nameof(UseScim))) && - ( - !forHash || - ( - !p.Name.Equals(nameof(Hash)) && - !p.Name.Equals(nameof(Issued)) && - !p.Name.Equals(nameof(Refresh)) - ) - )) - .OrderBy(p => p.Name) - .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") - .Aggregate((c, n) => $"{c}|{n}"); - data = $"license:organization|{props}"; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - - return Encoding.UTF8.GetBytes(data); - } - - public byte[] ComputeHash() - { - using (var alg = SHA256.Create()) - { - return alg.ComputeHash(GetDataBytes(true)); - } - } - - public bool CanUse(IGlobalSettings globalSettings) - { - if (!Enabled || Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; - } - - if (ValidLicenseVersion) - { - return InstallationId == globalSettings.Installation.Id && SelfHost; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - } - - public bool VerifyData(Organization organization, IGlobalSettings globalSettings) - { - if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; - } - - if (ValidLicenseVersion) - { - var valid = - globalSettings.Installation.Id == InstallationId && - organization.LicenseKey != null && organization.LicenseKey.Equals(LicenseKey) && - organization.Enabled == Enabled && - organization.PlanType == PlanType && - organization.Seats == Seats && - organization.MaxCollections == MaxCollections && - organization.UseGroups == UseGroups && - organization.UseDirectory == UseDirectory && - organization.UseTotp == UseTotp && - organization.SelfHost == SelfHost && - organization.Name.Equals(Name); - - if (valid && Version >= 2) + using (var alg = SHA256.Create()) { - valid = organization.UsersGetPremium == UsersGetPremium; + return alg.ComputeHash(GetDataBytes(true)); } - - if (valid && Version >= 3) - { - valid = organization.UseEvents == UseEvents; - } - - if (valid && Version >= 4) - { - valid = organization.Use2fa == Use2fa; - } - - if (valid && Version >= 5) - { - valid = organization.UseApi == UseApi; - } - - if (valid && Version >= 6) - { - valid = organization.UsePolicies == UsePolicies; - } - - if (valid && Version >= 7) - { - valid = organization.UseSso == UseSso; - } - - if (valid && Version >= 8) - { - valid = organization.UseResetPassword == UseResetPassword; - } - - if (valid && Version >= 9) - { - valid = organization.UseKeyConnector == UseKeyConnector; - } - - if (valid && Version >= 10) - { - valid = organization.UseScim == UseScim; - } - - return valid; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - } - - public bool VerifySignature(X509Certificate2 certificate) - { - using (var rsa = certificate.GetRSAPublicKey()) - { - return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - } - } - - public byte[] Sign(X509Certificate2 certificate) - { - if (!certificate.HasPrivateKey) - { - throw new InvalidOperationException("You don't have the private key!"); } - using (var rsa = certificate.GetRSAPrivateKey()) + public bool CanUse(IGlobalSettings globalSettings) { - return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + if (!Enabled || Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; + } + + if (ValidLicenseVersion) + { + return InstallationId == globalSettings.Installation.Id && SelfHost; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + } + + public bool VerifyData(Organization organization, IGlobalSettings globalSettings) + { + if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; + } + + if (ValidLicenseVersion) + { + var valid = + globalSettings.Installation.Id == InstallationId && + organization.LicenseKey != null && organization.LicenseKey.Equals(LicenseKey) && + organization.Enabled == Enabled && + organization.PlanType == PlanType && + organization.Seats == Seats && + organization.MaxCollections == MaxCollections && + organization.UseGroups == UseGroups && + organization.UseDirectory == UseDirectory && + organization.UseTotp == UseTotp && + organization.SelfHost == SelfHost && + organization.Name.Equals(Name); + + if (valid && Version >= 2) + { + valid = organization.UsersGetPremium == UsersGetPremium; + } + + if (valid && Version >= 3) + { + valid = organization.UseEvents == UseEvents; + } + + if (valid && Version >= 4) + { + valid = organization.Use2fa == Use2fa; + } + + if (valid && Version >= 5) + { + valid = organization.UseApi == UseApi; + } + + if (valid && Version >= 6) + { + valid = organization.UsePolicies == UsePolicies; + } + + if (valid && Version >= 7) + { + valid = organization.UseSso == UseSso; + } + + if (valid && Version >= 8) + { + valid = organization.UseResetPassword == UseResetPassword; + } + + if (valid && Version >= 9) + { + valid = organization.UseKeyConnector == UseKeyConnector; + } + + if (valid && Version >= 10) + { + valid = organization.UseScim == UseScim; + } + + return valid; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + } + + public bool VerifySignature(X509Certificate2 certificate) + { + using (var rsa = certificate.GetRSAPublicKey()) + { + return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + } + } + + public byte[] Sign(X509Certificate2 certificate) + { + if (!certificate.HasPrivateKey) + { + throw new InvalidOperationException("You don't have the private key!"); + } + + using (var rsa = certificate.GetRSAPrivateKey()) + { + return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + } } } } diff --git a/src/Core/Models/Business/OrganizationSignup.cs b/src/Core/Models/Business/OrganizationSignup.cs index 970ede9afc..a257410fd5 100644 --- a/src/Core/Models/Business/OrganizationSignup.cs +++ b/src/Core/Models/Business/OrganizationSignup.cs @@ -1,16 +1,17 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Business; - -public class OrganizationSignup : OrganizationUpgrade +namespace Bit.Core.Models.Business { - public string Name { get; set; } - public string BillingEmail { get; set; } - public User Owner { get; set; } - public string OwnerKey { get; set; } - public string CollectionName { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string PaymentToken { get; set; } - public int? MaxAutoscaleSeats { get; set; } = null; + public class OrganizationSignup : OrganizationUpgrade + { + public string Name { get; set; } + public string BillingEmail { get; set; } + public User Owner { get; set; } + public string OwnerKey { get; set; } + public string CollectionName { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string PaymentToken { get; set; } + public int? MaxAutoscaleSeats { get; set; } = null; + } } diff --git a/src/Core/Models/Business/OrganizationUpgrade.cs b/src/Core/Models/Business/OrganizationUpgrade.cs index b77a9d012c..f6d8aa415a 100644 --- a/src/Core/Models/Business/OrganizationUpgrade.cs +++ b/src/Core/Models/Business/OrganizationUpgrade.cs @@ -1,15 +1,16 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Business; - -public class OrganizationUpgrade +namespace Bit.Core.Models.Business { - public string BusinessName { get; set; } - public PlanType Plan { get; set; } - public int AdditionalSeats { get; set; } - public short AdditionalStorageGb { get; set; } - public bool PremiumAccessAddon { get; set; } - public TaxInfo TaxInfo { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + public class OrganizationUpgrade + { + public string BusinessName { get; set; } + public PlanType Plan { get; set; } + public int AdditionalSeats { get; set; } + public short AdditionalStorageGb { get; set; } + public bool PremiumAccessAddon { get; set; } + public TaxInfo TaxInfo { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + } } diff --git a/src/Core/Models/Business/OrganizationUserInvite.cs b/src/Core/Models/Business/OrganizationUserInvite.cs index 4fa61d55c0..8e7f6f8657 100644 --- a/src/Core/Models/Business/OrganizationUserInvite.cs +++ b/src/Core/Models/Business/OrganizationUserInvite.cs @@ -1,24 +1,25 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.Models.Business; - -public class OrganizationUserInvite +namespace Bit.Core.Models.Business { - public IEnumerable Emails { get; set; } - public Enums.OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public IEnumerable Collections { get; set; } - - public OrganizationUserInvite() { } - - public OrganizationUserInvite(OrganizationUserInviteData requestModel) + public class OrganizationUserInvite { - Emails = requestModel.Emails; - Type = requestModel.Type; - AccessAll = requestModel.AccessAll; - Collections = requestModel.Collections; - Permissions = requestModel.Permissions; + public IEnumerable Emails { get; set; } + public Enums.OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public IEnumerable Collections { get; set; } + + public OrganizationUserInvite() { } + + public OrganizationUserInvite(OrganizationUserInviteData requestModel) + { + Emails = requestModel.Emails; + Type = requestModel.Type; + AccessAll = requestModel.AccessAll; + Collections = requestModel.Collections; + Permissions = requestModel.Permissions; + } } } diff --git a/src/Core/Models/Business/Provider/ProviderUserInvite.cs b/src/Core/Models/Business/Provider/ProviderUserInvite.cs index 72e87728d2..39f6094794 100644 --- a/src/Core/Models/Business/Provider/ProviderUserInvite.cs +++ b/src/Core/Models/Business/Provider/ProviderUserInvite.cs @@ -1,35 +1,36 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Business.Provider; - -public class ProviderUserInvite +namespace Bit.Core.Models.Business.Provider { - public IEnumerable UserIdentifiers { get; set; } - public ProviderUserType Type { get; set; } - public Guid InvitingUserId { get; set; } - public Guid ProviderId { get; set; } -} - -public static class ProviderUserInviteFactory -{ - public static ProviderUserInvite CreateIntialInvite(IEnumerable inviteeEmails, ProviderUserType type, Guid invitingUserId, Guid providerId) + public class ProviderUserInvite { - return new ProviderUserInvite - { - UserIdentifiers = inviteeEmails, - Type = type, - InvitingUserId = invitingUserId, - ProviderId = providerId - }; + public IEnumerable UserIdentifiers { get; set; } + public ProviderUserType Type { get; set; } + public Guid InvitingUserId { get; set; } + public Guid ProviderId { get; set; } } - public static ProviderUserInvite CreateReinvite(IEnumerable inviteeUserIds, Guid invitingUserId, Guid providerId) + public static class ProviderUserInviteFactory { - return new ProviderUserInvite + public static ProviderUserInvite CreateIntialInvite(IEnumerable inviteeEmails, ProviderUserType type, Guid invitingUserId, Guid providerId) { - UserIdentifiers = inviteeUserIds, - InvitingUserId = invitingUserId, - ProviderId = providerId - }; + return new ProviderUserInvite + { + UserIdentifiers = inviteeEmails, + Type = type, + InvitingUserId = invitingUserId, + ProviderId = providerId + }; + } + + public static ProviderUserInvite CreateReinvite(IEnumerable inviteeUserIds, Guid invitingUserId, Guid providerId) + { + return new ProviderUserInvite + { + UserIdentifiers = inviteeUserIds, + InvitingUserId = invitingUserId, + ProviderId = providerId + }; + } } } diff --git a/src/Core/Models/Business/ReferenceEvent.cs b/src/Core/Models/Business/ReferenceEvent.cs index 4f20b2455c..35cb5dc735 100644 --- a/src/Core/Models/Business/ReferenceEvent.cs +++ b/src/Core/Models/Business/ReferenceEvent.cs @@ -2,60 +2,61 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Business; - -public class ReferenceEvent +namespace Bit.Core.Models.Business { - public ReferenceEvent() { } - - public ReferenceEvent(ReferenceEventType type, IReferenceable source) + public class ReferenceEvent { - Type = type; - if (source != null) + public ReferenceEvent() { } + + public ReferenceEvent(ReferenceEventType type, IReferenceable source) { - Source = source.IsUser() ? ReferenceEventSource.User : ReferenceEventSource.Organization; - Id = source.Id; - ReferenceData = source.ReferenceData; + Type = type; + if (source != null) + { + Source = source.IsUser() ? ReferenceEventSource.User : ReferenceEventSource.Organization; + Id = source.Id; + ReferenceData = source.ReferenceData; + } } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public ReferenceEventType Type { get; set; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public ReferenceEventSource Source { get; set; } + + public Guid Id { get; set; } + + public string ReferenceData { get; set; } + + public DateTime EventDate { get; set; } = DateTime.UtcNow; + + public int? Users { get; set; } + + public bool? EndOfPeriod { get; set; } + + public string PlanName { get; set; } + + public PlanType? PlanType { get; set; } + + public string OldPlanName { get; set; } + + public PlanType? OldPlanType { get; set; } + + public int? Seats { get; set; } + public int? PreviousSeats { get; set; } + + public short? Storage { get; set; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public SendType? SendType { get; set; } + + public int? MaxAccessCount { get; set; } + + public bool? HasPassword { get; set; } + + public string EventRaisedByUser { get; set; } + + public bool? SalesAssistedTrialStarted { get; set; } } - - [JsonConverter(typeof(JsonStringEnumConverter))] - public ReferenceEventType Type { get; set; } - - [JsonConverter(typeof(JsonStringEnumConverter))] - public ReferenceEventSource Source { get; set; } - - public Guid Id { get; set; } - - public string ReferenceData { get; set; } - - public DateTime EventDate { get; set; } = DateTime.UtcNow; - - public int? Users { get; set; } - - public bool? EndOfPeriod { get; set; } - - public string PlanName { get; set; } - - public PlanType? PlanType { get; set; } - - public string OldPlanName { get; set; } - - public PlanType? OldPlanType { get; set; } - - public int? Seats { get; set; } - public int? PreviousSeats { get; set; } - - public short? Storage { get; set; } - - [JsonConverter(typeof(JsonStringEnumConverter))] - public SendType? SendType { get; set; } - - public int? MaxAccessCount { get; set; } - - public bool? HasPassword { get; set; } - - public string EventRaisedByUser { get; set; } - - public bool? SalesAssistedTrialStarted { get; set; } } diff --git a/src/Core/Models/Business/SubscriptionCreateOptions.cs b/src/Core/Models/Business/SubscriptionCreateOptions.cs index 4964a625c8..e78aaeda0e 100644 --- a/src/Core/Models/Business/SubscriptionCreateOptions.cs +++ b/src/Core/Models/Business/SubscriptionCreateOptions.cs @@ -1,83 +1,84 @@ using Bit.Core.Entities; using Stripe; -namespace Bit.Core.Models.Business; - -public class OrganizationSubscriptionOptionsBase : Stripe.SubscriptionCreateOptions +namespace Bit.Core.Models.Business { - public OrganizationSubscriptionOptionsBase(Organization org, StaticStore.Plan plan, TaxInfo taxInfo, int additionalSeats, int additionalStorageGb, bool premiumAccessAddon) + public class OrganizationSubscriptionOptionsBase : Stripe.SubscriptionCreateOptions { - Items = new List(); - Metadata = new Dictionary + public OrganizationSubscriptionOptionsBase(Organization org, StaticStore.Plan plan, TaxInfo taxInfo, int additionalSeats, int additionalStorageGb, bool premiumAccessAddon) { - [org.GatewayIdField()] = org.Id.ToString() - }; - - if (plan.StripePlanId != null) - { - Items.Add(new SubscriptionItemOptions + Items = new List(); + Metadata = new Dictionary { - Plan = plan.StripePlanId, - Quantity = 1 - }); - } + [org.GatewayIdField()] = org.Id.ToString() + }; - if (additionalSeats > 0 && plan.StripeSeatPlanId != null) - { - Items.Add(new SubscriptionItemOptions + if (plan.StripePlanId != null) { - Plan = plan.StripeSeatPlanId, - Quantity = additionalSeats - }); - } + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripePlanId, + Quantity = 1 + }); + } - if (additionalStorageGb > 0) - { - Items.Add(new SubscriptionItemOptions + if (additionalSeats > 0 && plan.StripeSeatPlanId != null) { - Plan = plan.StripeStoragePlanId, - Quantity = additionalStorageGb - }); - } + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripeSeatPlanId, + Quantity = additionalSeats + }); + } - if (premiumAccessAddon && plan.StripePremiumAccessPlanId != null) - { - Items.Add(new SubscriptionItemOptions + if (additionalStorageGb > 0) { - Plan = plan.StripePremiumAccessPlanId, - Quantity = 1 - }); - } + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripeStoragePlanId, + Quantity = additionalStorageGb + }); + } - if (!string.IsNullOrWhiteSpace(taxInfo?.StripeTaxRateId)) + if (premiumAccessAddon && plan.StripePremiumAccessPlanId != null) + { + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripePremiumAccessPlanId, + Quantity = 1 + }); + } + + if (!string.IsNullOrWhiteSpace(taxInfo?.StripeTaxRateId)) + { + DefaultTaxRates = new List { taxInfo.StripeTaxRateId }; + } + } + } + + public class OrganizationPurchaseSubscriptionOptions : OrganizationSubscriptionOptionsBase + { + public OrganizationPurchaseSubscriptionOptions( + Organization org, StaticStore.Plan plan, + TaxInfo taxInfo, int additionalSeats = 0, + int additionalStorageGb = 0, bool premiumAccessAddon = false) : + base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) { - DefaultTaxRates = new List { taxInfo.StripeTaxRateId }; + OffSession = true; + TrialPeriodDays = plan.TrialPeriodDays; + } + } + + public class OrganizationUpgradeSubscriptionOptions : OrganizationSubscriptionOptionsBase + { + public OrganizationUpgradeSubscriptionOptions( + string customerId, Organization org, + StaticStore.Plan plan, TaxInfo taxInfo, + int additionalSeats = 0, int additionalStorageGb = 0, + bool premiumAccessAddon = false) : + base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) + { + Customer = customerId; } } } - -public class OrganizationPurchaseSubscriptionOptions : OrganizationSubscriptionOptionsBase -{ - public OrganizationPurchaseSubscriptionOptions( - Organization org, StaticStore.Plan plan, - TaxInfo taxInfo, int additionalSeats = 0, - int additionalStorageGb = 0, bool premiumAccessAddon = false) : - base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) - { - OffSession = true; - TrialPeriodDays = plan.TrialPeriodDays; - } -} - -public class OrganizationUpgradeSubscriptionOptions : OrganizationSubscriptionOptionsBase -{ - public OrganizationUpgradeSubscriptionOptions( - string customerId, Organization org, - StaticStore.Plan plan, TaxInfo taxInfo, - int additionalSeats = 0, int additionalStorageGb = 0, - bool premiumAccessAddon = false) : - base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) - { - Customer = customerId; - } -} diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index 61aa060cd4..e8e339db8d 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -1,86 +1,87 @@ using Stripe; -namespace Bit.Core.Models.Business; - -public class SubscriptionInfo +namespace Bit.Core.Models.Business { - public BillingSubscription Subscription { get; set; } - public BillingUpcomingInvoice UpcomingInvoice { get; set; } - public bool UsingInAppPurchase { get; set; } - - public class BillingSubscription + public class SubscriptionInfo { - public BillingSubscription(Subscription sub) - { - Status = sub.Status; - TrialStartDate = sub.TrialStart; - TrialEndDate = sub.TrialEnd; - PeriodStartDate = sub.CurrentPeriodStart; - PeriodEndDate = sub.CurrentPeriodEnd; - CancelledDate = sub.CanceledAt; - CancelAtEndDate = sub.CancelAtPeriodEnd; - Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired"; - if (sub.Items?.Data != null) - { - Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); - } - } + public BillingSubscription Subscription { get; set; } + public BillingUpcomingInvoice UpcomingInvoice { get; set; } + public bool UsingInAppPurchase { get; set; } - public DateTime? TrialStartDate { get; set; } - public DateTime? TrialEndDate { get; set; } - public DateTime? PeriodStartDate { get; set; } - public DateTime? PeriodEndDate { get; set; } - public TimeSpan? PeriodDuration => PeriodEndDate - PeriodStartDate; - public DateTime? CancelledDate { get; set; } - public bool CancelAtEndDate { get; set; } - public string Status { get; set; } - public bool Cancelled { get; set; } - public IEnumerable Items { get; set; } = new List(); - - public class BillingSubscriptionItem + public class BillingSubscription { - public BillingSubscriptionItem(SubscriptionItem item) + public BillingSubscription(Subscription sub) { - if (item.Plan != null) + Status = sub.Status; + TrialStartDate = sub.TrialStart; + TrialEndDate = sub.TrialEnd; + PeriodStartDate = sub.CurrentPeriodStart; + PeriodEndDate = sub.CurrentPeriodEnd; + CancelledDate = sub.CanceledAt; + CancelAtEndDate = sub.CancelAtPeriodEnd; + Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired"; + if (sub.Items?.Data != null) { - Name = item.Plan.Nickname; - Amount = item.Plan.Amount.GetValueOrDefault() / 100M; - Interval = item.Plan.Interval; + Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); + } + } + + public DateTime? TrialStartDate { get; set; } + public DateTime? TrialEndDate { get; set; } + public DateTime? PeriodStartDate { get; set; } + public DateTime? PeriodEndDate { get; set; } + public TimeSpan? PeriodDuration => PeriodEndDate - PeriodStartDate; + public DateTime? CancelledDate { get; set; } + public bool CancelAtEndDate { get; set; } + public string Status { get; set; } + public bool Cancelled { get; set; } + public IEnumerable Items { get; set; } = new List(); + + public class BillingSubscriptionItem + { + public BillingSubscriptionItem(SubscriptionItem item) + { + if (item.Plan != null) + { + Name = item.Plan.Nickname; + Amount = item.Plan.Amount.GetValueOrDefault() / 100M; + Interval = item.Plan.Interval; + } + + Quantity = (int)item.Quantity; + SponsoredSubscriptionItem = Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); } - Quantity = (int)item.Quantity; - SponsoredSubscriptionItem = Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); + public string Name { get; set; } + public decimal Amount { get; set; } + public int Quantity { get; set; } + public string Interval { get; set; } + public bool SponsoredSubscriptionItem { get; set; } } - - public string Name { get; set; } - public decimal Amount { get; set; } - public int Quantity { get; set; } - public string Interval { get; set; } - public bool SponsoredSubscriptionItem { get; set; } - } - } - - public class BillingUpcomingInvoice - { - public BillingUpcomingInvoice() { } - - public BillingUpcomingInvoice(Invoice inv) - { - Amount = inv.AmountDue / 100M; - Date = inv.Created; } - public BillingUpcomingInvoice(Braintree.Subscription sub) + public class BillingUpcomingInvoice { - Amount = sub.NextBillAmount.GetValueOrDefault() + sub.Balance.GetValueOrDefault(); - if (Amount < 0) + public BillingUpcomingInvoice() { } + + public BillingUpcomingInvoice(Invoice inv) { - Amount = 0; + Amount = inv.AmountDue / 100M; + Date = inv.Created; } - Date = sub.NextBillingDate; - } - public decimal Amount { get; set; } - public DateTime? Date { get; set; } + public BillingUpcomingInvoice(Braintree.Subscription sub) + { + Amount = sub.NextBillAmount.GetValueOrDefault() + sub.Balance.GetValueOrDefault(); + if (Amount < 0) + { + Amount = 0; + } + Date = sub.NextBillingDate; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } + } } } diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs index 64b43a8de2..56902524d8 100644 --- a/src/Core/Models/Business/SubscriptionUpdate.cs +++ b/src/Core/Models/Business/SubscriptionUpdate.cs @@ -1,209 +1,210 @@ using Bit.Core.Entities; using Stripe; -namespace Bit.Core.Models.Business; - -public abstract class SubscriptionUpdate +namespace Bit.Core.Models.Business { - protected abstract List PlanIds { get; } - - public abstract List RevertItemsOptions(Subscription subscription); - public abstract List UpgradeItemsOptions(Subscription subscription); - - public bool UpdateNeeded(Subscription subscription) + public abstract class SubscriptionUpdate { - var upgradeItemsOptions = UpgradeItemsOptions(subscription); - foreach (var upgradeItemOptions in upgradeItemsOptions) + protected abstract List PlanIds { get; } + + public abstract List RevertItemsOptions(Subscription subscription); + public abstract List UpgradeItemsOptions(Subscription subscription); + + public bool UpdateNeeded(Subscription subscription) { - var upgradeQuantity = upgradeItemOptions.Quantity ?? 0; - var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; - if (upgradeQuantity != existingQuantity) + var upgradeItemsOptions = UpgradeItemsOptions(subscription); + foreach (var upgradeItemOptions in upgradeItemsOptions) { - return true; + var upgradeQuantity = upgradeItemOptions.Quantity ?? 0; + var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; + if (upgradeQuantity != existingQuantity) + { + return true; + } } + return false; } - return false; + + protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) => + planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId); } - protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) => - planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId); -} - -public class SeatSubscriptionUpdate : SubscriptionUpdate -{ - private readonly int _previousSeats; - private readonly StaticStore.Plan _plan; - private readonly long? _additionalSeats; - protected override List PlanIds => new() { _plan.StripeSeatPlanId }; - - - public SeatSubscriptionUpdate(Organization organization, StaticStore.Plan plan, long? additionalSeats) + public class SeatSubscriptionUpdate : SubscriptionUpdate { - _plan = plan; - _additionalSeats = additionalSeats; - _previousSeats = organization.Seats ?? 0; - } + private readonly int _previousSeats; + private readonly StaticStore.Plan _plan; + private readonly long? _additionalSeats; + protected override List PlanIds => new() { _plan.StripeSeatPlanId }; - public override List UpgradeItemsOptions(Subscription subscription) - { - var item = SubscriptionItem(subscription, PlanIds.Single()); - return new() + + public SeatSubscriptionUpdate(Organization organization, StaticStore.Plan plan, long? additionalSeats) { - new SubscriptionItemOptions + _plan = plan; + _additionalSeats = additionalSeats; + _previousSeats = organization.Seats ?? 0; + } + + public override List UpgradeItemsOptions(Subscription subscription) + { + var item = SubscriptionItem(subscription, PlanIds.Single()); + return new() { - Id = item?.Id, - Plan = PlanIds.Single(), - Quantity = _additionalSeats, - Deleted = (item?.Id != null && _additionalSeats == 0) ? true : (bool?)null, - } - }; + new SubscriptionItemOptions + { + Id = item?.Id, + Plan = PlanIds.Single(), + Quantity = _additionalSeats, + Deleted = (item?.Id != null && _additionalSeats == 0) ? true : (bool?)null, + } + }; + } + + public override List RevertItemsOptions(Subscription subscription) + { + + var item = SubscriptionItem(subscription, PlanIds.Single()); + return new() + { + new SubscriptionItemOptions + { + Id = item?.Id, + Plan = PlanIds.Single(), + Quantity = _previousSeats, + Deleted = _previousSeats == 0 ? true : (bool?)null, + } + }; + } } - public override List RevertItemsOptions(Subscription subscription) + public class StorageSubscriptionUpdate : SubscriptionUpdate { + private long? _prevStorage; + private readonly string _plan; + private readonly long? _additionalStorage; + protected override List PlanIds => new() { _plan }; - var item = SubscriptionItem(subscription, PlanIds.Single()); - return new() + public StorageSubscriptionUpdate(string plan, long? additionalStorage) { - new SubscriptionItemOptions + _plan = plan; + _additionalStorage = additionalStorage; + } + + public override List UpgradeItemsOptions(Subscription subscription) + { + var item = SubscriptionItem(subscription, PlanIds.Single()); + _prevStorage = item?.Quantity ?? 0; + return new() { - Id = item?.Id, - Plan = PlanIds.Single(), - Quantity = _previousSeats, - Deleted = _previousSeats == 0 ? true : (bool?)null, + new SubscriptionItemOptions + { + Id = item?.Id, + Plan = _plan, + Quantity = _additionalStorage, + Deleted = (item?.Id != null && _additionalStorage == 0) ? true : (bool?)null, + } + }; + } + + public override List RevertItemsOptions(Subscription subscription) + { + if (!_prevStorage.HasValue) + { + throw new Exception("Unknown previous value, must first call UpgradeItemsOptions"); } - }; + + var item = SubscriptionItem(subscription, PlanIds.Single()); + return new() + { + new SubscriptionItemOptions + { + Id = item?.Id, + Plan = _plan, + Quantity = _prevStorage.Value, + Deleted = _prevStorage.Value == 0 ? true : (bool?)null, + } + }; + } + } + + public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate + { + private readonly string _existingPlanStripeId; + private readonly string _sponsoredPlanStripeId; + private readonly bool _applySponsorship; + protected override List PlanIds => new() { _existingPlanStripeId, _sponsoredPlanStripeId }; + + public SponsorOrganizationSubscriptionUpdate(StaticStore.Plan existingPlan, StaticStore.SponsoredPlan sponsoredPlan, bool applySponsorship) + { + _existingPlanStripeId = existingPlan.StripePlanId; + _sponsoredPlanStripeId = sponsoredPlan?.StripePlanId; + _applySponsorship = applySponsorship; + } + + public override List RevertItemsOptions(Subscription subscription) + { + var result = new List(); + if (!string.IsNullOrWhiteSpace(AddStripePlanId)) + { + result.Add(new SubscriptionItemOptions + { + Id = AddStripeItem(subscription)?.Id, + Plan = AddStripePlanId, + Quantity = 0, + Deleted = true, + }); + } + + if (!string.IsNullOrWhiteSpace(RemoveStripePlanId)) + { + result.Add(new SubscriptionItemOptions + { + Id = RemoveStripeItem(subscription)?.Id, + Plan = RemoveStripePlanId, + Quantity = 1, + Deleted = false, + }); + } + return result; + } + + public override List UpgradeItemsOptions(Subscription subscription) + { + var result = new List(); + if (RemoveStripeItem(subscription) != null) + { + result.Add(new SubscriptionItemOptions + { + Id = RemoveStripeItem(subscription)?.Id, + Plan = RemoveStripePlanId, + Quantity = 0, + Deleted = true, + }); + } + + if (!string.IsNullOrWhiteSpace(AddStripePlanId)) + { + result.Add(new SubscriptionItemOptions + { + Id = AddStripeItem(subscription)?.Id, + Plan = AddStripePlanId, + Quantity = 1, + Deleted = false, + }); + } + return result; + } + + private string RemoveStripePlanId => _applySponsorship ? _existingPlanStripeId : _sponsoredPlanStripeId; + private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId; + private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) => + _applySponsorship ? + SubscriptionItem(subscription, _existingPlanStripeId) : + SubscriptionItem(subscription, _sponsoredPlanStripeId); + private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) => + _applySponsorship ? + SubscriptionItem(subscription, _sponsoredPlanStripeId) : + SubscriptionItem(subscription, _existingPlanStripeId); + } } - -public class StorageSubscriptionUpdate : SubscriptionUpdate -{ - private long? _prevStorage; - private readonly string _plan; - private readonly long? _additionalStorage; - protected override List PlanIds => new() { _plan }; - - public StorageSubscriptionUpdate(string plan, long? additionalStorage) - { - _plan = plan; - _additionalStorage = additionalStorage; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var item = SubscriptionItem(subscription, PlanIds.Single()); - _prevStorage = item?.Quantity ?? 0; - return new() - { - new SubscriptionItemOptions - { - Id = item?.Id, - Plan = _plan, - Quantity = _additionalStorage, - Deleted = (item?.Id != null && _additionalStorage == 0) ? true : (bool?)null, - } - }; - } - - public override List RevertItemsOptions(Subscription subscription) - { - if (!_prevStorage.HasValue) - { - throw new Exception("Unknown previous value, must first call UpgradeItemsOptions"); - } - - var item = SubscriptionItem(subscription, PlanIds.Single()); - return new() - { - new SubscriptionItemOptions - { - Id = item?.Id, - Plan = _plan, - Quantity = _prevStorage.Value, - Deleted = _prevStorage.Value == 0 ? true : (bool?)null, - } - }; - } -} - -public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate -{ - private readonly string _existingPlanStripeId; - private readonly string _sponsoredPlanStripeId; - private readonly bool _applySponsorship; - protected override List PlanIds => new() { _existingPlanStripeId, _sponsoredPlanStripeId }; - - public SponsorOrganizationSubscriptionUpdate(StaticStore.Plan existingPlan, StaticStore.SponsoredPlan sponsoredPlan, bool applySponsorship) - { - _existingPlanStripeId = existingPlan.StripePlanId; - _sponsoredPlanStripeId = sponsoredPlan?.StripePlanId; - _applySponsorship = applySponsorship; - } - - public override List RevertItemsOptions(Subscription subscription) - { - var result = new List(); - if (!string.IsNullOrWhiteSpace(AddStripePlanId)) - { - result.Add(new SubscriptionItemOptions - { - Id = AddStripeItem(subscription)?.Id, - Plan = AddStripePlanId, - Quantity = 0, - Deleted = true, - }); - } - - if (!string.IsNullOrWhiteSpace(RemoveStripePlanId)) - { - result.Add(new SubscriptionItemOptions - { - Id = RemoveStripeItem(subscription)?.Id, - Plan = RemoveStripePlanId, - Quantity = 1, - Deleted = false, - }); - } - return result; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var result = new List(); - if (RemoveStripeItem(subscription) != null) - { - result.Add(new SubscriptionItemOptions - { - Id = RemoveStripeItem(subscription)?.Id, - Plan = RemoveStripePlanId, - Quantity = 0, - Deleted = true, - }); - } - - if (!string.IsNullOrWhiteSpace(AddStripePlanId)) - { - result.Add(new SubscriptionItemOptions - { - Id = AddStripeItem(subscription)?.Id, - Plan = AddStripePlanId, - Quantity = 1, - Deleted = false, - }); - } - return result; - } - - private string RemoveStripePlanId => _applySponsorship ? _existingPlanStripeId : _sponsoredPlanStripeId; - private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId; - private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) => - _applySponsorship ? - SubscriptionItem(subscription, _existingPlanStripeId) : - SubscriptionItem(subscription, _sponsoredPlanStripeId); - private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) => - _applySponsorship ? - SubscriptionItem(subscription, _sponsoredPlanStripeId) : - SubscriptionItem(subscription, _existingPlanStripeId); - -} diff --git a/src/Core/Models/Business/TaxInfo.cs b/src/Core/Models/Business/TaxInfo.cs index e763b72235..62d30b8fe7 100644 --- a/src/Core/Models/Business/TaxInfo.cs +++ b/src/Core/Models/Business/TaxInfo.cs @@ -1,153 +1,154 @@ -namespace Bit.Core.Models.Business; - -public class TaxInfo +namespace Bit.Core.Models.Business { - private string _taxIdNumber = null; - private string _taxIdType = null; + public class TaxInfo + { + private string _taxIdNumber = null; + private string _taxIdType = null; - public string TaxIdNumber - { - get => _taxIdNumber; - set + public string TaxIdNumber { - _taxIdNumber = value; - _taxIdType = null; - } - } - public string StripeTaxRateId { get; set; } - public string BillingAddressLine1 { get; set; } - public string BillingAddressLine2 { get; set; } - public string BillingAddressCity { get; set; } - public string BillingAddressState { get; set; } - public string BillingAddressPostalCode { get; set; } - public string BillingAddressCountry { get; set; } = "US"; - public string TaxIdType - { - get - { - if (string.IsNullOrWhiteSpace(BillingAddressCountry) || - string.IsNullOrWhiteSpace(TaxIdNumber)) + get => _taxIdNumber; + set { - return null; + _taxIdNumber = value; + _taxIdType = null; } - if (!string.IsNullOrWhiteSpace(_taxIdType)) + } + public string StripeTaxRateId { get; set; } + public string BillingAddressLine1 { get; set; } + public string BillingAddressLine2 { get; set; } + public string BillingAddressCity { get; set; } + public string BillingAddressState { get; set; } + public string BillingAddressPostalCode { get; set; } + public string BillingAddressCountry { get; set; } = "US"; + public string TaxIdType + { + get { + if (string.IsNullOrWhiteSpace(BillingAddressCountry) || + string.IsNullOrWhiteSpace(TaxIdNumber)) + { + return null; + } + if (!string.IsNullOrWhiteSpace(_taxIdType)) + { + return _taxIdType; + } + + switch (BillingAddressCountry) + { + case "AE": + _taxIdType = "ae_trn"; + break; + case "AU": + _taxIdType = "au_abn"; + break; + case "BR": + _taxIdType = "br_cnpj"; + break; + case "CA": + // May break for those in Québec given the assumption of QST + if (BillingAddressState?.Contains("bec") ?? false) + { + _taxIdType = "ca_qst"; + break; + } + _taxIdType = "ca_bn"; + break; + case "CL": + _taxIdType = "cl_tin"; + break; + case "AT": + case "BE": + case "BG": + case "CY": + case "CZ": + case "DE": + case "DK": + case "EE": + case "ES": + case "FI": + case "FR": + case "GB": + case "GR": + case "HR": + case "HU": + case "IE": + case "IT": + case "LT": + case "LU": + case "LV": + case "MT": + case "NL": + case "PL": + case "PT": + case "RO": + case "SE": + case "SI": + case "SK": + _taxIdType = "eu_vat"; + break; + case "HK": + _taxIdType = "hk_br"; + break; + case "IN": + _taxIdType = "in_gst"; + break; + case "JP": + _taxIdType = "jp_cn"; + break; + case "KR": + _taxIdType = "kr_brn"; + break; + case "LI": + _taxIdType = "li_uid"; + break; + case "MX": + _taxIdType = "mx_rfc"; + break; + case "MY": + _taxIdType = "my_sst"; + break; + case "NO": + _taxIdType = "no_vat"; + break; + case "NZ": + _taxIdType = "nz_gst"; + break; + case "RU": + _taxIdType = "ru_inn"; + break; + case "SA": + _taxIdType = "sa_vat"; + break; + case "SG": + _taxIdType = "sg_gst"; + break; + case "TH": + _taxIdType = "th_vat"; + break; + case "TW": + _taxIdType = "tw_vat"; + break; + case "US": + _taxIdType = "us_ein"; + break; + case "ZA": + _taxIdType = "za_vat"; + break; + default: + _taxIdType = null; + break; + } + return _taxIdType; } + } - switch (BillingAddressCountry) - { - case "AE": - _taxIdType = "ae_trn"; - break; - case "AU": - _taxIdType = "au_abn"; - break; - case "BR": - _taxIdType = "br_cnpj"; - break; - case "CA": - // May break for those in Québec given the assumption of QST - if (BillingAddressState?.Contains("bec") ?? false) - { - _taxIdType = "ca_qst"; - break; - } - _taxIdType = "ca_bn"; - break; - case "CL": - _taxIdType = "cl_tin"; - break; - case "AT": - case "BE": - case "BG": - case "CY": - case "CZ": - case "DE": - case "DK": - case "EE": - case "ES": - case "FI": - case "FR": - case "GB": - case "GR": - case "HR": - case "HU": - case "IE": - case "IT": - case "LT": - case "LU": - case "LV": - case "MT": - case "NL": - case "PL": - case "PT": - case "RO": - case "SE": - case "SI": - case "SK": - _taxIdType = "eu_vat"; - break; - case "HK": - _taxIdType = "hk_br"; - break; - case "IN": - _taxIdType = "in_gst"; - break; - case "JP": - _taxIdType = "jp_cn"; - break; - case "KR": - _taxIdType = "kr_brn"; - break; - case "LI": - _taxIdType = "li_uid"; - break; - case "MX": - _taxIdType = "mx_rfc"; - break; - case "MY": - _taxIdType = "my_sst"; - break; - case "NO": - _taxIdType = "no_vat"; - break; - case "NZ": - _taxIdType = "nz_gst"; - break; - case "RU": - _taxIdType = "ru_inn"; - break; - case "SA": - _taxIdType = "sa_vat"; - break; - case "SG": - _taxIdType = "sg_gst"; - break; - case "TH": - _taxIdType = "th_vat"; - break; - case "TW": - _taxIdType = "tw_vat"; - break; - case "US": - _taxIdType = "us_ein"; - break; - case "ZA": - _taxIdType = "za_vat"; - break; - default: - _taxIdType = null; - break; - } - - return _taxIdType; + public bool HasTaxId + { + get => !string.IsNullOrWhiteSpace(TaxIdNumber) && + !string.IsNullOrWhiteSpace(TaxIdType); } } - - public bool HasTaxId - { - get => !string.IsNullOrWhiteSpace(TaxIdNumber) && - !string.IsNullOrWhiteSpace(TaxIdType); - } } diff --git a/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs b/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs index 9d0e6cafa6..f8d7b02b78 100644 --- a/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs +++ b/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs @@ -1,35 +1,36 @@ using System.Text.Json.Serialization; using Bit.Core.Entities; -namespace Bit.Core.Models.Business.Tokenables; - -public class EmergencyAccessInviteTokenable : Tokens.ExpiringTokenable +namespace Bit.Core.Models.Business.Tokenables { - public const string ClearTextPrefix = ""; - public const string DataProtectorPurpose = "EmergencyAccessServiceDataProtector"; - public const string TokenIdentifier = "EmergencyAccessInvite"; - public string Identifier { get; set; } = TokenIdentifier; - public Guid Id { get; set; } - public string Email { get; set; } - - [JsonConstructor] - public EmergencyAccessInviteTokenable(DateTime expirationDate) + public class EmergencyAccessInviteTokenable : Tokens.ExpiringTokenable { - ExpirationDate = expirationDate; - } + public const string ClearTextPrefix = ""; + public const string DataProtectorPurpose = "EmergencyAccessServiceDataProtector"; + public const string TokenIdentifier = "EmergencyAccessInvite"; + public string Identifier { get; set; } = TokenIdentifier; + public Guid Id { get; set; } + public string Email { get; set; } - public EmergencyAccessInviteTokenable(EmergencyAccess user, int hoursTillExpiration) - { - Id = user.Id; - Email = user.Email; - ExpirationDate = DateTime.UtcNow.AddHours(hoursTillExpiration); - } + [JsonConstructor] + public EmergencyAccessInviteTokenable(DateTime expirationDate) + { + ExpirationDate = expirationDate; + } - public bool IsValid(Guid id, string email) - { - return Id == id && - Email.Equals(email, StringComparison.InvariantCultureIgnoreCase); - } + public EmergencyAccessInviteTokenable(EmergencyAccess user, int hoursTillExpiration) + { + Id = user.Id; + Email = user.Email; + ExpirationDate = DateTime.UtcNow.AddHours(hoursTillExpiration); + } - protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); + public bool IsValid(Guid id, string email) + { + return Id == id && + Email.Equals(email, StringComparison.InvariantCultureIgnoreCase); + } + + protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); + } } diff --git a/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs b/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs index c62c7189a6..774df7d799 100644 --- a/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs +++ b/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs @@ -2,42 +2,43 @@ using Bit.Core.Entities; using Bit.Core.Tokens; -namespace Bit.Core.Models.Business.Tokenables; - -public class HCaptchaTokenable : ExpiringTokenable +namespace Bit.Core.Models.Business.Tokenables { - private const double _tokenLifetimeInHours = (double)5 / 60; // 5 minutes - public const string ClearTextPrefix = "BWCaptchaBypass_"; - public const string DataProtectorPurpose = "CaptchaServiceDataProtector"; - public const string TokenIdentifier = "CaptchaBypassToken"; - - public string Identifier { get; set; } = TokenIdentifier; - public Guid Id { get; set; } - public string Email { get; set; } - - [JsonConstructor] - public HCaptchaTokenable() + public class HCaptchaTokenable : ExpiringTokenable { - ExpirationDate = DateTime.UtcNow.AddHours(_tokenLifetimeInHours); - } + private const double _tokenLifetimeInHours = (double)5 / 60; // 5 minutes + public const string ClearTextPrefix = "BWCaptchaBypass_"; + public const string DataProtectorPurpose = "CaptchaServiceDataProtector"; + public const string TokenIdentifier = "CaptchaBypassToken"; - public HCaptchaTokenable(User user) : this() - { - Id = user?.Id ?? default; - Email = user?.Email; - } + public string Identifier { get; set; } = TokenIdentifier; + public Guid Id { get; set; } + public string Email { get; set; } - public bool TokenIsValid(User user) - { - if (Id == default || Email == default || user == null) + [JsonConstructor] + public HCaptchaTokenable() { - return false; + ExpirationDate = DateTime.UtcNow.AddHours(_tokenLifetimeInHours); } - return Id == user.Id && - Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase); - } + public HCaptchaTokenable(User user) : this() + { + Id = user?.Id ?? default; + Email = user?.Email; + } - // Validates deserialized - protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); + public bool TokenIsValid(User user) + { + if (Id == default || Email == default || user == null) + { + return false; + } + + return Id == user.Id && + Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase); + } + + // Validates deserialized + protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); + } } diff --git a/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs b/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs index 4bca8e1ca1..0360f35421 100644 --- a/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs +++ b/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs @@ -3,54 +3,55 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Tokens; -namespace Bit.Core.Models.Business.Tokenables; - -public class OrganizationSponsorshipOfferTokenable : Tokenable +namespace Bit.Core.Models.Business.Tokenables { - public const string ClearTextPrefix = "BWOrganizationSponsorship_"; - public const string DataProtectorPurpose = "OrganizationSponsorshipDataProtector"; - public const string TokenIdentifier = "OrganizationSponsorshipOfferToken"; - public string Identifier { get; set; } = TokenIdentifier; - public Guid Id { get; set; } - public PlanSponsorshipType SponsorshipType { get; set; } - public string Email { get; set; } - - public override bool Valid => !string.IsNullOrWhiteSpace(Email) && - Identifier == TokenIdentifier && - Id != default; - - - [JsonConstructor] - public OrganizationSponsorshipOfferTokenable() { } - - public OrganizationSponsorshipOfferTokenable(OrganizationSponsorship sponsorship) + public class OrganizationSponsorshipOfferTokenable : Tokenable { - if (string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail)) - { - throw new ArgumentException("Invalid OrganizationSponsorship to create a token, OfferedToEmail is required", nameof(sponsorship)); - } - Email = sponsorship.OfferedToEmail; + public const string ClearTextPrefix = "BWOrganizationSponsorship_"; + public const string DataProtectorPurpose = "OrganizationSponsorshipDataProtector"; + public const string TokenIdentifier = "OrganizationSponsorshipOfferToken"; + public string Identifier { get; set; } = TokenIdentifier; + public Guid Id { get; set; } + public PlanSponsorshipType SponsorshipType { get; set; } + public string Email { get; set; } - if (!sponsorship.PlanSponsorshipType.HasValue) - { - throw new ArgumentException("Invalid OrganizationSponsorship to create a token, PlanSponsorshipType is required", nameof(sponsorship)); - } - SponsorshipType = sponsorship.PlanSponsorshipType.Value; + public override bool Valid => !string.IsNullOrWhiteSpace(Email) && + Identifier == TokenIdentifier && + Id != default; - if (sponsorship.Id == default) + + [JsonConstructor] + public OrganizationSponsorshipOfferTokenable() { } + + public OrganizationSponsorshipOfferTokenable(OrganizationSponsorship sponsorship) { - throw new ArgumentException("Invalid OrganizationSponsorship to create a token, Id is required", nameof(sponsorship)); + if (string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail)) + { + throw new ArgumentException("Invalid OrganizationSponsorship to create a token, OfferedToEmail is required", nameof(sponsorship)); + } + Email = sponsorship.OfferedToEmail; + + if (!sponsorship.PlanSponsorshipType.HasValue) + { + throw new ArgumentException("Invalid OrganizationSponsorship to create a token, PlanSponsorshipType is required", nameof(sponsorship)); + } + SponsorshipType = sponsorship.PlanSponsorshipType.Value; + + if (sponsorship.Id == default) + { + throw new ArgumentException("Invalid OrganizationSponsorship to create a token, Id is required", nameof(sponsorship)); + } + Id = sponsorship.Id; } - Id = sponsorship.Id; + + public bool IsValid(OrganizationSponsorship sponsorship, string currentUserEmail) => + sponsorship != null && + sponsorship.PlanSponsorshipType.HasValue && + SponsorshipType == sponsorship.PlanSponsorshipType.Value && + Id == sponsorship.Id && + !string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail) && + Email.Equals(currentUserEmail, StringComparison.InvariantCultureIgnoreCase) && + Email.Equals(sponsorship.OfferedToEmail, StringComparison.InvariantCultureIgnoreCase); + } - - public bool IsValid(OrganizationSponsorship sponsorship, string currentUserEmail) => - sponsorship != null && - sponsorship.PlanSponsorshipType.HasValue && - SponsorshipType == sponsorship.PlanSponsorshipType.Value && - Id == sponsorship.Id && - !string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail) && - Email.Equals(currentUserEmail, StringComparison.InvariantCultureIgnoreCase) && - Email.Equals(sponsorship.OfferedToEmail, StringComparison.InvariantCultureIgnoreCase); - } diff --git a/src/Core/Models/Business/Tokenables/SsoTokenable.cs b/src/Core/Models/Business/Tokenables/SsoTokenable.cs index f6524d2c7f..765e6ce596 100644 --- a/src/Core/Models/Business/Tokenables/SsoTokenable.cs +++ b/src/Core/Models/Business/Tokenables/SsoTokenable.cs @@ -2,42 +2,43 @@ using Bit.Core.Entities; using Bit.Core.Tokens; -namespace Bit.Core.Models.Business.Tokenables; - -public class SsoTokenable : ExpiringTokenable +namespace Bit.Core.Models.Business.Tokenables { - public const string ClearTextPrefix = "BWUserPrefix_"; - public const string DataProtectorPurpose = "SsoTokenDataProtector"; - public const string TokenIdentifier = "ssoToken"; - - public Guid OrganizationId { get; set; } - public string DomainHint { get; set; } - public string Identifier { get; set; } = TokenIdentifier; - - [JsonConstructor] - public SsoTokenable() { } - - public SsoTokenable(Organization organization, double tokenLifetimeInSeconds) : this() + public class SsoTokenable : ExpiringTokenable { - OrganizationId = organization?.Id ?? default; - DomainHint = organization?.Identifier; - ExpirationDate = DateTime.UtcNow.AddSeconds(tokenLifetimeInSeconds); - } + public const string ClearTextPrefix = "BWUserPrefix_"; + public const string DataProtectorPurpose = "SsoTokenDataProtector"; + public const string TokenIdentifier = "ssoToken"; - public bool TokenIsValid(Organization organization) - { - if (OrganizationId == default || DomainHint == default || organization == null || !Valid) + public Guid OrganizationId { get; set; } + public string DomainHint { get; set; } + public string Identifier { get; set; } = TokenIdentifier; + + [JsonConstructor] + public SsoTokenable() { } + + public SsoTokenable(Organization organization, double tokenLifetimeInSeconds) : this() { - return false; + OrganizationId = organization?.Id ?? default; + DomainHint = organization?.Identifier; + ExpirationDate = DateTime.UtcNow.AddSeconds(tokenLifetimeInSeconds); } - return organization.Identifier.Equals(DomainHint, StringComparison.InvariantCultureIgnoreCase) - && organization.Id.Equals(OrganizationId); - } + public bool TokenIsValid(Organization organization) + { + if (OrganizationId == default || DomainHint == default || organization == null || !Valid) + { + return false; + } - // Validates deserialized - protected override bool TokenIsValid() => - Identifier == TokenIdentifier - && OrganizationId != default - && !string.IsNullOrWhiteSpace(DomainHint); + return organization.Identifier.Equals(DomainHint, StringComparison.InvariantCultureIgnoreCase) + && organization.Id.Equals(OrganizationId); + } + + // Validates deserialized + protected override bool TokenIsValid() => + Identifier == TokenIdentifier + && OrganizationId != default + && !string.IsNullOrWhiteSpace(DomainHint); + } } diff --git a/src/Core/Models/Business/UserLicense.cs b/src/Core/Models/Business/UserLicense.cs index f079a71839..183bf95764 100644 --- a/src/Core/Models/Business/UserLicense.cs +++ b/src/Core/Models/Business/UserLicense.cs @@ -7,167 +7,168 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Services; -namespace Bit.Core.Models.Business; - -public class UserLicense : ILicense +namespace Bit.Core.Models.Business { - public UserLicense() - { } - - public UserLicense(User user, SubscriptionInfo subscriptionInfo, ILicensingService licenseService, - int? version = null) + public class UserLicense : ILicense { - LicenseType = Enums.LicenseType.User; - LicenseKey = user.LicenseKey; - Id = user.Id; - Name = user.Name; - Email = user.Email; - Version = version.GetValueOrDefault(1); - Premium = user.Premium; - MaxStorageGb = user.MaxStorageGb; - Issued = DateTime.UtcNow; - Expires = subscriptionInfo?.UpcomingInvoice?.Date != null ? - subscriptionInfo.UpcomingInvoice.Date.Value.AddDays(7) : - user.PremiumExpirationDate?.AddDays(7); - Refresh = subscriptionInfo?.UpcomingInvoice?.Date; - Trial = (subscriptionInfo?.Subscription?.TrialEndDate.HasValue ?? false) && - subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow; + public UserLicense() + { } - Hash = Convert.ToBase64String(ComputeHash()); - Signature = Convert.ToBase64String(licenseService.SignLicense(this)); - } - - public UserLicense(User user, ILicensingService licenseService, int? version = null) - { - LicenseType = Enums.LicenseType.User; - LicenseKey = user.LicenseKey; - Id = user.Id; - Name = user.Name; - Email = user.Email; - Version = version.GetValueOrDefault(1); - Premium = user.Premium; - MaxStorageGb = user.MaxStorageGb; - Issued = DateTime.UtcNow; - Expires = user.PremiumExpirationDate?.AddDays(7); - Refresh = user.PremiumExpirationDate?.Date; - Trial = false; - - Hash = Convert.ToBase64String(ComputeHash()); - Signature = Convert.ToBase64String(licenseService.SignLicense(this)); - } - - public string LicenseKey { get; set; } - public Guid Id { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public bool Premium { get; set; } - public short? MaxStorageGb { get; set; } - public int Version { get; set; } - public DateTime Issued { get; set; } - public DateTime? Refresh { get; set; } - public DateTime? Expires { get; set; } - public bool Trial { get; set; } - public LicenseType? LicenseType { get; set; } - public string Hash { get; set; } - public string Signature { get; set; } - [JsonIgnore] - public byte[] SignatureBytes => Convert.FromBase64String(Signature); - - public byte[] GetDataBytes(bool forHash = false) - { - string data = null; - if (Version == 1) + public UserLicense(User user, SubscriptionInfo subscriptionInfo, ILicensingService licenseService, + int? version = null) { - var props = typeof(UserLicense) - .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(p => - !p.Name.Equals(nameof(Signature)) && - !p.Name.Equals(nameof(SignatureBytes)) && - !p.Name.Equals(nameof(LicenseType)) && - ( - !forHash || + LicenseType = Enums.LicenseType.User; + LicenseKey = user.LicenseKey; + Id = user.Id; + Name = user.Name; + Email = user.Email; + Version = version.GetValueOrDefault(1); + Premium = user.Premium; + MaxStorageGb = user.MaxStorageGb; + Issued = DateTime.UtcNow; + Expires = subscriptionInfo?.UpcomingInvoice?.Date != null ? + subscriptionInfo.UpcomingInvoice.Date.Value.AddDays(7) : + user.PremiumExpirationDate?.AddDays(7); + Refresh = subscriptionInfo?.UpcomingInvoice?.Date; + Trial = (subscriptionInfo?.Subscription?.TrialEndDate.HasValue ?? false) && + subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow; + + Hash = Convert.ToBase64String(ComputeHash()); + Signature = Convert.ToBase64String(licenseService.SignLicense(this)); + } + + public UserLicense(User user, ILicensingService licenseService, int? version = null) + { + LicenseType = Enums.LicenseType.User; + LicenseKey = user.LicenseKey; + Id = user.Id; + Name = user.Name; + Email = user.Email; + Version = version.GetValueOrDefault(1); + Premium = user.Premium; + MaxStorageGb = user.MaxStorageGb; + Issued = DateTime.UtcNow; + Expires = user.PremiumExpirationDate?.AddDays(7); + Refresh = user.PremiumExpirationDate?.Date; + Trial = false; + + Hash = Convert.ToBase64String(ComputeHash()); + Signature = Convert.ToBase64String(licenseService.SignLicense(this)); + } + + public string LicenseKey { get; set; } + public Guid Id { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public bool Premium { get; set; } + public short? MaxStorageGb { get; set; } + public int Version { get; set; } + public DateTime Issued { get; set; } + public DateTime? Refresh { get; set; } + public DateTime? Expires { get; set; } + public bool Trial { get; set; } + public LicenseType? LicenseType { get; set; } + public string Hash { get; set; } + public string Signature { get; set; } + [JsonIgnore] + public byte[] SignatureBytes => Convert.FromBase64String(Signature); + + public byte[] GetDataBytes(bool forHash = false) + { + string data = null; + if (Version == 1) + { + var props = typeof(UserLicense) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => + !p.Name.Equals(nameof(Signature)) && + !p.Name.Equals(nameof(SignatureBytes)) && + !p.Name.Equals(nameof(LicenseType)) && ( - !p.Name.Equals(nameof(Hash)) && - !p.Name.Equals(nameof(Issued)) && - !p.Name.Equals(nameof(Refresh)) - ) - )) - .OrderBy(p => p.Name) - .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") - .Aggregate((c, n) => $"{c}|{n}"); - data = $"license:user|{props}"; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); + !forHash || + ( + !p.Name.Equals(nameof(Hash)) && + !p.Name.Equals(nameof(Issued)) && + !p.Name.Equals(nameof(Refresh)) + ) + )) + .OrderBy(p => p.Name) + .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") + .Aggregate((c, n) => $"{c}|{n}"); + data = $"license:user|{props}"; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + + return Encoding.UTF8.GetBytes(data); } - return Encoding.UTF8.GetBytes(data); - } - - public byte[] ComputeHash() - { - using (var alg = SHA256.Create()) + public byte[] ComputeHash() { - return alg.ComputeHash(GetDataBytes(true)); - } - } - - public bool CanUse(User user) - { - if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; + using (var alg = SHA256.Create()) + { + return alg.ComputeHash(GetDataBytes(true)); + } } - if (Version == 1) + public bool CanUse(User user) { - return user.EmailVerified && user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - } + if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; + } - public bool VerifyData(User user) - { - if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; + if (Version == 1) + { + return user.EmailVerified && user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } } - if (Version == 1) + public bool VerifyData(User user) { - return - user.LicenseKey != null && user.LicenseKey.Equals(LicenseKey) && - user.Premium == Premium && - user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - } + if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; + } - public bool VerifySignature(X509Certificate2 certificate) - { - using (var rsa = certificate.GetRSAPublicKey()) - { - return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - } - } - - public byte[] Sign(X509Certificate2 certificate) - { - if (!certificate.HasPrivateKey) - { - throw new InvalidOperationException("You don't have the private key!"); + if (Version == 1) + { + return + user.LicenseKey != null && user.LicenseKey.Equals(LicenseKey) && + user.Premium == Premium && + user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } } - using (var rsa = certificate.GetRSAPrivateKey()) + public bool VerifySignature(X509Certificate2 certificate) { - return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + using (var rsa = certificate.GetRSAPublicKey()) + { + return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + } + } + + public byte[] Sign(X509Certificate2 certificate) + { + if (!certificate.HasPrivateKey) + { + throw new InvalidOperationException("You don't have the private key!"); + } + + using (var rsa = certificate.GetRSAPrivateKey()) + { + return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + } } } } diff --git a/src/Core/Models/Data/AttachmentResponseData.cs b/src/Core/Models/Data/AttachmentResponseData.cs index f45125c3d6..1a5c0de433 100644 --- a/src/Core/Models/Data/AttachmentResponseData.cs +++ b/src/Core/Models/Data/AttachmentResponseData.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data; - -public class AttachmentResponseData +namespace Bit.Core.Models.Data { - public string Id { get; set; } - public CipherAttachment.MetaData Data { get; set; } - public Cipher Cipher { get; set; } - public string Url { get; set; } + public class AttachmentResponseData + { + public string Id { get; set; } + public CipherAttachment.MetaData Data { get; set; } + public Cipher Cipher { get; set; } + public string Url { get; set; } + } } diff --git a/src/Core/Models/Data/CipherAttachment.cs b/src/Core/Models/Data/CipherAttachment.cs index 62b46335ae..a306c76ad4 100644 --- a/src/Core/Models/Data/CipherAttachment.cs +++ b/src/Core/Models/Data/CipherAttachment.cs @@ -1,35 +1,36 @@ using System.Text.Json.Serialization; -namespace Bit.Core.Models.Data; - -public class CipherAttachment +namespace Bit.Core.Models.Data { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public string AttachmentId { get; set; } - public string AttachmentData { get; set; } - - public class MetaData + public class CipherAttachment { - private long _size; - - // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers - [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] - public long Size - { - get { return _size; } - set { _size = value; } - } - - public string FileName { get; set; } - public string Key { get; set; } - - public string ContainerName { get; set; } = "attachments"; - public bool Validated { get; set; } = true; - - // This is stored alongside metadata as an identifier. It does not need repeating in serialization - [JsonIgnore] + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } public string AttachmentId { get; set; } + public string AttachmentData { get; set; } + + public class MetaData + { + private long _size; + + // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] + public long Size + { + get { return _size; } + set { _size = value; } + } + + public string FileName { get; set; } + public string Key { get; set; } + + public string ContainerName { get; set; } = "attachments"; + public bool Validated { get; set; } = true; + + // This is stored alongside metadata as an identifier. It does not need repeating in serialization + [JsonIgnore] + public string AttachmentId { get; set; } + } } } diff --git a/src/Core/Models/Data/CipherCardData.cs b/src/Core/Models/Data/CipherCardData.cs index fdfc604dae..0d8745eb9e 100644 --- a/src/Core/Models/Data/CipherCardData.cs +++ b/src/Core/Models/Data/CipherCardData.cs @@ -1,13 +1,14 @@ -namespace Bit.Core.Models.Data; - -public class CipherCardData : CipherData +namespace Bit.Core.Models.Data { - public CipherCardData() { } + public class CipherCardData : CipherData + { + public CipherCardData() { } - public string CardholderName { get; set; } - public string Brand { get; set; } - public string Number { get; set; } - public string ExpMonth { get; set; } - public string ExpYear { get; set; } - public string Code { get; set; } + public string CardholderName { get; set; } + public string Brand { get; set; } + public string Number { get; set; } + public string ExpMonth { get; set; } + public string ExpYear { get; set; } + public string Code { get; set; } + } } diff --git a/src/Core/Models/Data/CipherData.cs b/src/Core/Models/Data/CipherData.cs index 9881ed6ba1..3c7598f260 100644 --- a/src/Core/Models/Data/CipherData.cs +++ b/src/Core/Models/Data/CipherData.cs @@ -1,11 +1,12 @@ -namespace Bit.Core.Models.Data; - -public abstract class CipherData +namespace Bit.Core.Models.Data { - public CipherData() { } + public abstract class CipherData + { + public CipherData() { } - public string Name { get; set; } - public string Notes { get; set; } - public IEnumerable Fields { get; set; } - public IEnumerable PasswordHistory { get; set; } + public string Name { get; set; } + public string Notes { get; set; } + public IEnumerable Fields { get; set; } + public IEnumerable PasswordHistory { get; set; } + } } diff --git a/src/Core/Models/Data/CipherDetails.cs b/src/Core/Models/Data/CipherDetails.cs index 21a636bf76..e7276ac3c4 100644 --- a/src/Core/Models/Data/CipherDetails.cs +++ b/src/Core/Models/Data/CipherDetails.cs @@ -1,9 +1,10 @@ -namespace Core.Models.Data; - -public class CipherDetails : CipherOrganizationDetails +namespace Core.Models.Data { - public Guid? FolderId { get; set; } - public bool Favorite { get; set; } - public bool Edit { get; set; } - public bool ViewPassword { get; set; } + public class CipherDetails : CipherOrganizationDetails + { + public Guid? FolderId { get; set; } + public bool Favorite { get; set; } + public bool Edit { get; set; } + public bool ViewPassword { get; set; } + } } diff --git a/src/Core/Models/Data/CipherFieldData.cs b/src/Core/Models/Data/CipherFieldData.cs index 748a478cf9..b46d16099f 100644 --- a/src/Core/Models/Data/CipherFieldData.cs +++ b/src/Core/Models/Data/CipherFieldData.cs @@ -1,13 +1,14 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data; - -public class CipherFieldData +namespace Bit.Core.Models.Data { - public CipherFieldData() { } + public class CipherFieldData + { + public CipherFieldData() { } - public FieldType Type { get; set; } - public string Name { get; set; } - public string Value { get; set; } - public int? LinkedId { get; set; } + public FieldType Type { get; set; } + public string Name { get; set; } + public string Value { get; set; } + public int? LinkedId { get; set; } + } } diff --git a/src/Core/Models/Data/CipherIdentityData.cs b/src/Core/Models/Data/CipherIdentityData.cs index 19773424a3..3a5aa70e83 100644 --- a/src/Core/Models/Data/CipherIdentityData.cs +++ b/src/Core/Models/Data/CipherIdentityData.cs @@ -1,25 +1,26 @@ -namespace Bit.Core.Models.Data; - -public class CipherIdentityData : CipherData +namespace Bit.Core.Models.Data { - public CipherIdentityData() { } + public class CipherIdentityData : CipherData + { + public CipherIdentityData() { } - public string Title { get; set; } - public string FirstName { get; set; } - public string MiddleName { get; set; } - public string LastName { get; set; } - public string Address1 { get; set; } - public string Address2 { get; set; } - public string Address3 { get; set; } - public string City { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public string Country { get; set; } - public string Company { get; set; } - public string Email { get; set; } - public string Phone { get; set; } - public string SSN { get; set; } - public string Username { get; set; } - public string PassportNumber { get; set; } - public string LicenseNumber { get; set; } + public string Title { get; set; } + public string FirstName { get; set; } + public string MiddleName { get; set; } + public string LastName { get; set; } + public string Address1 { get; set; } + public string Address2 { get; set; } + public string Address3 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Country { get; set; } + public string Company { get; set; } + public string Email { get; set; } + public string Phone { get; set; } + public string SSN { get; set; } + public string Username { get; set; } + public string PassportNumber { get; set; } + public string LicenseNumber { get; set; } + } } diff --git a/src/Core/Models/Data/CipherLoginData.cs b/src/Core/Models/Data/CipherLoginData.cs index d266d7786b..2a98ff1557 100644 --- a/src/Core/Models/Data/CipherLoginData.cs +++ b/src/Core/Models/Data/CipherLoginData.cs @@ -1,30 +1,31 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data; - -public class CipherLoginData : CipherData +namespace Bit.Core.Models.Data { - private string _uri; - - public CipherLoginData() { } - - public string Uri + public class CipherLoginData : CipherData { - get => Uris?.FirstOrDefault()?.Uri ?? _uri; - set { _uri = value; } - } - public IEnumerable Uris { get; set; } - public string Username { get; set; } - public string Password { get; set; } - public DateTime? PasswordRevisionDate { get; set; } - public string Totp { get; set; } - public bool? AutofillOnPageLoad { get; set; } + private string _uri; - public class CipherLoginUriData - { - public CipherLoginUriData() { } + public CipherLoginData() { } - public string Uri { get; set; } - public UriMatchType? Match { get; set; } = null; + public string Uri + { + get => Uris?.FirstOrDefault()?.Uri ?? _uri; + set { _uri = value; } + } + public IEnumerable Uris { get; set; } + public string Username { get; set; } + public string Password { get; set; } + public DateTime? PasswordRevisionDate { get; set; } + public string Totp { get; set; } + public bool? AutofillOnPageLoad { get; set; } + + public class CipherLoginUriData + { + public CipherLoginUriData() { } + + public string Uri { get; set; } + public UriMatchType? Match { get; set; } = null; + } } } diff --git a/src/Core/Models/Data/CipherOrganizationDetails.cs b/src/Core/Models/Data/CipherOrganizationDetails.cs index d2717b30fa..522ebdd2fe 100644 --- a/src/Core/Models/Data/CipherOrganizationDetails.cs +++ b/src/Core/Models/Data/CipherOrganizationDetails.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Core.Models.Data; - -public class CipherOrganizationDetails : Cipher +namespace Core.Models.Data { - public bool OrganizationUseTotp { get; set; } + public class CipherOrganizationDetails : Cipher + { + public bool OrganizationUseTotp { get; set; } + } } diff --git a/src/Core/Models/Data/CipherPasswordHistoryData.cs b/src/Core/Models/Data/CipherPasswordHistoryData.cs index 3ea5edab40..2362572a1e 100644 --- a/src/Core/Models/Data/CipherPasswordHistoryData.cs +++ b/src/Core/Models/Data/CipherPasswordHistoryData.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Models.Data; - -public class CipherPasswordHistoryData +namespace Bit.Core.Models.Data { - public CipherPasswordHistoryData() { } + public class CipherPasswordHistoryData + { + public CipherPasswordHistoryData() { } - public string Password { get; set; } - public DateTime LastUsedDate { get; set; } + public string Password { get; set; } + public DateTime LastUsedDate { get; set; } + } } diff --git a/src/Core/Models/Data/CipherSecureNoteData.cs b/src/Core/Models/Data/CipherSecureNoteData.cs index 88b7384cd7..1287e71dfc 100644 --- a/src/Core/Models/Data/CipherSecureNoteData.cs +++ b/src/Core/Models/Data/CipherSecureNoteData.cs @@ -1,10 +1,11 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data; - -public class CipherSecureNoteData : CipherData +namespace Bit.Core.Models.Data { - public CipherSecureNoteData() { } + public class CipherSecureNoteData : CipherData + { + public CipherSecureNoteData() { } - public SecureNoteType Type { get; set; } + public SecureNoteType Type { get; set; } + } } diff --git a/src/Core/Models/Data/CollectionDetails.cs b/src/Core/Models/Data/CollectionDetails.cs index 4b618749e8..110acc3e59 100644 --- a/src/Core/Models/Data/CollectionDetails.cs +++ b/src/Core/Models/Data/CollectionDetails.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data; - -public class CollectionDetails : Collection +namespace Bit.Core.Models.Data { - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } + public class CollectionDetails : Collection + { + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } + } } diff --git a/src/Core/Models/Data/DictionaryEntity.cs b/src/Core/Models/Data/DictionaryEntity.cs index 72e6c871c7..00b85d6a2d 100644 --- a/src/Core/Models/Data/DictionaryEntity.cs +++ b/src/Core/Models/Data/DictionaryEntity.cs @@ -1,134 +1,135 @@ using System.Collections; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Models.Data; - -public class DictionaryEntity : TableEntity, IDictionary +namespace Bit.Core.Models.Data { - private IDictionary _properties = new Dictionary(); - - public ICollection Values => _properties.Values; - - public EntityProperty this[string key] + public class DictionaryEntity : TableEntity, IDictionary { - get => _properties[key]; - set => _properties[key] = value; - } + private IDictionary _properties = new Dictionary(); - public int Count => _properties.Count; + public ICollection Values => _properties.Values; - public bool IsReadOnly => _properties.IsReadOnly; + public EntityProperty this[string key] + { + get => _properties[key]; + set => _properties[key] = value; + } - public ICollection Keys => _properties.Keys; + public int Count => _properties.Count; - public override void ReadEntity(IDictionary properties, - OperationContext operationContext) - { - _properties = properties; - } + public bool IsReadOnly => _properties.IsReadOnly; - public override IDictionary WriteEntity(OperationContext operationContext) - { - return _properties; - } + public ICollection Keys => _properties.Keys; - public void Add(string key, EntityProperty value) - { - _properties.Add(key, value); - } + public override void ReadEntity(IDictionary properties, + OperationContext operationContext) + { + _properties = properties; + } - public void Add(string key, bool value) - { - _properties.Add(key, new EntityProperty(value)); - } + public override IDictionary WriteEntity(OperationContext operationContext) + { + return _properties; + } - public void Add(string key, byte[] value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, EntityProperty value) + { + _properties.Add(key, value); + } - public void Add(string key, DateTime? value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, bool value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, DateTimeOffset? value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, byte[] value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, double value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, DateTime? value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, Guid value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, DateTimeOffset? value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, int value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, double value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, long value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, Guid value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, string value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, int value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(KeyValuePair item) - { - _properties.Add(item); - } + public void Add(string key, long value) + { + _properties.Add(key, new EntityProperty(value)); + } - public bool ContainsKey(string key) - { - return _properties.ContainsKey(key); - } + public void Add(string key, string value) + { + _properties.Add(key, new EntityProperty(value)); + } - public bool Remove(string key) - { - return _properties.Remove(key); - } + public void Add(KeyValuePair item) + { + _properties.Add(item); + } - public bool TryGetValue(string key, out EntityProperty value) - { - return _properties.TryGetValue(key, out value); - } + public bool ContainsKey(string key) + { + return _properties.ContainsKey(key); + } - public void Clear() - { - _properties.Clear(); - } + public bool Remove(string key) + { + return _properties.Remove(key); + } - public bool Contains(KeyValuePair item) - { - return _properties.Contains(item); - } + public bool TryGetValue(string key, out EntityProperty value) + { + return _properties.TryGetValue(key, out value); + } - public void CopyTo(KeyValuePair[] array, int arrayIndex) - { - _properties.CopyTo(array, arrayIndex); - } + public void Clear() + { + _properties.Clear(); + } - public bool Remove(KeyValuePair item) - { - return _properties.Remove(item); - } + public bool Contains(KeyValuePair item) + { + return _properties.Contains(item); + } - public IEnumerator> GetEnumerator() - { - return _properties.GetEnumerator(); - } + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + _properties.CopyTo(array, arrayIndex); + } - IEnumerator IEnumerable.GetEnumerator() - { - return _properties.GetEnumerator(); + public bool Remove(KeyValuePair item) + { + return _properties.Remove(item); + } + + public IEnumerator> GetEnumerator() + { + return _properties.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _properties.GetEnumerator(); + } } } diff --git a/src/Core/Models/Data/EmergencyAccessDetails.cs b/src/Core/Models/Data/EmergencyAccessDetails.cs index 89b04e3fcf..54e5069a0e 100644 --- a/src/Core/Models/Data/EmergencyAccessDetails.cs +++ b/src/Core/Models/Data/EmergencyAccessDetails.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data; - -public class EmergencyAccessDetails : EmergencyAccess +namespace Bit.Core.Models.Data { - public string GranteeName { get; set; } - public string GranteeEmail { get; set; } - public string GrantorName { get; set; } - public string GrantorEmail { get; set; } + public class EmergencyAccessDetails : EmergencyAccess + { + public string GranteeName { get; set; } + public string GranteeEmail { get; set; } + public string GrantorName { get; set; } + public string GrantorEmail { get; set; } + } } diff --git a/src/Core/Models/Data/EmergencyAccessNotify.cs b/src/Core/Models/Data/EmergencyAccessNotify.cs index 6eaccd272d..4661a1b494 100644 --- a/src/Core/Models/Data/EmergencyAccessNotify.cs +++ b/src/Core/Models/Data/EmergencyAccessNotify.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data; - -public class EmergencyAccessNotify : EmergencyAccess +namespace Bit.Core.Models.Data { - public string GrantorEmail { get; set; } - public string GranteeName { get; set; } - public string GranteeEmail { get; set; } + public class EmergencyAccessNotify : EmergencyAccess + { + public string GrantorEmail { get; set; } + public string GranteeName { get; set; } + public string GranteeEmail { get; set; } + } } diff --git a/src/Core/Models/Data/EmergencyAccessViewData.cs b/src/Core/Models/Data/EmergencyAccessViewData.cs index ef9ffb0a21..86260e823e 100644 --- a/src/Core/Models/Data/EmergencyAccessViewData.cs +++ b/src/Core/Models/Data/EmergencyAccessViewData.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; using Core.Models.Data; -namespace Bit.Core.Models.Data; - -public class EmergencyAccessViewData +namespace Bit.Core.Models.Data { - public EmergencyAccess EmergencyAccess { get; set; } - public IEnumerable Ciphers { get; set; } + public class EmergencyAccessViewData + { + public EmergencyAccess EmergencyAccess { get; set; } + public IEnumerable Ciphers { get; set; } + } } diff --git a/src/Core/Models/Data/EventMessage.cs b/src/Core/Models/Data/EventMessage.cs index c77eceab08..f99330d013 100644 --- a/src/Core/Models/Data/EventMessage.cs +++ b/src/Core/Models/Data/EventMessage.cs @@ -1,34 +1,35 @@ using Bit.Core.Context; using Bit.Core.Enums; -namespace Bit.Core.Models.Data; - -public class EventMessage : IEvent +namespace Bit.Core.Models.Data { - public EventMessage() { } - - public EventMessage(ICurrentContext currentContext) - : base() + public class EventMessage : IEvent { - IpAddress = currentContext.IpAddress; - DeviceType = currentContext.DeviceType; - } + public EventMessage() { } - public DateTime Date { get; set; } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? GroupId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public Guid? ActingUserId { get; set; } - public DeviceType? DeviceType { get; set; } - public string IpAddress { get; set; } - public Guid? IdempotencyId { get; private set; } = Guid.NewGuid(); + public EventMessage(ICurrentContext currentContext) + : base() + { + IpAddress = currentContext.IpAddress; + DeviceType = currentContext.DeviceType; + } + + public DateTime Date { get; set; } + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? GroupId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public Guid? ActingUserId { get; set; } + public DeviceType? DeviceType { get; set; } + public string IpAddress { get; set; } + public Guid? IdempotencyId { get; private set; } = Guid.NewGuid(); + } } diff --git a/src/Core/Models/Data/EventTableEntity.cs b/src/Core/Models/Data/EventTableEntity.cs index 182a3171de..83e25b296a 100644 --- a/src/Core/Models/Data/EventTableEntity.cs +++ b/src/Core/Models/Data/EventTableEntity.cs @@ -2,153 +2,154 @@ using Bit.Core.Utilities; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Models.Data; - -public class EventTableEntity : TableEntity, IEvent +namespace Bit.Core.Models.Data { - public EventTableEntity() { } - - private EventTableEntity(IEvent e) + public class EventTableEntity : TableEntity, IEvent { - Date = e.Date; - Type = e.Type; - UserId = e.UserId; - OrganizationId = e.OrganizationId; - InstallationId = e.InstallationId; - ProviderId = e.ProviderId; - CipherId = e.CipherId; - CollectionId = e.CollectionId; - PolicyId = e.PolicyId; - GroupId = e.GroupId; - OrganizationUserId = e.OrganizationUserId; - ProviderUserId = e.ProviderUserId; - ProviderOrganizationId = e.ProviderOrganizationId; - DeviceType = e.DeviceType; - IpAddress = e.IpAddress; - ActingUserId = e.ActingUserId; - } + public EventTableEntity() { } - public DateTime Date { get; set; } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? GroupId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public DeviceType? DeviceType { get; set; } - public string IpAddress { get; set; } - public Guid? ActingUserId { get; set; } - - public override IDictionary WriteEntity(OperationContext operationContext) - { - var result = base.WriteEntity(operationContext); - - var typeName = nameof(Type); - if (result.ContainsKey(typeName)) + private EventTableEntity(IEvent e) { - result[typeName] = new EntityProperty((int)Type); - } - else - { - result.Add(typeName, new EntityProperty((int)Type)); + Date = e.Date; + Type = e.Type; + UserId = e.UserId; + OrganizationId = e.OrganizationId; + InstallationId = e.InstallationId; + ProviderId = e.ProviderId; + CipherId = e.CipherId; + CollectionId = e.CollectionId; + PolicyId = e.PolicyId; + GroupId = e.GroupId; + OrganizationUserId = e.OrganizationUserId; + ProviderUserId = e.ProviderUserId; + ProviderOrganizationId = e.ProviderOrganizationId; + DeviceType = e.DeviceType; + IpAddress = e.IpAddress; + ActingUserId = e.ActingUserId; } - var deviceTypeName = nameof(DeviceType); - if (result.ContainsKey(deviceTypeName)) + public DateTime Date { get; set; } + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? GroupId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public DeviceType? DeviceType { get; set; } + public string IpAddress { get; set; } + public Guid? ActingUserId { get; set; } + + public override IDictionary WriteEntity(OperationContext operationContext) { - result[deviceTypeName] = new EntityProperty((int?)DeviceType); - } - else - { - result.Add(deviceTypeName, new EntityProperty((int?)DeviceType)); - } + var result = base.WriteEntity(operationContext); - return result; - } - - public override void ReadEntity(IDictionary properties, - OperationContext operationContext) - { - base.ReadEntity(properties, operationContext); - - var typeName = nameof(Type); - if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue) - { - Type = (EventType)properties[typeName].Int32Value.Value; - } - - var deviceTypeName = nameof(DeviceType); - if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue) - { - DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value; - } - } - - public static List IndexEvent(EventMessage e) - { - var uniquifier = e.IdempotencyId.GetValueOrDefault(Guid.NewGuid()); - - var pKey = GetPartitionKey(e); - - var dateKey = CoreHelpers.DateTimeToTableStorageKey(e.Date); - - var entities = new List - { - new EventTableEntity(e) + var typeName = nameof(Type); + if (result.ContainsKey(typeName)) { - PartitionKey = pKey, - RowKey = $"Date={dateKey}__Uniquifier={uniquifier}" + result[typeName] = new EntityProperty((int)Type); } - }; - - if (e.OrganizationId.HasValue && e.ActingUserId.HasValue) - { - entities.Add(new EventTableEntity(e) + else { - PartitionKey = pKey, - RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" - }); - } + result.Add(typeName, new EntityProperty((int)Type)); + } - if (!e.OrganizationId.HasValue && e.ProviderId.HasValue && e.ActingUserId.HasValue) - { - entities.Add(new EventTableEntity(e) + var deviceTypeName = nameof(DeviceType); + if (result.ContainsKey(deviceTypeName)) { - PartitionKey = pKey, - RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" - }); - } - - if (e.CipherId.HasValue) - { - entities.Add(new EventTableEntity(e) + result[deviceTypeName] = new EntityProperty((int?)DeviceType); + } + else { - PartitionKey = pKey, - RowKey = $"CipherId={e.CipherId}__Date={dateKey}__Uniquifier={uniquifier}" - }); + result.Add(deviceTypeName, new EntityProperty((int?)DeviceType)); + } + + return result; } - return entities; - } - - private static string GetPartitionKey(EventMessage e) - { - if (e.OrganizationId.HasValue) + public override void ReadEntity(IDictionary properties, + OperationContext operationContext) { - return $"OrganizationId={e.OrganizationId}"; + base.ReadEntity(properties, operationContext); + + var typeName = nameof(Type); + if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue) + { + Type = (EventType)properties[typeName].Int32Value.Value; + } + + var deviceTypeName = nameof(DeviceType); + if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue) + { + DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value; + } } - if (e.ProviderId.HasValue) + public static List IndexEvent(EventMessage e) { - return $"ProviderId={e.ProviderId}"; + var uniquifier = e.IdempotencyId.GetValueOrDefault(Guid.NewGuid()); + + var pKey = GetPartitionKey(e); + + var dateKey = CoreHelpers.DateTimeToTableStorageKey(e.Date); + + var entities = new List + { + new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"Date={dateKey}__Uniquifier={uniquifier}" + } + }; + + if (e.OrganizationId.HasValue && e.ActingUserId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + + if (!e.OrganizationId.HasValue && e.ProviderId.HasValue && e.ActingUserId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + + if (e.CipherId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"CipherId={e.CipherId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + + return entities; } - return $"UserId={e.UserId}"; + private static string GetPartitionKey(EventMessage e) + { + if (e.OrganizationId.HasValue) + { + return $"OrganizationId={e.OrganizationId}"; + } + + if (e.ProviderId.HasValue) + { + return $"ProviderId={e.ProviderId}"; + } + + return $"UserId={e.UserId}"; + } } } diff --git a/src/Core/Models/Data/GroupWithCollections.cs b/src/Core/Models/Data/GroupWithCollections.cs index 3fa08bc45b..958f70feb1 100644 --- a/src/Core/Models/Data/GroupWithCollections.cs +++ b/src/Core/Models/Data/GroupWithCollections.cs @@ -1,9 +1,10 @@ using System.Data; using Bit.Core.Entities; -namespace Bit.Core.Models.Data; - -public class GroupWithCollections : Group +namespace Bit.Core.Models.Data { - public DataTable Collections { get; set; } + public class GroupWithCollections : Group + { + public DataTable Collections { get; set; } + } } diff --git a/src/Core/Models/Data/IEvent.cs b/src/Core/Models/Data/IEvent.cs index 82d8f74bac..860c6d446a 100644 --- a/src/Core/Models/Data/IEvent.cs +++ b/src/Core/Models/Data/IEvent.cs @@ -1,23 +1,24 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data; - -public interface IEvent +namespace Bit.Core.Models.Data { - EventType Type { get; set; } - Guid? UserId { get; set; } - Guid? OrganizationId { get; set; } - Guid? InstallationId { get; set; } - Guid? ProviderId { get; set; } - Guid? CipherId { get; set; } - Guid? CollectionId { get; set; } - Guid? GroupId { get; set; } - Guid? PolicyId { get; set; } - Guid? OrganizationUserId { get; set; } - Guid? ProviderUserId { get; set; } - Guid? ProviderOrganizationId { get; set; } - Guid? ActingUserId { get; set; } - DeviceType? DeviceType { get; set; } - string IpAddress { get; set; } - DateTime Date { get; set; } + public interface IEvent + { + EventType Type { get; set; } + Guid? UserId { get; set; } + Guid? OrganizationId { get; set; } + Guid? InstallationId { get; set; } + Guid? ProviderId { get; set; } + Guid? CipherId { get; set; } + Guid? CollectionId { get; set; } + Guid? GroupId { get; set; } + Guid? PolicyId { get; set; } + Guid? OrganizationUserId { get; set; } + Guid? ProviderUserId { get; set; } + Guid? ProviderOrganizationId { get; set; } + Guid? ActingUserId { get; set; } + DeviceType? DeviceType { get; set; } + string IpAddress { get; set; } + DateTime Date { get; set; } + } } diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs index cb7bf00873..0fb81e3404 100644 --- a/src/Core/Models/Data/InstallationDeviceEntity.cs +++ b/src/Core/Models/Data/InstallationDeviceEntity.cs @@ -1,34 +1,35 @@ using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Models.Data; - -public class InstallationDeviceEntity : TableEntity +namespace Bit.Core.Models.Data { - public InstallationDeviceEntity() { } - - public InstallationDeviceEntity(Guid installationId, Guid deviceId) + public class InstallationDeviceEntity : TableEntity { - PartitionKey = installationId.ToString(); - RowKey = deviceId.ToString(); - } + public InstallationDeviceEntity() { } - public InstallationDeviceEntity(string prefixedDeviceId) - { - var parts = prefixedDeviceId.Split("_"); - if (parts.Length < 2) + public InstallationDeviceEntity(Guid installationId, Guid deviceId) { - throw new ArgumentException("Not enough parts."); + PartitionKey = installationId.ToString(); + RowKey = deviceId.ToString(); } - if (!Guid.TryParse(parts[0], out var installationId) || !Guid.TryParse(parts[1], out var deviceId)) - { - throw new ArgumentException("Could not parse parts."); - } - PartitionKey = parts[0]; - RowKey = parts[1]; - } - public static bool IsInstallationDeviceId(string deviceId) - { - return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; + public InstallationDeviceEntity(string prefixedDeviceId) + { + var parts = prefixedDeviceId.Split("_"); + if (parts.Length < 2) + { + throw new ArgumentException("Not enough parts."); + } + if (!Guid.TryParse(parts[0], out var installationId) || !Guid.TryParse(parts[1], out var deviceId)) + { + throw new ArgumentException("Could not parse parts."); + } + PartitionKey = parts[0]; + RowKey = parts[1]; + } + + public static bool IsInstallationDeviceId(string deviceId) + { + return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; + } } } diff --git a/src/Core/Models/Data/Organizations/OrganizationAbility.cs b/src/Core/Models/Data/Organizations/OrganizationAbility.cs index 9b9ee85095..6ec693185a 100644 --- a/src/Core/Models/Data/Organizations/OrganizationAbility.cs +++ b/src/Core/Models/Data/Organizations/OrganizationAbility.cs @@ -1,34 +1,35 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data.Organizations; - -public class OrganizationAbility +namespace Bit.Core.Models.Data.Organizations { - public OrganizationAbility() { } - - public OrganizationAbility(Organization organization) + public class OrganizationAbility { - Id = organization.Id; - UseEvents = organization.UseEvents; - Use2fa = organization.Use2fa; - Using2fa = organization.Use2fa && organization.TwoFactorProviders != null && - organization.TwoFactorProviders != "{}"; - UsersGetPremium = organization.UsersGetPremium; - Enabled = organization.Enabled; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseResetPassword = organization.UseResetPassword; - } + public OrganizationAbility() { } - public Guid Id { get; set; } - public bool UseEvents { get; set; } - public bool Use2fa { get; set; } - public bool Using2fa { get; set; } - public bool UsersGetPremium { get; set; } - public bool Enabled { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseResetPassword { get; set; } + public OrganizationAbility(Organization organization) + { + Id = organization.Id; + UseEvents = organization.UseEvents; + Use2fa = organization.Use2fa; + Using2fa = organization.Use2fa && organization.TwoFactorProviders != null && + organization.TwoFactorProviders != "{}"; + UsersGetPremium = organization.UsersGetPremium; + Enabled = organization.Enabled; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseResetPassword = organization.UseResetPassword; + } + + public Guid Id { get; set; } + public bool UseEvents { get; set; } + public bool Use2fa { get; set; } + public bool Using2fa { get; set; } + public bool UsersGetPremium { get; set; } + public bool Enabled { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseResetPassword { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs b/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs index 3a3edaed45..272f411c5f 100644 --- a/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs @@ -1,31 +1,32 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationConnections; - -public class OrganizationConnectionData where T : new() +namespace Bit.Core.Models.Data.Organizations.OrganizationConnections { - public Guid? Id { get; set; } - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public T Config { get; set; } - - public OrganizationConnection ToEntity() + public class OrganizationConnectionData where T : new() { - var result = new OrganizationConnection() - { - Type = Type, - OrganizationId = OrganizationId, - Enabled = Enabled, - }; - result.SetConfig(Config); + public Guid? Id { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public T Config { get; set; } - if (Id.HasValue) + public OrganizationConnection ToEntity() { - result.Id = Id.Value; + var result = new OrganizationConnection() + { + Type = Type, + OrganizationId = OrganizationId, + Enabled = Enabled, + }; + result.SetConfig(Config); + + if (Id.HasValue) + { + result.Id = Id.Value; + } + + return result; } - - return result; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs index 927262957a..2a964ec99c 100644 --- a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs @@ -1,30 +1,31 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships; - -public class OrganizationSponsorshipData +namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships { - public OrganizationSponsorshipData() { } - public OrganizationSponsorshipData(OrganizationSponsorship sponsorship) + public class OrganizationSponsorshipData { - SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; - SponsoredOrganizationId = sponsorship.SponsoredOrganizationId; - FriendlyName = sponsorship.FriendlyName; - OfferedToEmail = sponsorship.OfferedToEmail; - PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); - LastSyncDate = sponsorship.LastSyncDate; - ValidUntil = sponsorship.ValidUntil; - ToDelete = sponsorship.ToDelete; - } - public Guid SponsoringOrganizationUserId { get; set; } - public Guid? SponsoredOrganizationId { get; set; } - public string FriendlyName { get; set; } - public string OfferedToEmail { get; set; } - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } + public OrganizationSponsorshipData() { } + public OrganizationSponsorshipData(OrganizationSponsorship sponsorship) + { + SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; + SponsoredOrganizationId = sponsorship.SponsoredOrganizationId; + FriendlyName = sponsorship.FriendlyName; + OfferedToEmail = sponsorship.OfferedToEmail; + PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); + LastSyncDate = sponsorship.LastSyncDate; + ValidUntil = sponsorship.ValidUntil; + ToDelete = sponsorship.ToDelete; + } + public Guid SponsoringOrganizationUserId { get; set; } + public Guid? SponsoredOrganizationId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } - public bool CloudSponsorshipRemoved { get; set; } + public bool CloudSponsorshipRemoved { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs index 8c10187116..29cd20030c 100644 --- a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships; - -public class OrganizationSponsorshipSyncData +namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships { - public string BillingSyncKey { get; set; } - public Guid SponsoringOrganizationCloudId { get; set; } - public IEnumerable SponsorshipsBatch { get; set; } + public class OrganizationSponsorshipSyncData + { + public string BillingSyncKey { get; set; } + public Guid SponsoringOrganizationCloudId { get; set; } + public IEnumerable SponsorshipsBatch { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs index ff360c10f1..e23be4468d 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs @@ -1,12 +1,13 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; - -public class OrganizationUserInviteData +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers { - public IEnumerable Emails { get; set; } - public OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public IEnumerable Collections { get; set; } - public Permissions Permissions { get; set; } + public class OrganizationUserInviteData + { + public IEnumerable Emails { get; set; } + public OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public IEnumerable Collections { get; set; } + public Permissions Permissions { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs index c132aee64f..554e99e6c3 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs @@ -1,42 +1,43 @@ -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; - -public class OrganizationUserOrganizationDetails +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers { - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public Enums.OrganizationUserStatusType Status { get; set; } - public Enums.OrganizationUserType Type { get; set; } - public bool Enabled { get; set; } - public Enums.PlanType PlanType { get; set; } - public string SsoExternalId { get; set; } - public string Identifier { get; set; } - public string Permissions { get; set; } - public string ResetPasswordKey { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public Guid? ProviderId { get; set; } - public string ProviderName { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } - public string SsoConfig { get; set; } - public DateTime? FamilySponsorshipLastSyncDate { get; set; } - public DateTime? FamilySponsorshipValidUntil { get; set; } - public bool? FamilySponsorshipToDelete { get; set; } + public class OrganizationUserOrganizationDetails + { + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string Key { get; set; } + public Enums.OrganizationUserStatusType Status { get; set; } + public Enums.OrganizationUserType Type { get; set; } + public bool Enabled { get; set; } + public Enums.PlanType PlanType { get; set; } + public string SsoExternalId { get; set; } + public string Identifier { get; set; } + public string Permissions { get; set; } + public string ResetPasswordKey { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public Guid? ProviderId { get; set; } + public string ProviderName { get; set; } + public string FamilySponsorshipFriendlyName { get; set; } + public string SsoConfig { get; set; } + public DateTime? FamilySponsorshipLastSyncDate { get; set; } + public DateTime? FamilySponsorshipValidUntil { get; set; } + public bool? FamilySponsorshipToDelete { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs index 7c04967872..c465f49a09 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; - -public class OrganizationUserPublicKey +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string PublicKey { get; set; } + public class OrganizationUserPublicKey + { + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string PublicKey { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs index 66fa27dfd3..ccac4a587b 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs @@ -1,34 +1,35 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; - -public class OrganizationUserResetPasswordDetails +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers { - public OrganizationUserResetPasswordDetails(OrganizationUser orgUser, User user, Organization org) + public class OrganizationUserResetPasswordDetails { - if (orgUser == null) + public OrganizationUserResetPasswordDetails(OrganizationUser orgUser, User user, Organization org) { - throw new ArgumentNullException(nameof(orgUser)); - } + if (orgUser == null) + { + throw new ArgumentNullException(nameof(orgUser)); + } - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } - if (org == null) - { - throw new ArgumentNullException(nameof(org)); - } + if (org == null) + { + throw new ArgumentNullException(nameof(org)); + } - Kdf = user.Kdf; - KdfIterations = user.KdfIterations; - ResetPasswordKey = orgUser.ResetPasswordKey; - EncryptedPrivateKey = org.PrivateKey; + Kdf = user.Kdf; + KdfIterations = user.KdfIterations; + ResetPasswordKey = orgUser.ResetPasswordKey; + EncryptedPrivateKey = org.PrivateKey; + } + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } + public string ResetPasswordKey { get; set; } + public string EncryptedPrivateKey { get; set; } } - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } - public string ResetPasswordKey { get; set; } - public string EncryptedPrivateKey { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs index ff28d1f3cd..334ee74177 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs @@ -1,59 +1,60 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; - -public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers { - private Dictionary _twoFactorProviders; - - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public string TwoFactorProviders { get; set; } - public bool? Premium { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool AccessAll { get; set; } - public string ExternalId { get; set; } - public string SsoExternalId { get; set; } - public string Permissions { get; set; } - public string ResetPasswordKey { get; set; } - public bool UsesKeyConnector { get; set; } - - public Dictionary GetTwoFactorProviders() + public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser { - if (string.IsNullOrWhiteSpace(TwoFactorProviders)) - { - return null; - } + private Dictionary _twoFactorProviders; - try + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public string TwoFactorProviders { get; set; } + public bool? Premium { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public bool AccessAll { get; set; } + public string ExternalId { get; set; } + public string SsoExternalId { get; set; } + public string Permissions { get; set; } + public string ResetPasswordKey { get; set; } + public bool UsesKeyConnector { get; set; } + + public Dictionary GetTwoFactorProviders() { - if (_twoFactorProviders == null) + if (string.IsNullOrWhiteSpace(TwoFactorProviders)) { - _twoFactorProviders = - JsonHelpers.LegacyDeserialize>( - TwoFactorProviders); + return null; } - return _twoFactorProviders; + try + { + if (_twoFactorProviders == null) + { + _twoFactorProviders = + JsonHelpers.LegacyDeserialize>( + TwoFactorProviders); + } + + return _twoFactorProviders; + } + catch (Newtonsoft.Json.JsonException) + { + return null; + } } - catch (Newtonsoft.Json.JsonException) + + public Guid? GetUserId() { - return null; + return UserId; } - } - public Guid? GetUserId() - { - return UserId; - } - - public bool GetPremium() - { - return Premium.GetValueOrDefault(false); + public bool GetPremium() + { + return Premium.GetValueOrDefault(false); + } } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs index d86c6c1581..c96a49f561 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs @@ -1,9 +1,10 @@ using System.Data; using Bit.Core.Entities; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; - -public class OrganizationUserWithCollections : OrganizationUser +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers { - public DataTable Collections { get; set; } + public class OrganizationUserWithCollections : OrganizationUser + { + public DataTable Collections { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs b/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs index ef8789d483..1d263cedb4 100644 --- a/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs +++ b/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs @@ -1,5 +1,6 @@ -namespace Bit.Core.Models.Data.Organizations.Policies; - -public interface IPolicyDataModel +namespace Bit.Core.Models.Data.Organizations.Policies { + public interface IPolicyDataModel + { + } } diff --git a/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs b/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs index 1931cc5b79..c77d8ef01e 100644 --- a/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs +++ b/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs @@ -1,9 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Data.Organizations.Policies; - -public class ResetPasswordDataModel : IPolicyDataModel +namespace Bit.Core.Models.Data.Organizations.Policies { - [Display(Name = "ResetPasswordAutoEnrollCheckbox")] - public bool AutoEnrollEnabled { get; set; } + public class ResetPasswordDataModel : IPolicyDataModel + { + [Display(Name = "ResetPasswordAutoEnrollCheckbox")] + public bool AutoEnrollEnabled { get; set; } + } } diff --git a/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs b/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs index aa9f651665..d9bb5ef9d7 100644 --- a/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs +++ b/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs @@ -1,9 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Data.Organizations.Policies; - -public class SendOptionsPolicyData : IPolicyDataModel +namespace Bit.Core.Models.Data.Organizations.Policies { - [Display(Name = "DisableHideEmail")] - public bool DisableHideEmail { get; set; } + public class SendOptionsPolicyData : IPolicyDataModel + { + [Display(Name = "DisableHideEmail")] + public bool DisableHideEmail { get; set; } + } } diff --git a/src/Core/Models/Data/PageOptions.cs b/src/Core/Models/Data/PageOptions.cs index e9f12ece9a..1b354932e8 100644 --- a/src/Core/Models/Data/PageOptions.cs +++ b/src/Core/Models/Data/PageOptions.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Data; - -public class PageOptions +namespace Bit.Core.Models.Data { - public string ContinuationToken { get; set; } - public int PageSize { get; set; } = 50; + public class PageOptions + { + public string ContinuationToken { get; set; } + public int PageSize { get; set; } = 50; + } } diff --git a/src/Core/Models/Data/PagedResult.cs b/src/Core/Models/Data/PagedResult.cs index b02044dd8c..1bb7e3cd26 100644 --- a/src/Core/Models/Data/PagedResult.cs +++ b/src/Core/Models/Data/PagedResult.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Data; - -public class PagedResult +namespace Bit.Core.Models.Data { - public List Data { get; set; } = new List(); - public string ContinuationToken { get; set; } + public class PagedResult + { + public List Data { get; set; } = new List(); + public string ContinuationToken { get; set; } + } } diff --git a/src/Core/Models/Data/Permissions.cs b/src/Core/Models/Data/Permissions.cs index 49a7e37f08..5cb0149a3f 100644 --- a/src/Core/Models/Data/Permissions.cs +++ b/src/Core/Models/Data/Permissions.cs @@ -1,44 +1,45 @@ using System.Text.Json.Serialization; -namespace Bit.Core.Models.Data; - -public class Permissions +namespace Bit.Core.Models.Data { - public bool AccessEventLogs { get; set; } - public bool AccessImportExport { get; set; } - public bool AccessReports { get; set; } - [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] - public bool ManageAllCollections => CreateNewCollections && EditAnyCollection && DeleteAnyCollection; - public bool CreateNewCollections { get; set; } - public bool EditAnyCollection { get; set; } - public bool DeleteAnyCollection { get; set; } - [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] - public bool ManageAssignedCollections => EditAssignedCollections && DeleteAssignedCollections; - public bool EditAssignedCollections { get; set; } - public bool DeleteAssignedCollections { get; set; } - public bool ManageGroups { get; set; } - public bool ManagePolicies { get; set; } - public bool ManageSso { get; set; } - public bool ManageUsers { get; set; } - public bool ManageResetPassword { get; set; } - public bool ManageScim { get; set; } - - [JsonIgnore] - public List<(bool Permission, string ClaimName)> ClaimsMap => new() + public class Permissions { - (AccessEventLogs, "accesseventlogs"), - (AccessImportExport, "accessimportexport"), - (AccessReports, "accessreports"), - (CreateNewCollections, "createnewcollections"), - (EditAnyCollection, "editanycollection"), - (DeleteAnyCollection, "deleteanycollection"), - (EditAssignedCollections, "editassignedcollections"), - (DeleteAssignedCollections, "deleteassignedcollections"), - (ManageGroups, "managegroups"), - (ManagePolicies, "managepolicies"), - (ManageSso, "managesso"), - (ManageUsers, "manageusers"), - (ManageResetPassword, "manageresetpassword"), - (ManageScim, "managescim"), - }; + public bool AccessEventLogs { get; set; } + public bool AccessImportExport { get; set; } + public bool AccessReports { get; set; } + [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] + public bool ManageAllCollections => CreateNewCollections && EditAnyCollection && DeleteAnyCollection; + public bool CreateNewCollections { get; set; } + public bool EditAnyCollection { get; set; } + public bool DeleteAnyCollection { get; set; } + [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] + public bool ManageAssignedCollections => EditAssignedCollections && DeleteAssignedCollections; + public bool EditAssignedCollections { get; set; } + public bool DeleteAssignedCollections { get; set; } + public bool ManageGroups { get; set; } + public bool ManagePolicies { get; set; } + public bool ManageSso { get; set; } + public bool ManageUsers { get; set; } + public bool ManageResetPassword { get; set; } + public bool ManageScim { get; set; } + + [JsonIgnore] + public List<(bool Permission, string ClaimName)> ClaimsMap => new() + { + (AccessEventLogs, "accesseventlogs"), + (AccessImportExport, "accessimportexport"), + (AccessReports, "accessreports"), + (CreateNewCollections, "createnewcollections"), + (EditAnyCollection, "editanycollection"), + (DeleteAnyCollection, "deleteanycollection"), + (EditAssignedCollections, "editassignedcollections"), + (DeleteAssignedCollections, "deleteassignedcollections"), + (ManageGroups, "managegroups"), + (ManagePolicies, "managepolicies"), + (ManageSso, "managesso"), + (ManageUsers, "manageusers"), + (ManageResetPassword, "manageresetpassword"), + (ManageScim, "managescim"), + }; + } } diff --git a/src/Core/Models/Data/Provider/ProviderAbility.cs b/src/Core/Models/Data/Provider/ProviderAbility.cs index b7e45eaed4..a772030142 100644 --- a/src/Core/Models/Data/Provider/ProviderAbility.cs +++ b/src/Core/Models/Data/Provider/ProviderAbility.cs @@ -1,19 +1,20 @@ using Bit.Core.Entities.Provider; -namespace Bit.Core.Models.Data; - -public class ProviderAbility +namespace Bit.Core.Models.Data { - public ProviderAbility() { } - - public ProviderAbility(Provider provider) + public class ProviderAbility { - Id = provider.Id; - UseEvents = provider.UseEvents; - Enabled = provider.Enabled; - } + public ProviderAbility() { } - public Guid Id { get; set; } - public bool UseEvents { get; set; } - public bool Enabled { get; set; } + public ProviderAbility(Provider provider) + { + Id = provider.Id; + UseEvents = provider.UseEvents; + Enabled = provider.Enabled; + } + + public Guid Id { get; set; } + public bool UseEvents { get; set; } + public bool Enabled { get; set; } + } } diff --git a/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs b/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs index 923bba4af0..279994df4b 100644 --- a/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs @@ -1,16 +1,17 @@ -namespace Bit.Core.Models.Data; - -public class ProviderOrganizationOrganizationDetails +namespace Bit.Core.Models.Data { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public string Key { get; set; } - public string Settings { get; set; } - public DateTime CreationDate { get; set; } - public DateTime RevisionDate { get; set; } - public int UserCount { get; set; } - public int? Seats { get; set; } - public string Plan { get; set; } + public class ProviderOrganizationOrganizationDetails + { + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid OrganizationId { get; set; } + public string OrganizationName { get; set; } + public string Key { get; set; } + public string Settings { get; set; } + public DateTime CreationDate { get; set; } + public DateTime RevisionDate { get; set; } + public int UserCount { get; set; } + public int? Seats { get; set; } + public string Plan { get; set; } + } } diff --git a/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs b/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs index 9d0740b73b..ab19931b69 100644 --- a/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs @@ -1,36 +1,37 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Data; - -public class ProviderUserOrganizationDetails +namespace Bit.Core.Models.Data { - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public bool Enabled { get; set; } - public string Identifier { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public Guid? ProviderId { get; set; } - public Guid? ProviderUserId { get; set; } - public string ProviderName { get; set; } + public class ProviderUserOrganizationDetails + { + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public bool Enabled { get; set; } + public string Identifier { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public Guid? ProviderId { get; set; } + public Guid? ProviderUserId { get; set; } + public string ProviderName { get; set; } + } } diff --git a/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs b/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs index 16f2e1dda5..a14a455d9e 100644 --- a/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs @@ -1,17 +1,18 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Data; - -public class ProviderUserProviderDetails +namespace Bit.Core.Models.Data { - public Guid ProviderId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public bool Enabled { get; set; } - public string Permissions { get; set; } - public bool UseEvents { get; set; } - public ProviderStatusType ProviderStatus { get; set; } + public class ProviderUserProviderDetails + { + public Guid ProviderId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public bool Enabled { get; set; } + public string Permissions { get; set; } + public bool UseEvents { get; set; } + public ProviderStatusType ProviderStatus { get; set; } + } } diff --git a/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs b/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs index 0b161fd860..0be26770c2 100644 --- a/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs +++ b/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Data; - -public class ProviderUserPublicKey +namespace Bit.Core.Models.Data { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string PublicKey { get; set; } + public class ProviderUserPublicKey + { + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string PublicKey { get; set; } + } } diff --git a/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs b/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs index 51df1d44ef..6d0c4daa61 100644 --- a/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs @@ -1,15 +1,16 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Data; - -public class ProviderUserUserDetails +namespace Bit.Core.Models.Data { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public string Permissions { get; set; } + public class ProviderUserUserDetails + { + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public string Permissions { get; set; } + } } diff --git a/src/Core/Models/Data/SelectionReadOnly.cs b/src/Core/Models/Data/SelectionReadOnly.cs index 426abb57f7..b1dd09d71e 100644 --- a/src/Core/Models/Data/SelectionReadOnly.cs +++ b/src/Core/Models/Data/SelectionReadOnly.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Data; - -public class SelectionReadOnly +namespace Bit.Core.Models.Data { - public Guid Id { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } + public class SelectionReadOnly + { + public Guid Id { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } + } } diff --git a/src/Core/Models/Data/SendData.cs b/src/Core/Models/Data/SendData.cs index 7210caae6c..956f934ba7 100644 --- a/src/Core/Models/Data/SendData.cs +++ b/src/Core/Models/Data/SendData.cs @@ -1,15 +1,16 @@ -namespace Bit.Core.Models.Data; - -public abstract class SendData +namespace Bit.Core.Models.Data { - public SendData() { } - - public SendData(string name, string notes) + public abstract class SendData { - Name = name; - Notes = notes; - } + public SendData() { } - public string Name { get; set; } - public string Notes { get; set; } + public SendData(string name, string notes) + { + Name = name; + Notes = notes; + } + + public string Name { get; set; } + public string Notes { get; set; } + } } diff --git a/src/Core/Models/Data/SendFileData.cs b/src/Core/Models/Data/SendFileData.cs index 253ee01cee..8ec61ec792 100644 --- a/src/Core/Models/Data/SendFileData.cs +++ b/src/Core/Models/Data/SendFileData.cs @@ -1,22 +1,23 @@ using System.Text.Json.Serialization; -namespace Bit.Core.Models.Data; - -public class SendFileData : SendData +namespace Bit.Core.Models.Data { - public SendFileData() { } - - public SendFileData(string name, string notes, string fileName) - : base(name, notes) + public class SendFileData : SendData { - FileName = fileName; + public SendFileData() { } + + public SendFileData(string name, string notes, string fileName) + : base(name, notes) + { + FileName = fileName; + } + + // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers + [JsonNumberHandling(JsonNumberHandling.WriteAsString | JsonNumberHandling.AllowReadingFromString)] + public long Size { get; set; } + + public string Id { get; set; } + public string FileName { get; set; } + public bool Validated { get; set; } = true; } - - // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers - [JsonNumberHandling(JsonNumberHandling.WriteAsString | JsonNumberHandling.AllowReadingFromString)] - public long Size { get; set; } - - public string Id { get; set; } - public string FileName { get; set; } - public bool Validated { get; set; } = true; } diff --git a/src/Core/Models/Data/SendTextData.cs b/src/Core/Models/Data/SendTextData.cs index 2aa6d0481f..0e6d301152 100644 --- a/src/Core/Models/Data/SendTextData.cs +++ b/src/Core/Models/Data/SendTextData.cs @@ -1,16 +1,17 @@ -namespace Bit.Core.Models.Data; - -public class SendTextData : SendData +namespace Bit.Core.Models.Data { - public SendTextData() { } - - public SendTextData(string name, string notes, string text, bool hidden) - : base(name, notes) + public class SendTextData : SendData { - Text = text; - Hidden = hidden; - } + public SendTextData() { } - public string Text { get; set; } - public bool Hidden { get; set; } + public SendTextData(string name, string notes, string text, bool hidden) + : base(name, notes) + { + Text = text; + Hidden = hidden; + } + + public string Text { get; set; } + public bool Hidden { get; set; } + } } diff --git a/src/Core/Models/Data/SsoConfigurationData.cs b/src/Core/Models/Data/SsoConfigurationData.cs index 844c52146a..093ec9a8aa 100644 --- a/src/Core/Models/Data/SsoConfigurationData.cs +++ b/src/Core/Models/Data/SsoConfigurationData.cs @@ -2,124 +2,125 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authentication.OpenIdConnect; -namespace Bit.Core.Models.Data; - -public class SsoConfigurationData +namespace Bit.Core.Models.Data { - private static string _oidcSigninPath = "/oidc-signin"; - private static string _oidcSignedOutPath = "/oidc-signedout"; - private static string _saml2ModulePath = "/saml2"; - - public static SsoConfigurationData Deserialize(string data) + public class SsoConfigurationData { - return CoreHelpers.LoadClassFromJsonData(data); - } + private static string _oidcSigninPath = "/oidc-signin"; + private static string _oidcSignedOutPath = "/oidc-signedout"; + private static string _saml2ModulePath = "/saml2"; - public string Serialize() - { - return CoreHelpers.ClassToJsonData(this); - } - - public SsoType ConfigType { get; set; } - - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } - - // OIDC - public string Authority { get; set; } - public string ClientId { get; set; } - public string ClientSecret { get; set; } - public string MetadataAddress { get; set; } - public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } - public bool GetClaimsFromUserInfoEndpoint { get; set; } - public string AdditionalScopes { get; set; } - public string AdditionalUserIdClaimTypes { get; set; } - public string AdditionalEmailClaimTypes { get; set; } - public string AdditionalNameClaimTypes { get; set; } - public string AcrValues { get; set; } - public string ExpectedReturnAcrValue { get; set; } - - // SAML2 IDP - public string IdpEntityId { get; set; } - public string IdpSingleSignOnServiceUrl { get; set; } - public string IdpSingleLogoutServiceUrl { get; set; } - public string IdpX509PublicCert { get; set; } - public Saml2BindingType IdpBindingType { get; set; } - public bool IdpAllowUnsolicitedAuthnResponse { get; set; } - public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } - public bool IdpDisableOutboundLogoutRequests { get; set; } - public string IdpOutboundSigningAlgorithm { get; set; } - public bool IdpWantAuthnRequestsSigned { get; set; } - - // SAML2 SP - public Saml2NameIdFormat SpNameIdFormat { get; set; } - public string SpOutboundSigningAlgorithm { get; set; } - public Saml2SigningBehavior SpSigningBehavior { get; set; } - public bool SpWantAssertionsSigned { get; set; } - public bool SpValidateCertificates { get; set; } - public string SpMinIncomingSigningAlgorithm { get; set; } - - public static string BuildCallbackPath(string ssoUri = null) - { - return BuildSsoUrl(_oidcSigninPath, ssoUri); - } - - public static string BuildSignedOutCallbackPath(string ssoUri = null) - { - return BuildSsoUrl(_oidcSignedOutPath, ssoUri); - } - - public static string BuildSaml2ModulePath(string ssoUri = null, string scheme = null) - { - return string.Concat(BuildSsoUrl(_saml2ModulePath, ssoUri), - string.IsNullOrWhiteSpace(scheme) ? string.Empty : $"/{scheme}"); - } - - public static string BuildSaml2AcsUrl(string ssoUri = null, string scheme = null) - { - return string.Concat(BuildSaml2ModulePath(ssoUri, scheme), "/Acs"); - } - - public static string BuildSaml2MetadataUrl(string ssoUri = null, string scheme = null) - { - return BuildSaml2ModulePath(ssoUri, scheme); - } - - public IEnumerable GetAdditionalScopes() => AdditionalScopes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - public IEnumerable GetAdditionalUserIdClaimTypes() => AdditionalUserIdClaimTypes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - public IEnumerable GetAdditionalEmailClaimTypes() => AdditionalEmailClaimTypes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - public IEnumerable GetAdditionalNameClaimTypes() => AdditionalNameClaimTypes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - private static string BuildSsoUrl(string relativePath, string ssoUri) - { - if (string.IsNullOrWhiteSpace(ssoUri) || - !Uri.IsWellFormedUriString(ssoUri, UriKind.Absolute)) + public static SsoConfigurationData Deserialize(string data) { + return CoreHelpers.LoadClassFromJsonData(data); + } + + public string Serialize() + { + return CoreHelpers.ClassToJsonData(this); + } + + public SsoType ConfigType { get; set; } + + public bool KeyConnectorEnabled { get; set; } + public string KeyConnectorUrl { get; set; } + + // OIDC + public string Authority { get; set; } + public string ClientId { get; set; } + public string ClientSecret { get; set; } + public string MetadataAddress { get; set; } + public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } + public bool GetClaimsFromUserInfoEndpoint { get; set; } + public string AdditionalScopes { get; set; } + public string AdditionalUserIdClaimTypes { get; set; } + public string AdditionalEmailClaimTypes { get; set; } + public string AdditionalNameClaimTypes { get; set; } + public string AcrValues { get; set; } + public string ExpectedReturnAcrValue { get; set; } + + // SAML2 IDP + public string IdpEntityId { get; set; } + public string IdpSingleSignOnServiceUrl { get; set; } + public string IdpSingleLogoutServiceUrl { get; set; } + public string IdpX509PublicCert { get; set; } + public Saml2BindingType IdpBindingType { get; set; } + public bool IdpAllowUnsolicitedAuthnResponse { get; set; } + public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } + public bool IdpDisableOutboundLogoutRequests { get; set; } + public string IdpOutboundSigningAlgorithm { get; set; } + public bool IdpWantAuthnRequestsSigned { get; set; } + + // SAML2 SP + public Saml2NameIdFormat SpNameIdFormat { get; set; } + public string SpOutboundSigningAlgorithm { get; set; } + public Saml2SigningBehavior SpSigningBehavior { get; set; } + public bool SpWantAssertionsSigned { get; set; } + public bool SpValidateCertificates { get; set; } + public string SpMinIncomingSigningAlgorithm { get; set; } + + public static string BuildCallbackPath(string ssoUri = null) + { + return BuildSsoUrl(_oidcSigninPath, ssoUri); + } + + public static string BuildSignedOutCallbackPath(string ssoUri = null) + { + return BuildSsoUrl(_oidcSignedOutPath, ssoUri); + } + + public static string BuildSaml2ModulePath(string ssoUri = null, string scheme = null) + { + return string.Concat(BuildSsoUrl(_saml2ModulePath, ssoUri), + string.IsNullOrWhiteSpace(scheme) ? string.Empty : $"/{scheme}"); + } + + public static string BuildSaml2AcsUrl(string ssoUri = null, string scheme = null) + { + return string.Concat(BuildSaml2ModulePath(ssoUri, scheme), "/Acs"); + } + + public static string BuildSaml2MetadataUrl(string ssoUri = null, string scheme = null) + { + return BuildSaml2ModulePath(ssoUri, scheme); + } + + public IEnumerable GetAdditionalScopes() => AdditionalScopes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + public IEnumerable GetAdditionalUserIdClaimTypes() => AdditionalUserIdClaimTypes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + public IEnumerable GetAdditionalEmailClaimTypes() => AdditionalEmailClaimTypes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + public IEnumerable GetAdditionalNameClaimTypes() => AdditionalNameClaimTypes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + private static string BuildSsoUrl(string relativePath, string ssoUri) + { + if (string.IsNullOrWhiteSpace(ssoUri) || + !Uri.IsWellFormedUriString(ssoUri, UriKind.Absolute)) + { + return relativePath; + } + if (Uri.TryCreate(string.Concat(ssoUri.TrimEnd('/'), relativePath), UriKind.Absolute, out var newUri)) + { + return newUri.ToString(); + } return relativePath; } - if (Uri.TryCreate(string.Concat(ssoUri.TrimEnd('/'), relativePath), UriKind.Absolute, out var newUri)) - { - return newUri.ToString(); - } - return relativePath; } } diff --git a/src/Core/Models/Data/UserKdfInformation.cs b/src/Core/Models/Data/UserKdfInformation.cs index 0fa3d6f83f..3825006d1d 100644 --- a/src/Core/Models/Data/UserKdfInformation.cs +++ b/src/Core/Models/Data/UserKdfInformation.cs @@ -1,9 +1,10 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data; - -public class UserKdfInformation +namespace Bit.Core.Models.Data { - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } + public class UserKdfInformation + { + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } + } } diff --git a/src/Core/Models/IExternal.cs b/src/Core/Models/IExternal.cs index e81de1d47b..f6d51add24 100644 --- a/src/Core/Models/IExternal.cs +++ b/src/Core/Models/IExternal.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models; - -public interface IExternal +namespace Bit.Core.Models { - string ExternalId { get; } + public interface IExternal + { + string ExternalId { get; } + } } diff --git a/src/Core/Models/ITwoFactorProvidersUser.cs b/src/Core/Models/ITwoFactorProvidersUser.cs index b056ba31ca..c617960ade 100644 --- a/src/Core/Models/ITwoFactorProvidersUser.cs +++ b/src/Core/Models/ITwoFactorProvidersUser.cs @@ -1,11 +1,12 @@ using Bit.Core.Enums; -namespace Bit.Core.Models; - -public interface ITwoFactorProvidersUser +namespace Bit.Core.Models { - string TwoFactorProviders { get; } - Dictionary GetTwoFactorProviders(); - Guid? GetUserId(); - bool GetPremium(); + public interface ITwoFactorProvidersUser + { + string TwoFactorProviders { get; } + Dictionary GetTwoFactorProviders(); + Guid? GetUserId(); + bool GetPremium(); + } } diff --git a/src/Core/Models/Mail/AddedCreditViewModel.cs b/src/Core/Models/Mail/AddedCreditViewModel.cs index 6ccfb5fcc8..fe3d995019 100644 --- a/src/Core/Models/Mail/AddedCreditViewModel.cs +++ b/src/Core/Models/Mail/AddedCreditViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class AddedCreditViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public decimal Amount { get; set; } + public class AddedCreditViewModel : BaseMailModel + { + public decimal Amount { get; set; } + } } diff --git a/src/Core/Models/Mail/AdminResetPasswordViewModel.cs b/src/Core/Models/Mail/AdminResetPasswordViewModel.cs index 18e257fea7..5f5e859ac2 100644 --- a/src/Core/Models/Mail/AdminResetPasswordViewModel.cs +++ b/src/Core/Models/Mail/AdminResetPasswordViewModel.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Mail; - -public class AdminResetPasswordViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string UserName { get; set; } - public string OrgName { get; set; } + public class AdminResetPasswordViewModel : BaseMailModel + { + public string UserName { get; set; } + public string OrgName { get; set; } + } } diff --git a/src/Core/Models/Mail/BaseMailModel.cs b/src/Core/Models/Mail/BaseMailModel.cs index e3aa4d2c41..416e50d252 100644 --- a/src/Core/Models/Mail/BaseMailModel.cs +++ b/src/Core/Models/Mail/BaseMailModel.cs @@ -1,26 +1,27 @@ -namespace Bit.Core.Models.Mail; - -public class BaseMailModel +namespace Bit.Core.Models.Mail { - public string SiteName { get; set; } - public string WebVaultUrl { get; set; } - public string WebVaultUrlHostname + public class BaseMailModel { - get + public string SiteName { get; set; } + public string WebVaultUrl { get; set; } + public string WebVaultUrlHostname { - if (Uri.TryCreate(WebVaultUrl, UriKind.Absolute, out Uri uri)) + get { - return uri.Host; - } + if (Uri.TryCreate(WebVaultUrl, UriKind.Absolute, out Uri uri)) + { + return uri.Host; + } - return WebVaultUrl; + return WebVaultUrl; + } } - } - public string CurrentYear - { - get + public string CurrentYear { - return DateTime.UtcNow.Year.ToString(); + get + { + return DateTime.UtcNow.Year.ToString(); + } } } } diff --git a/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs b/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs index 22367e8f27..8eda668828 100644 --- a/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs +++ b/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Mail; - -public class ChangeEmailExistsViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string FromEmail { get; set; } - public string ToEmail { get; set; } + public class ChangeEmailExistsViewModel : BaseMailModel + { + public string FromEmail { get; set; } + public string ToEmail { get; set; } + } } diff --git a/src/Core/Models/Mail/EmailTokenViewModel.cs b/src/Core/Models/Mail/EmailTokenViewModel.cs index 561df580e8..596fc7c21c 100644 --- a/src/Core/Models/Mail/EmailTokenViewModel.cs +++ b/src/Core/Models/Mail/EmailTokenViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class EmailTokenViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Token { get; set; } + public class EmailTokenViewModel : BaseMailModel + { + public string Token { get; set; } + } } diff --git a/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs index afe29b9843..1073ea8590 100644 --- a/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessAcceptedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string GranteeEmail { get; set; } + public class EmergencyAccessAcceptedViewModel : BaseMailModel + { + public string GranteeEmail { get; set; } + } } diff --git a/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs index 9ad446aab6..b8cb13b7f1 100644 --- a/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessApprovedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Name { get; set; } + public class EmergencyAccessApprovedViewModel : BaseMailModel + { + public string Name { get; set; } + } } diff --git a/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs index 2ab55a05eb..c7f457e338 100644 --- a/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessConfirmedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Name { get; set; } + public class EmergencyAccessConfirmedViewModel : BaseMailModel + { + public string Name { get; set; } + } } diff --git a/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs index fa432c5b70..a211208c47 100644 --- a/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessInvitedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Name { get; set; } - public string Id { get; set; } - public string Email { get; set; } - public string Token { get; set; } - public string Url => $"{WebVaultUrl}/accept-emergency?id={Id}&name={Name}&email={Email}&token={Token}"; + public class EmergencyAccessInvitedViewModel : BaseMailModel + { + public string Name { get; set; } + public string Id { get; set; } + public string Email { get; set; } + public string Token { get; set; } + public string Url => $"{WebVaultUrl}/accept-emergency?id={Id}&name={Name}&email={Email}&token={Token}"; + } } diff --git a/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs b/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs index dd3ae3dd82..2c0a287ca1 100644 --- a/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessRecoveryTimedOutViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Name { get; set; } - public string Action { get; set; } + public class EmergencyAccessRecoveryTimedOutViewModel : BaseMailModel + { + public string Name { get; set; } + public string Action { get; set; } + } } diff --git a/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs b/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs index 3811b49ff0..bea6059fc4 100644 --- a/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessRecoveryViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Name { get; set; } - public string Action { get; set; } - public int DaysLeft { get; set; } + public class EmergencyAccessRecoveryViewModel : BaseMailModel + { + public string Name { get; set; } + public string Action { get; set; } + public int DaysLeft { get; set; } + } } diff --git a/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs index 101cb9c167..4cf1887261 100644 --- a/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class EmergencyAccessRejectedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Name { get; set; } + public class EmergencyAccessRejectedViewModel : BaseMailModel + { + public string Name { get; set; } + } } diff --git a/src/Core/Models/Mail/FailedAuthAttemptsModel.cs b/src/Core/Models/Mail/FailedAuthAttemptsModel.cs index 8ef66061d8..030616d35b 100644 --- a/src/Core/Models/Mail/FailedAuthAttemptsModel.cs +++ b/src/Core/Models/Mail/FailedAuthAttemptsModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class FailedAuthAttemptsModel : NewDeviceLoggedInModel +namespace Bit.Core.Models.Mail { - public string AffectedEmail { get; set; } + public class FailedAuthAttemptsModel : NewDeviceLoggedInModel + { + public string AffectedEmail { get; set; } + } } diff --git a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs index 7e9d8ee193..97f028253d 100644 --- a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs +++ b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs @@ -1,16 +1,17 @@ -namespace Bit.Core.Models.Mail.FamiliesForEnterprise; - -public class FamiliesForEnterpriseOfferViewModel : BaseMailModel +namespace Bit.Core.Models.Mail.FamiliesForEnterprise { - public string SponsorOrgName { get; set; } - public string SponsoredEmail { get; set; } - public string SponsorshipToken { get; set; } - public bool ExistingAccount { get; set; } - public string Url => string.Concat( - WebVaultUrl, - "/accept-families-for-enterprise", - $"?token={SponsorshipToken}", - $"&email={SponsoredEmail}", - ExistingAccount ? "" : "®ister=true" - ); + public class FamiliesForEnterpriseOfferViewModel : BaseMailModel + { + public string SponsorOrgName { get; set; } + public string SponsoredEmail { get; set; } + public string SponsorshipToken { get; set; } + public bool ExistingAccount { get; set; } + public string Url => string.Concat( + WebVaultUrl, + "/accept-families-for-enterprise", + $"?token={SponsorshipToken}", + $"&email={SponsoredEmail}", + ExistingAccount ? "" : "®ister=true" + ); + } } diff --git a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs index 08c445b6f3..b15717c87c 100644 --- a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs +++ b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail.FamiliesForEnterprise; - -public class FamiliesForEnterpriseSponsorshipRevertingViewModel : BaseMailModel +namespace Bit.Core.Models.Mail.FamiliesForEnterprise { - public DateTime ExpirationDate { get; set; } + public class FamiliesForEnterpriseSponsorshipRevertingViewModel : BaseMailModel + { + public DateTime ExpirationDate { get; set; } + } } diff --git a/src/Core/Models/Mail/IMailQueueMessage.cs b/src/Core/Models/Mail/IMailQueueMessage.cs index 085e811c50..37c09c90e3 100644 --- a/src/Core/Models/Mail/IMailQueueMessage.cs +++ b/src/Core/Models/Mail/IMailQueueMessage.cs @@ -1,11 +1,12 @@ -namespace Bit.Core.Models.Mail; - -public interface IMailQueueMessage +namespace Bit.Core.Models.Mail { - string Subject { get; set; } - IEnumerable ToEmails { get; set; } - IEnumerable BccEmails { get; set; } - string Category { get; set; } - string TemplateName { get; set; } - object Model { get; set; } + public interface IMailQueueMessage + { + string Subject { get; set; } + IEnumerable ToEmails { get; set; } + IEnumerable BccEmails { get; set; } + string Category { get; set; } + string TemplateName { get; set; } + object Model { get; set; } + } } diff --git a/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs b/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs index 29c40bf920..7a3bdacea7 100644 --- a/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs +++ b/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Models.Mail; - -public class InvoiceUpcomingViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public decimal AmountDue { get; set; } - public DateTime DueDate { get; set; } - public List Items { get; set; } - public bool MentionInvoices { get; set; } + public class InvoiceUpcomingViewModel : BaseMailModel + { + public decimal AmountDue { get; set; } + public DateTime DueDate { get; set; } + public List Items { get; set; } + public bool MentionInvoices { get; set; } + } } diff --git a/src/Core/Models/Mail/LicenseExpiredViewModel.cs b/src/Core/Models/Mail/LicenseExpiredViewModel.cs index 922b35cfb1..70f5f32cd2 100644 --- a/src/Core/Models/Mail/LicenseExpiredViewModel.cs +++ b/src/Core/Models/Mail/LicenseExpiredViewModel.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Mail; - -public class LicenseExpiredViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string OrganizationName { get; set; } - public bool IsOrganization => !string.IsNullOrWhiteSpace(OrganizationName); + public class LicenseExpiredViewModel : BaseMailModel + { + public string OrganizationName { get; set; } + public bool IsOrganization => !string.IsNullOrWhiteSpace(OrganizationName); + } } diff --git a/src/Core/Models/Mail/MailMessage.cs b/src/Core/Models/Mail/MailMessage.cs index df444c77f5..1ccb87acf4 100644 --- a/src/Core/Models/Mail/MailMessage.cs +++ b/src/Core/Models/Mail/MailMessage.cs @@ -1,12 +1,13 @@ -namespace Bit.Core.Models.Mail; - -public class MailMessage +namespace Bit.Core.Models.Mail { - public string Subject { get; set; } - public IEnumerable ToEmails { get; set; } - public IEnumerable BccEmails { get; set; } - public string HtmlContent { get; set; } - public string TextContent { get; set; } - public string Category { get; set; } - public IDictionary MetaData { get; set; } + public class MailMessage + { + public string Subject { get; set; } + public IEnumerable ToEmails { get; set; } + public IEnumerable BccEmails { get; set; } + public string HtmlContent { get; set; } + public string TextContent { get; set; } + public string Category { get; set; } + public IDictionary MetaData { get; set; } + } } diff --git a/src/Core/Models/Mail/MailQueueMessage.cs b/src/Core/Models/Mail/MailQueueMessage.cs index d413c5f1a5..2aa2b3c65a 100644 --- a/src/Core/Models/Mail/MailQueueMessage.cs +++ b/src/Core/Models/Mail/MailQueueMessage.cs @@ -1,28 +1,29 @@ using System.Text.Json.Serialization; using Bit.Core.Utilities; -namespace Bit.Core.Models.Mail; - -public class MailQueueMessage : IMailQueueMessage +namespace Bit.Core.Models.Mail { - public string Subject { get; set; } - public IEnumerable ToEmails { get; set; } - public IEnumerable BccEmails { get; set; } - public string Category { get; set; } - public string TemplateName { get; set; } - - [JsonConverter(typeof(HandlebarsObjectJsonConverter))] - public object Model { get; set; } - - public MailQueueMessage() { } - - public MailQueueMessage(MailMessage message, string templateName, object model) + public class MailQueueMessage : IMailQueueMessage { - Subject = message.Subject; - ToEmails = message.ToEmails; - BccEmails = message.BccEmails; - Category = string.IsNullOrEmpty(message.Category) ? templateName : message.Category; - TemplateName = templateName; - Model = model; + public string Subject { get; set; } + public IEnumerable ToEmails { get; set; } + public IEnumerable BccEmails { get; set; } + public string Category { get; set; } + public string TemplateName { get; set; } + + [JsonConverter(typeof(HandlebarsObjectJsonConverter))] + public object Model { get; set; } + + public MailQueueMessage() { } + + public MailQueueMessage(MailMessage message, string templateName, object model) + { + Subject = message.Subject; + ToEmails = message.ToEmails; + BccEmails = message.BccEmails; + Category = string.IsNullOrEmpty(message.Category) ? templateName : message.Category; + TemplateName = templateName; + Model = model; + } } } diff --git a/src/Core/Models/Mail/MasterPasswordHintViewModel.cs b/src/Core/Models/Mail/MasterPasswordHintViewModel.cs index 01eb883a28..d2cfff49ed 100644 --- a/src/Core/Models/Mail/MasterPasswordHintViewModel.cs +++ b/src/Core/Models/Mail/MasterPasswordHintViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class MasterPasswordHintViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Hint { get; set; } + public class MasterPasswordHintViewModel : BaseMailModel + { + public string Hint { get; set; } + } } diff --git a/src/Core/Models/Mail/NewDeviceLoggedInModel.cs b/src/Core/Models/Mail/NewDeviceLoggedInModel.cs index 6d55a19b64..ee550fc4e6 100644 --- a/src/Core/Models/Mail/NewDeviceLoggedInModel.cs +++ b/src/Core/Models/Mail/NewDeviceLoggedInModel.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Models.Mail; - -public class NewDeviceLoggedInModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string TheDate { get; set; } - public string TheTime { get; set; } - public string TimeZone { get; set; } - public string IpAddress { get; set; } - public string DeviceType { get; set; } + public class NewDeviceLoggedInModel : BaseMailModel + { + public string TheDate { get; set; } + public string TheTime { get; set; } + public string TimeZone { get; set; } + public string IpAddress { get; set; } + public string DeviceType { get; set; } + } } diff --git a/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs b/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs index 87f87b1c69..44299c3905 100644 --- a/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs +++ b/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationSeatsAutoscaledViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public Guid OrganizationId { get; set; } - public int InitialSeatCount { get; set; } - public int CurrentSeatCount { get; set; } + public class OrganizationSeatsAutoscaledViewModel : BaseMailModel + { + public Guid OrganizationId { get; set; } + public int InitialSeatCount { get; set; } + public int CurrentSeatCount { get; set; } + } } diff --git a/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs b/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs index cdfb57b2dc..5fcdee7040 100644 --- a/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationSeatsMaxReachedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public Guid OrganizationId { get; set; } - public int MaxSeatCount { get; set; } + public class OrganizationSeatsMaxReachedViewModel : BaseMailModel + { + public Guid OrganizationId { get; set; } + public int MaxSeatCount { get; set; } + } } diff --git a/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs b/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs index 919463c2c3..5bfd502a50 100644 --- a/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationUserAcceptedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public string UserIdentifier { get; set; } + public class OrganizationUserAcceptedViewModel : BaseMailModel + { + public Guid OrganizationId { get; set; } + public string OrganizationName { get; set; } + public string UserIdentifier { get; set; } + } } diff --git a/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs b/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs index 61e7107742..e15cf54eed 100644 --- a/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationUserConfirmedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string OrganizationName { get; set; } + public class OrganizationUserConfirmedViewModel : BaseMailModel + { + public string OrganizationName { get; set; } + } } diff --git a/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs b/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs index 4bf9fbb863..0e13fa6639 100644 --- a/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs @@ -1,20 +1,21 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationUserInvitedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string OrganizationName { get; set; } - public string OrganizationId { get; set; } - public string OrganizationUserId { get; set; } - public string Email { get; set; } - public string OrganizationNameUrlEncoded { get; set; } - public string Token { get; set; } - public string ExpirationDate { get; set; } - public string Url => string.Format("{0}/accept-organization?organizationId={1}&" + - "organizationUserId={2}&email={3}&organizationName={4}&token={5}", - WebVaultUrl, - OrganizationId, - OrganizationUserId, - Email, - OrganizationNameUrlEncoded, - Token); + public class OrganizationUserInvitedViewModel : BaseMailModel + { + public string OrganizationName { get; set; } + public string OrganizationId { get; set; } + public string OrganizationUserId { get; set; } + public string Email { get; set; } + public string OrganizationNameUrlEncoded { get; set; } + public string Token { get; set; } + public string ExpirationDate { get; set; } + public string Url => string.Format("{0}/accept-organization?organizationId={1}&" + + "organizationUserId={2}&email={3}&organizationName={4}&token={5}", + WebVaultUrl, + OrganizationId, + OrganizationUserId, + Email, + OrganizationNameUrlEncoded, + Token); + } } diff --git a/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs b/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs index 46020ae46a..9a92f0e0be 100644 --- a/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationUserRemovedForPolicySingleOrgViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string OrganizationName { get; set; } + public class OrganizationUserRemovedForPolicySingleOrgViewModel : BaseMailModel + { + public string OrganizationName { get; set; } + } } diff --git a/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs b/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs index cd4528ad50..10beaa5d7e 100644 --- a/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class OrganizationUserRemovedForPolicyTwoStepViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string OrganizationName { get; set; } + public class OrganizationUserRemovedForPolicyTwoStepViewModel : BaseMailModel + { + public string OrganizationName { get; set; } + } } diff --git a/src/Core/Models/Mail/PasswordlessSignInModel.cs b/src/Core/Models/Mail/PasswordlessSignInModel.cs index 07754cf804..a09d5f7b0e 100644 --- a/src/Core/Models/Mail/PasswordlessSignInModel.cs +++ b/src/Core/Models/Mail/PasswordlessSignInModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class PasswordlessSignInModel +namespace Bit.Core.Models.Mail { - public string Url { get; set; } + public class PasswordlessSignInModel + { + public string Url { get; set; } + } } diff --git a/src/Core/Models/Mail/PaymentFailedViewModel.cs b/src/Core/Models/Mail/PaymentFailedViewModel.cs index 387feeb022..1eb5e6952c 100644 --- a/src/Core/Models/Mail/PaymentFailedViewModel.cs +++ b/src/Core/Models/Mail/PaymentFailedViewModel.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.Mail; - -public class PaymentFailedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public decimal Amount { get; set; } - public bool MentionInvoices { get; set; } + public class PaymentFailedViewModel : BaseMailModel + { + public decimal Amount { get; set; } + public bool MentionInvoices { get; set; } + } } diff --git a/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs b/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs index f351a5fe1b..daaba8a49d 100644 --- a/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs @@ -1,13 +1,14 @@ -namespace Bit.Core.Models.Mail.Provider; - -public class ProviderSetupInviteViewModel : BaseMailModel +namespace Bit.Core.Models.Mail.Provider { - public string ProviderId { get; set; } - public string Email { get; set; } - public string Token { get; set; } - public string Url => string.Format("{0}/providers/setup-provider?providerId={1}&email={2}&token={3}", - WebVaultUrl, - ProviderId, - Email, - Token); + public class ProviderSetupInviteViewModel : BaseMailModel + { + public string ProviderId { get; set; } + public string Email { get; set; } + public string Token { get; set; } + public string Url => string.Format("{0}/providers/setup-provider?providerId={1}&email={2}&token={3}", + WebVaultUrl, + ProviderId, + Email, + Token); + } } diff --git a/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs b/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs index 30d24ad1e9..8a8716c46c 100644 --- a/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail.Provider; - -public class ProviderUserConfirmedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail.Provider { - public string ProviderName { get; set; } + public class ProviderUserConfirmedViewModel : BaseMailModel + { + public string ProviderName { get; set; } + } } diff --git a/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs b/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs index e418d30f21..964c517593 100644 --- a/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs @@ -1,19 +1,20 @@ -namespace Bit.Core.Models.Mail.Provider; - -public class ProviderUserInvitedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail.Provider { - public string ProviderName { get; set; } - public string ProviderId { get; set; } - public string ProviderUserId { get; set; } - public string Email { get; set; } - public string ProviderNameUrlEncoded { get; set; } - public string Token { get; set; } - public string Url => string.Format("{0}/providers/accept-provider?providerId={1}&" + - "providerUserId={2}&email={3}&providerName={4}&token={5}", - WebVaultUrl, - ProviderId, - ProviderUserId, - Email, - ProviderNameUrlEncoded, - Token); + public class ProviderUserInvitedViewModel : BaseMailModel + { + public string ProviderName { get; set; } + public string ProviderId { get; set; } + public string ProviderUserId { get; set; } + public string Email { get; set; } + public string ProviderNameUrlEncoded { get; set; } + public string Token { get; set; } + public string Url => string.Format("{0}/providers/accept-provider?providerId={1}&" + + "providerUserId={2}&email={3}&providerName={4}&token={5}", + WebVaultUrl, + ProviderId, + ProviderUserId, + Email, + ProviderNameUrlEncoded, + Token); + } } diff --git a/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs b/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs index aef9d9c593..4d64ed3d70 100644 --- a/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail.Provider; - -public class ProviderUserRemovedViewModel : BaseMailModel +namespace Bit.Core.Models.Mail.Provider { - public string ProviderName { get; set; } + public class ProviderUserRemovedViewModel : BaseMailModel + { + public string ProviderName { get; set; } + } } diff --git a/src/Core/Models/Mail/RecoverTwoFactorModel.cs b/src/Core/Models/Mail/RecoverTwoFactorModel.cs index b62f076711..f9b8cb5d45 100644 --- a/src/Core/Models/Mail/RecoverTwoFactorModel.cs +++ b/src/Core/Models/Mail/RecoverTwoFactorModel.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Models.Mail; - -public class RecoverTwoFactorModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string TheDate { get; set; } - public string TheTime { get; set; } - public string TimeZone { get; set; } - public string IpAddress { get; set; } + public class RecoverTwoFactorModel : BaseMailModel + { + public string TheDate { get; set; } + public string TheTime { get; set; } + public string TimeZone { get; set; } + public string IpAddress { get; set; } + } } diff --git a/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs b/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs index 6e45df5305..ed35d3e97b 100644 --- a/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs +++ b/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Models.Mail; - -public class UpdateTempPasswordViewModel +namespace Bit.Core.Models.Mail { - public string UserName { get; set; } + public class UpdateTempPasswordViewModel + { + public string UserName { get; set; } + } } diff --git a/src/Core/Models/Mail/VerifyDeleteModel.cs b/src/Core/Models/Mail/VerifyDeleteModel.cs index 22775aae15..dbe7199818 100644 --- a/src/Core/Models/Mail/VerifyDeleteModel.cs +++ b/src/Core/Models/Mail/VerifyDeleteModel.cs @@ -1,15 +1,16 @@ -namespace Bit.Core.Models.Mail; - -public class VerifyDeleteModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Url => string.Format("{0}/verify-recover-delete?userId={1}&token={2}&email={3}", - WebVaultUrl, - UserId, - Token, - EmailEncoded); + public class VerifyDeleteModel : BaseMailModel + { + public string Url => string.Format("{0}/verify-recover-delete?userId={1}&token={2}&email={3}", + WebVaultUrl, + UserId, + Token, + EmailEncoded); - public Guid UserId { get; set; } - public string Email { get; set; } - public string EmailEncoded { get; set; } - public string Token { get; set; } + public Guid UserId { get; set; } + public string Email { get; set; } + public string EmailEncoded { get; set; } + public string Token { get; set; } + } } diff --git a/src/Core/Models/Mail/VerifyEmailModel.cs b/src/Core/Models/Mail/VerifyEmailModel.cs index 17b2eba864..934ac590f9 100644 --- a/src/Core/Models/Mail/VerifyEmailModel.cs +++ b/src/Core/Models/Mail/VerifyEmailModel.cs @@ -1,12 +1,13 @@ -namespace Bit.Core.Models.Mail; - -public class VerifyEmailModel : BaseMailModel +namespace Bit.Core.Models.Mail { - public string Url => string.Format("{0}/verify-email?userId={1}&token={2}", - WebVaultUrl, - UserId, - Token); + public class VerifyEmailModel : BaseMailModel + { + public string Url => string.Format("{0}/verify-email?userId={1}&token={2}", + WebVaultUrl, + UserId, + Token); - public Guid UserId { get; set; } - public string Token { get; set; } + public Guid UserId { get; set; } + public string Token { get; set; } + } } diff --git a/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs b/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs index 204e165d05..8b46eb831a 100644 --- a/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs +++ b/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Models.OrganizationConnectionConfigs; - -public class BillingSyncConfig +namespace Bit.Core.Models.OrganizationConnectionConfigs { - public string BillingSyncKey { get; set; } - public Guid CloudOrganizationId { get; set; } + public class BillingSyncConfig + { + public string BillingSyncKey { get; set; } + public Guid CloudOrganizationId { get; set; } + } } diff --git a/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs b/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs index 63a1606cb2..a7eeb632ba 100644 --- a/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs +++ b/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs @@ -1,11 +1,12 @@ using System.Text.Json.Serialization; using Bit.Core.Enums; -namespace Bit.Core.Models.OrganizationConnectionConfigs; - -public class ScimConfig +namespace Bit.Core.Models.OrganizationConnectionConfigs { - public bool Enabled { get; set; } - [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public ScimProviderType? ScimProvider { get; set; } + public class ScimConfig + { + public bool Enabled { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public ScimProviderType? ScimProvider { get; set; } + } } diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index 4cbdae8b61..7f34b25625 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -1,46 +1,47 @@ using Bit.Core.Enums; -namespace Bit.Core.Models; - -public class PushNotificationData +namespace Bit.Core.Models { - public PushNotificationData(PushType type, T payload, string contextId) + public class PushNotificationData { - Type = type; - Payload = payload; - ContextId = contextId; + public PushNotificationData(PushType type, T payload, string contextId) + { + Type = type; + Payload = payload; + ContextId = contextId; + } + + public PushType Type { get; set; } + public T Payload { get; set; } + public string ContextId { get; set; } } - public PushType Type { get; set; } - public T Payload { get; set; } - public string ContextId { get; set; } -} + public class SyncCipherPushNotification + { + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public IEnumerable CollectionIds { get; set; } + public DateTime RevisionDate { get; set; } + } -public class SyncCipherPushNotification -{ - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public IEnumerable CollectionIds { get; set; } - public DateTime RevisionDate { get; set; } -} + public class SyncFolderPushNotification + { + public Guid Id { get; set; } + public Guid UserId { get; set; } + public DateTime RevisionDate { get; set; } + } -public class SyncFolderPushNotification -{ - public Guid Id { get; set; } - public Guid UserId { get; set; } - public DateTime RevisionDate { get; set; } -} + public class UserPushNotification + { + public Guid UserId { get; set; } + public DateTime Date { get; set; } + } -public class UserPushNotification -{ - public Guid UserId { get; set; } - public DateTime Date { get; set; } -} - -public class SyncSendPushNotification -{ - public Guid Id { get; set; } - public Guid UserId { get; set; } - public DateTime RevisionDate { get; set; } + public class SyncSendPushNotification + { + public Guid Id { get; set; } + public Guid UserId { get; set; } + public DateTime RevisionDate { get; set; } + } } diff --git a/src/Core/Models/StaticStore/Plan.cs b/src/Core/Models/StaticStore/Plan.cs index 25a947d185..98686f5ab4 100644 --- a/src/Core/Models/StaticStore/Plan.cs +++ b/src/Core/Models/StaticStore/Plan.cs @@ -1,54 +1,55 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.StaticStore; - -public class Plan +namespace Bit.Core.Models.StaticStore { - public PlanType Type { get; set; } - public ProductType Product { get; set; } - public string Name { get; set; } - public bool IsAnnual { get; set; } - public string NameLocalizationKey { get; set; } - public string DescriptionLocalizationKey { get; set; } - public bool CanBeUsedByBusiness { get; set; } - public int BaseSeats { get; set; } - public short? BaseStorageGb { get; set; } - public short? MaxCollections { get; set; } - public short? MaxUsers { get; set; } - public bool AllowSeatAutoscale { get; set; } + public class Plan + { + public PlanType Type { get; set; } + public ProductType Product { get; set; } + public string Name { get; set; } + public bool IsAnnual { get; set; } + public string NameLocalizationKey { get; set; } + public string DescriptionLocalizationKey { get; set; } + public bool CanBeUsedByBusiness { get; set; } + public int BaseSeats { get; set; } + public short? BaseStorageGb { get; set; } + public short? MaxCollections { get; set; } + public short? MaxUsers { get; set; } + public bool AllowSeatAutoscale { get; set; } - public bool HasAdditionalSeatsOption { get; set; } - public int? MaxAdditionalSeats { get; set; } - public bool HasAdditionalStorageOption { get; set; } - public short? MaxAdditionalStorage { get; set; } - public bool HasPremiumAccessOption { get; set; } - public int? TrialPeriodDays { get; set; } + public bool HasAdditionalSeatsOption { get; set; } + public int? MaxAdditionalSeats { get; set; } + public bool HasAdditionalStorageOption { get; set; } + public short? MaxAdditionalStorage { get; set; } + public bool HasPremiumAccessOption { get; set; } + public int? TrialPeriodDays { get; set; } - public bool HasSelfHost { get; set; } - public bool HasPolicies { get; set; } - public bool HasGroups { get; set; } - public bool HasDirectory { get; set; } - public bool HasEvents { get; set; } - public bool HasTotp { get; set; } - public bool Has2fa { get; set; } - public bool HasApi { get; set; } - public bool HasSso { get; set; } - public bool HasKeyConnector { get; set; } - public bool HasScim { get; set; } - public bool HasResetPassword { get; set; } - public bool UsersGetPremium { get; set; } + public bool HasSelfHost { get; set; } + public bool HasPolicies { get; set; } + public bool HasGroups { get; set; } + public bool HasDirectory { get; set; } + public bool HasEvents { get; set; } + public bool HasTotp { get; set; } + public bool Has2fa { get; set; } + public bool HasApi { get; set; } + public bool HasSso { get; set; } + public bool HasKeyConnector { get; set; } + public bool HasScim { get; set; } + public bool HasResetPassword { get; set; } + public bool UsersGetPremium { get; set; } - public int UpgradeSortOrder { get; set; } - public int DisplaySortOrder { get; set; } - public int? LegacyYear { get; set; } - public bool Disabled { get; set; } + public int UpgradeSortOrder { get; set; } + public int DisplaySortOrder { get; set; } + public int? LegacyYear { get; set; } + public bool Disabled { get; set; } - public string StripePlanId { get; set; } - public string StripeSeatPlanId { get; set; } - public string StripeStoragePlanId { get; set; } - public string StripePremiumAccessPlanId { get; set; } - public decimal BasePrice { get; set; } - public decimal SeatPrice { get; set; } - public decimal AdditionalStoragePricePerGb { get; set; } - public decimal PremiumAccessOptionPrice { get; set; } + public string StripePlanId { get; set; } + public string StripeSeatPlanId { get; set; } + public string StripeStoragePlanId { get; set; } + public string StripePremiumAccessPlanId { get; set; } + public decimal BasePrice { get; set; } + public decimal SeatPrice { get; set; } + public decimal AdditionalStoragePricePerGb { get; set; } + public decimal PremiumAccessOptionPrice { get; set; } + } } diff --git a/src/Core/Models/StaticStore/SponsoredPlan.cs b/src/Core/Models/StaticStore/SponsoredPlan.cs index bcd23874a7..e1a8dbd96a 100644 --- a/src/Core/Models/StaticStore/SponsoredPlan.cs +++ b/src/Core/Models/StaticStore/SponsoredPlan.cs @@ -1,13 +1,14 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.Models.StaticStore; - -public class SponsoredPlan +namespace Bit.Core.Models.StaticStore { - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public ProductType SponsoredProductType { get; set; } - public ProductType SponsoringProductType { get; set; } - public string StripePlanId { get; set; } - public Func UsersCanSponsor { get; set; } + public class SponsoredPlan + { + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public ProductType SponsoredProductType { get; set; } + public ProductType SponsoringProductType { get; set; } + public string StripePlanId { get; set; } + public Func UsersCanSponsor { get; set; } + } } diff --git a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs b/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs index f32576c407..1672dc407c 100644 --- a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs +++ b/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs @@ -1,48 +1,49 @@ -namespace Bit.Core.Models.BitStripe; - -// Stripe's SubscriptionListOptions model has a complex input for date filters. -// It expects a dictionary, and has lots of validation rules around what can have a value and what can't. -// To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT. -// ___ -// Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model. -public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions +namespace Bit.Core.Models.BitStripe { - public DateTime? CurrentPeriodEndDate { get; set; } - public string CurrentPeriodEndRange { get; set; } = "lt"; - public bool SelectAll { get; set; } - public new Stripe.DateRangeOptions CurrentPeriodEnd + // Stripe's SubscriptionListOptions model has a complex input for date filters. + // It expects a dictionary, and has lots of validation rules around what can have a value and what can't. + // To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT. + // ___ + // Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model. + public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions { - get + public DateTime? CurrentPeriodEndDate { get; set; } + public string CurrentPeriodEndRange { get; set; } = "lt"; + public bool SelectAll { get; set; } + public new Stripe.DateRangeOptions CurrentPeriodEnd { - return CurrentPeriodEndDate.HasValue ? - new Stripe.DateRangeOptions() + get + { + return CurrentPeriodEndDate.HasValue ? + new Stripe.DateRangeOptions() + { + LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, + GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null + } : + null; + } + } + + public Stripe.SubscriptionListOptions ToStripeApiOptions() + { + var stripeApiOptions = (Stripe.SubscriptionListOptions)this; + + if (SelectAll) + { + stripeApiOptions.EndingBefore = null; + stripeApiOptions.StartingAfter = null; + } + + if (CurrentPeriodEndDate.HasValue) + { + stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions() { LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - } : - null; + }; + } + + return stripeApiOptions; } } - - public Stripe.SubscriptionListOptions ToStripeApiOptions() - { - var stripeApiOptions = (Stripe.SubscriptionListOptions)this; - - if (SelectAll) - { - stripeApiOptions.EndingBefore = null; - stripeApiOptions.StartingAfter = null; - } - - if (CurrentPeriodEndDate.HasValue) - { - stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions() - { - LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, - GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - }; - } - - return stripeApiOptions; - } } diff --git a/src/Core/Models/TwoFactorProvider.cs b/src/Core/Models/TwoFactorProvider.cs index 0ff791ff8d..7e48ed397a 100644 --- a/src/Core/Models/TwoFactorProvider.cs +++ b/src/Core/Models/TwoFactorProvider.cs @@ -2,65 +2,66 @@ using Bit.Core.Enums; using Fido2NetLib.Objects; -namespace Bit.Core.Models; - -public class TwoFactorProvider +namespace Bit.Core.Models { - public bool Enabled { get; set; } - public Dictionary MetaData { get; set; } = new Dictionary(); - - public class WebAuthnData + public class TwoFactorProvider { - public WebAuthnData() { } + public bool Enabled { get; set; } + public Dictionary MetaData { get; set; } = new Dictionary(); - public WebAuthnData(dynamic o) + public class WebAuthnData { - Name = o.Name; - try + public WebAuthnData() { } + + public WebAuthnData(dynamic o) { - Descriptor = o.Descriptor; - } - catch - { - // Fallback for older newtonsoft serialized tokens. - if (o.Descriptor.Type == 0) + Name = o.Name; + try { - o.Descriptor.Type = "public-key"; + Descriptor = o.Descriptor; } - Descriptor = JsonSerializer.Deserialize(o.Descriptor.ToString(), - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + catch + { + // Fallback for older newtonsoft serialized tokens. + if (o.Descriptor.Type == 0) + { + o.Descriptor.Type = "public-key"; + } + Descriptor = JsonSerializer.Deserialize(o.Descriptor.ToString(), + new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + } + PublicKey = o.PublicKey; + UserHandle = o.UserHandle; + SignatureCounter = o.SignatureCounter; + CredType = o.CredType; + RegDate = o.RegDate; + AaGuid = o.AaGuid; + Migrated = o.Migrated; } - PublicKey = o.PublicKey; - UserHandle = o.UserHandle; - SignatureCounter = o.SignatureCounter; - CredType = o.CredType; - RegDate = o.RegDate; - AaGuid = o.AaGuid; - Migrated = o.Migrated; + + public string Name { get; set; } + public PublicKeyCredentialDescriptor Descriptor { get; internal set; } + public byte[] PublicKey { get; internal set; } + public byte[] UserHandle { get; internal set; } + public uint SignatureCounter { get; set; } + public string CredType { get; internal set; } + public DateTime RegDate { get; internal set; } + public Guid AaGuid { get; internal set; } + public bool Migrated { get; internal set; } } - public string Name { get; set; } - public PublicKeyCredentialDescriptor Descriptor { get; internal set; } - public byte[] PublicKey { get; internal set; } - public byte[] UserHandle { get; internal set; } - public uint SignatureCounter { get; set; } - public string CredType { get; internal set; } - public DateTime RegDate { get; internal set; } - public Guid AaGuid { get; internal set; } - public bool Migrated { get; internal set; } - } - - public static bool RequiresPremium(TwoFactorProviderType type) - { - switch (type) + public static bool RequiresPremium(TwoFactorProviderType type) { - case TwoFactorProviderType.Duo: - case TwoFactorProviderType.YubiKey: - case TwoFactorProviderType.U2f: // Keep to ensure old U2f keys are considered premium - case TwoFactorProviderType.WebAuthn: - return true; - default: - return false; + switch (type) + { + case TwoFactorProviderType.Duo: + case TwoFactorProviderType.YubiKey: + case TwoFactorProviderType.U2f: // Keep to ensure old U2f keys are considered premium + case TwoFactorProviderType.WebAuthn: + return true; + default: + return false; + } } } } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs index 1a01562417..93b4d15bda 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs @@ -4,42 +4,43 @@ using Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; using Bit.Core.Repositories; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys; - -public class GetOrganizationApiKeyCommand : IGetOrganizationApiKeyCommand +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys { - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - - public GetOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) + public class GetOrganizationApiKeyCommand : IGetOrganizationApiKeyCommand { - _organizationApiKeyRepository = organizationApiKeyRepository; - } + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - public async Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType) - { - if (!Enum.IsDefined(organizationApiKeyType)) + public GetOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) { - throw new ArgumentOutOfRangeException(nameof(organizationApiKeyType), $"Invalid value for enum {nameof(OrganizationApiKeyType)}"); + _organizationApiKeyRepository = organizationApiKeyRepository; } - var apiKeys = await _organizationApiKeyRepository - .GetManyByOrganizationIdTypeAsync(organizationId, organizationApiKeyType); - - if (apiKeys == null || !apiKeys.Any()) + public async Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType) { - var apiKey = new OrganizationApiKey + if (!Enum.IsDefined(organizationApiKeyType)) { - OrganizationId = organizationId, - Type = organizationApiKeyType, - ApiKey = CoreHelpers.SecureRandomString(30), - RevisionDate = DateTime.UtcNow, - }; + throw new ArgumentOutOfRangeException(nameof(organizationApiKeyType), $"Invalid value for enum {nameof(OrganizationApiKeyType)}"); + } - await _organizationApiKeyRepository.CreateAsync(apiKey); - return apiKey; + var apiKeys = await _organizationApiKeyRepository + .GetManyByOrganizationIdTypeAsync(organizationId, organizationApiKeyType); + + if (apiKeys == null || !apiKeys.Any()) + { + var apiKey = new OrganizationApiKey + { + OrganizationId = organizationId, + Type = organizationApiKeyType, + ApiKey = CoreHelpers.SecureRandomString(30), + RevisionDate = DateTime.UtcNow, + }; + + await _organizationApiKeyRepository.CreateAsync(apiKey); + return apiKey; + } + + // NOTE: Currently we only allow one type of api key per organization + return apiKeys.Single(); } - - // NOTE: Currently we only allow one type of api key per organization - return apiKeys.Single(); } } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs index 5fcfdedd99..645fb10862 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; - -public interface IGetOrganizationApiKeyCommand +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces { - Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType); + public interface IGetOrganizationApiKeyCommand + { + Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs index a5cf51c3fe..85d0479870 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; - -public interface IRotateOrganizationApiKeyCommand +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces { - Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey); + public interface IRotateOrganizationApiKeyCommand + { + Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs index f43aaa5f34..967f399471 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs @@ -3,22 +3,23 @@ using Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; using Bit.Core.Repositories; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys; - -public class RotateOrganizationApiKeyCommand : IRotateOrganizationApiKeyCommand +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys { - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - - public RotateOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) + public class RotateOrganizationApiKeyCommand : IRotateOrganizationApiKeyCommand { - _organizationApiKeyRepository = organizationApiKeyRepository; - } + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - public async Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey) - { - organizationApiKey.ApiKey = CoreHelpers.SecureRandomString(30); - organizationApiKey.RevisionDate = DateTime.UtcNow; - await _organizationApiKeyRepository.UpsertAsync(organizationApiKey); - return organizationApiKey; + public RotateOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) + { + _organizationApiKeyRepository = organizationApiKeyRepository; + } + + public async Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey) + { + organizationApiKey.ApiKey = CoreHelpers.SecureRandomString(30); + organizationApiKey.RevisionDate = DateTime.UtcNow; + await _organizationApiKeyRepository.UpsertAsync(organizationApiKey); + return organizationApiKey; + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs index e3f308bc57..c54ef5dfc8 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs @@ -3,19 +3,20 @@ using Bit.Core.Models.Data.Organizations.OrganizationConnections; using Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections; - -public class CreateOrganizationConnectionCommand : ICreateOrganizationConnectionCommand +namespace Bit.Core.OrganizationFeatures.OrganizationConnections { - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - - public CreateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) + public class CreateOrganizationConnectionCommand : ICreateOrganizationConnectionCommand { - _organizationConnectionRepository = organizationConnectionRepository; - } + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - public async Task CreateAsync(OrganizationConnectionData connectionData) where T : new() - { - return await _organizationConnectionRepository.CreateAsync(connectionData.ToEntity()); + public CreateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) + { + _organizationConnectionRepository = organizationConnectionRepository; + } + + public async Task CreateAsync(OrganizationConnectionData connectionData) where T : new() + { + return await _organizationConnectionRepository.CreateAsync(connectionData.ToEntity()); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs index 7166059db5..784975780f 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs @@ -2,19 +2,20 @@ using Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections; - -public class DeleteOrganizationConnectionCommand : IDeleteOrganizationConnectionCommand +namespace Bit.Core.OrganizationFeatures.OrganizationConnections { - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - - public DeleteOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) + public class DeleteOrganizationConnectionCommand : IDeleteOrganizationConnectionCommand { - _organizationConnectionRepository = organizationConnectionRepository; - } + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - public async Task DeleteAsync(OrganizationConnection connection) - { - await _organizationConnectionRepository.DeleteAsync(connection); + public DeleteOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) + { + _organizationConnectionRepository = organizationConnectionRepository; + } + + public async Task DeleteAsync(OrganizationConnection connection) + { + await _organizationConnectionRepository.DeleteAsync(connection); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs index b31920b10a..c91985d75d 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations.OrganizationConnections; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; - -public interface ICreateOrganizationConnectionCommand +namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces { - Task CreateAsync(OrganizationConnectionData connectionData) where T : new(); + public interface ICreateOrganizationConnectionCommand + { + Task CreateAsync(OrganizationConnectionData connectionData) where T : new(); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs index 818609aef2..1b92a9fcf4 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; - -public interface IDeleteOrganizationConnectionCommand +namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces { - Task DeleteAsync(OrganizationConnection connection); + public interface IDeleteOrganizationConnectionCommand + { + Task DeleteAsync(OrganizationConnection connection); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs index 742e89c970..d01fd0b9a0 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations.OrganizationConnections; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; - -public interface IUpdateOrganizationConnectionCommand +namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces { - Task UpdateAsync(OrganizationConnectionData connectionData) where T : new(); + public interface IUpdateOrganizationConnectionCommand + { + Task UpdateAsync(OrganizationConnectionData connectionData) where T : new(); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs index 0d872b6f1f..74aa08bd35 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs @@ -4,33 +4,34 @@ using Bit.Core.Models.Data.Organizations.OrganizationConnections; using Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections; - -public class UpdateOrganizationConnectionCommand : IUpdateOrganizationConnectionCommand +namespace Bit.Core.OrganizationFeatures.OrganizationConnections { - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - - public UpdateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) + public class UpdateOrganizationConnectionCommand : IUpdateOrganizationConnectionCommand { - _organizationConnectionRepository = organizationConnectionRepository; - } + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - public async Task UpdateAsync(OrganizationConnectionData connectionData) where T : new() - { - if (!connectionData.Id.HasValue) + public UpdateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) { - throw new Exception("Cannot update connection, Connection does not exist."); + _organizationConnectionRepository = organizationConnectionRepository; } - var connection = await _organizationConnectionRepository.GetByIdAsync(connectionData.Id.Value); - - if (connection == null) + public async Task UpdateAsync(OrganizationConnectionData connectionData) where T : new() { - throw new NotFoundException(); - } + if (!connectionData.Id.HasValue) + { + throw new Exception("Cannot update connection, Connection does not exist."); + } - var entity = connectionData.ToEntity(); - await _organizationConnectionRepository.UpsertAsync(entity); - return entity; + var connection = await _organizationConnectionRepository.GetByIdAsync(connectionData.Id.Value); + + if (connection == null) + { + throw new NotFoundException(); + } + + var entity = connectionData.ToEntity(); + await _organizationConnectionRepository.UpsertAsync(entity); + return entity; + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index e428318c50..94e59ab31b 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -13,64 +13,65 @@ using Bit.Core.Tokens; using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.OrganizationFeatures; - -public static class OrganizationServiceCollectionExtensions +namespace Bit.Core.OrganizationFeatures { - public static void AddOrganizationServices(this IServiceCollection services, IGlobalSettings globalSettings) + public static class OrganizationServiceCollectionExtensions { - services.AddScoped(); - services.AddTokenizers(); - services.AddOrganizationConnectionCommands(); - services.AddOrganizationSponsorshipCommands(globalSettings); - services.AddOrganizationApiKeyCommands(); - } - - private static void AddOrganizationConnectionCommands(this IServiceCollection services) - { - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - } - - private static void AddOrganizationSponsorshipCommands(this IServiceCollection services, IGlobalSettings globalSettings) - { - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - if (globalSettings.SelfHosted) + public static void AddOrganizationServices(this IServiceCollection services, IGlobalSettings globalSettings) { - services.AddScoped(); + services.AddScoped(); + services.AddTokenizers(); + services.AddOrganizationConnectionCommands(); + services.AddOrganizationSponsorshipCommands(globalSettings); + services.AddOrganizationApiKeyCommands(); } - else + + private static void AddOrganizationConnectionCommands(this IServiceCollection services) { - services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); } - } - private static void AddOrganizationApiKeyCommands(this IServiceCollection services) - { - services.AddScoped(); - services.AddScoped(); - } + private static void AddOrganizationSponsorshipCommands(this IServiceCollection services, IGlobalSettings globalSettings) + { + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + if (globalSettings.SelfHosted) + { + services.AddScoped(); + } + else + { + services.AddScoped(); + } + } - private static void AddTokenizers(this IServiceCollection services) - { - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - OrganizationSponsorshipOfferTokenable.ClearTextPrefix, - OrganizationSponsorshipOfferTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider()) - ); + private static void AddOrganizationApiKeyCommands(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + } + + private static void AddTokenizers(this IServiceCollection services) + { + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + OrganizationSponsorshipOfferTokenable.ClearTextPrefix, + OrganizationSponsorshipOfferTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider()) + ); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs index 111cec395c..71ce0b4fa2 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs @@ -2,38 +2,39 @@ using Bit.Core.Exceptions; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -public abstract class CancelSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - protected readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - protected readonly IOrganizationRepository _organizationRepository; - - public CancelSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) + public abstract class CancelSponsorshipCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationRepository = organizationRepository; - } + protected readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + protected readonly IOrganizationRepository _organizationRepository; - protected virtual async Task DeleteSponsorshipAsync(OrganizationSponsorship sponsorship = null) - { - if (sponsorship == null) + public CancelSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) { - return; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationRepository = organizationRepository; } - await _organizationSponsorshipRepository.DeleteAsync(sponsorship); - } - - protected async Task MarkToDeleteSponsorshipAsync(OrganizationSponsorship sponsorship) - { - if (sponsorship == null) + protected virtual async Task DeleteSponsorshipAsync(OrganizationSponsorship sponsorship = null) { - throw new BadRequestException("The sponsorship you are trying to cancel does not exist"); + if (sponsorship == null) + { + return; + } + + await _organizationSponsorshipRepository.DeleteAsync(sponsorship); } - sponsorship.ToDelete = true; - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + protected async Task MarkToDeleteSponsorshipAsync(OrganizationSponsorship sponsorship) + { + if (sponsorship == null) + { + throw new BadRequestException("The sponsorship you are trying to cancel does not exist"); + } + + sponsorship.ToDelete = true; + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs index 76c180f74e..d12765eff8 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs @@ -3,30 +3,31 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class CloudRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - public CloudRevokeSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) + public class CloudRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand { - } - - public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) - { - if (sponsorship == null) + public CloudRevokeSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) { - throw new BadRequestException("You are not currently sponsoring an organization."); } - if (sponsorship.SponsoredOrganizationId == null) + public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) { - await base.DeleteSponsorshipAsync(sponsorship); - } - else - { - await MarkToDeleteSponsorshipAsync(sponsorship); + if (sponsorship == null) + { + throw new BadRequestException("You are not currently sponsoring an organization."); + } + + if (sponsorship.SponsoredOrganizationId == null) + { + await base.DeleteSponsorshipAsync(sponsorship); + } + else + { + await MarkToDeleteSponsorshipAsync(sponsorship); + } } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs index d0569278bb..c4da82c96c 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs @@ -7,127 +7,128 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class CloudSyncSponsorshipsCommand : ICloudSyncSponsorshipsCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IEventService _eventService; - - public CloudSyncSponsorshipsCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IEventService eventService) + public class CloudSyncSponsorshipsCommand : ICloudSyncSponsorshipsCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _eventService = eventService; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IEventService _eventService; - public async Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData) - { - if (sponsoringOrg == null) + public CloudSyncSponsorshipsCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IEventService eventService) { - throw new BadRequestException("Failed to sync sponsorship - missing organization."); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _eventService = eventService; } - var (processedSponsorshipsData, sponsorshipsToEmailOffer) = sponsorshipsData.Any() ? - await DoSyncAsync(sponsoringOrg, sponsorshipsData) : - (sponsorshipsData, Array.Empty()); - - await RecordEvent(sponsoringOrg); - - return (new OrganizationSponsorshipSyncData + public async Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData) { - SponsorshipsBatch = processedSponsorshipsData - }, sponsorshipsToEmailOffer); - } + if (sponsoringOrg == null) + { + throw new BadRequestException("Failed to sync sponsorship - missing organization."); + } - private async Task<(IEnumerable data, IEnumerable toOffer)> DoSyncAsync(Organization sponsoringOrg, IEnumerable sponsorshipsData) - { - var existingSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrg.Id)) - .ToDictionary(i => i.SponsoringOrganizationUserId); + var (processedSponsorshipsData, sponsorshipsToEmailOffer) = sponsorshipsData.Any() ? + await DoSyncAsync(sponsoringOrg, sponsorshipsData) : + (sponsorshipsData, Array.Empty()); - var sponsorshipsToUpsert = new List(); - var sponsorshipIdsToDelete = new List(); - var sponsorshipsToReturn = new List(); + await RecordEvent(sponsoringOrg); - foreach (var selfHostedSponsorship in sponsorshipsData) + return (new OrganizationSponsorshipSyncData + { + SponsorshipsBatch = processedSponsorshipsData + }, sponsorshipsToEmailOffer); + } + + private async Task<(IEnumerable data, IEnumerable toOffer)> DoSyncAsync(Organization sponsoringOrg, IEnumerable sponsorshipsData) { - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(selfHostedSponsorship.PlanSponsorshipType)?.SponsoringProductType; - if (requiredSponsoringProductType == null - || StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) - { - continue; // prevent unsupported sponsorships - } + var existingSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrg.Id)) + .ToDictionary(i => i.SponsoringOrganizationUserId); - if (!existingSponsorshipsDict.TryGetValue(selfHostedSponsorship.SponsoringOrganizationUserId, out var cloudSponsorship)) - { - if (selfHostedSponsorship.ToDelete && selfHostedSponsorship.LastSyncDate == null) - { - continue; // prevent invalid sponsorships in cloud. These should have been deleted by self hosted - } - if (OrgDisabledForMoreThanGracePeriod(sponsoringOrg)) - { - continue; // prevent new sponsorships from disabled orgs - } - cloudSponsorship = new OrganizationSponsorship - { - SponsoringOrganizationId = sponsoringOrg.Id, - SponsoringOrganizationUserId = selfHostedSponsorship.SponsoringOrganizationUserId, - FriendlyName = selfHostedSponsorship.FriendlyName, - OfferedToEmail = selfHostedSponsorship.OfferedToEmail, - PlanSponsorshipType = selfHostedSponsorship.PlanSponsorshipType, - LastSyncDate = DateTime.UtcNow, - }; - } - else - { - cloudSponsorship.LastSyncDate = DateTime.UtcNow; - } + var sponsorshipsToUpsert = new List(); + var sponsorshipIdsToDelete = new List(); + var sponsorshipsToReturn = new List(); - if (selfHostedSponsorship.ToDelete) + foreach (var selfHostedSponsorship in sponsorshipsData) { - if (cloudSponsorship.SponsoredOrganizationId == null) + var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(selfHostedSponsorship.PlanSponsorshipType)?.SponsoringProductType; + if (requiredSponsoringProductType == null + || StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) { - sponsorshipIdsToDelete.Add(cloudSponsorship.Id); - selfHostedSponsorship.CloudSponsorshipRemoved = true; + continue; // prevent unsupported sponsorships + } + + if (!existingSponsorshipsDict.TryGetValue(selfHostedSponsorship.SponsoringOrganizationUserId, out var cloudSponsorship)) + { + if (selfHostedSponsorship.ToDelete && selfHostedSponsorship.LastSyncDate == null) + { + continue; // prevent invalid sponsorships in cloud. These should have been deleted by self hosted + } + if (OrgDisabledForMoreThanGracePeriod(sponsoringOrg)) + { + continue; // prevent new sponsorships from disabled orgs + } + cloudSponsorship = new OrganizationSponsorship + { + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = selfHostedSponsorship.SponsoringOrganizationUserId, + FriendlyName = selfHostedSponsorship.FriendlyName, + OfferedToEmail = selfHostedSponsorship.OfferedToEmail, + PlanSponsorshipType = selfHostedSponsorship.PlanSponsorshipType, + LastSyncDate = DateTime.UtcNow, + }; } else { - cloudSponsorship.ToDelete = true; + cloudSponsorship.LastSyncDate = DateTime.UtcNow; } + + if (selfHostedSponsorship.ToDelete) + { + if (cloudSponsorship.SponsoredOrganizationId == null) + { + sponsorshipIdsToDelete.Add(cloudSponsorship.Id); + selfHostedSponsorship.CloudSponsorshipRemoved = true; + } + else + { + cloudSponsorship.ToDelete = true; + } + } + sponsorshipsToUpsert.Add(cloudSponsorship); + + selfHostedSponsorship.ValidUntil = cloudSponsorship.ValidUntil; + selfHostedSponsorship.LastSyncDate = DateTime.UtcNow; + sponsorshipsToReturn.Add(selfHostedSponsorship); + } + var sponsorshipsToEmailOffer = sponsorshipsToUpsert.Where(s => s.Id == default).ToArray(); + if (sponsorshipsToUpsert.Any()) + { + await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); + } + if (sponsorshipIdsToDelete.Any()) + { + await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipIdsToDelete); } - sponsorshipsToUpsert.Add(cloudSponsorship); - selfHostedSponsorship.ValidUntil = cloudSponsorship.ValidUntil; - selfHostedSponsorship.LastSyncDate = DateTime.UtcNow; - sponsorshipsToReturn.Add(selfHostedSponsorship); + return (sponsorshipsToReturn, sponsorshipsToEmailOffer); } - var sponsorshipsToEmailOffer = sponsorshipsToUpsert.Where(s => s.Id == default).ToArray(); - if (sponsorshipsToUpsert.Any()) + + /// + /// True if Organization is disabled and the expiration date is more than three months ago + /// + /// + private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => + !organization.Enabled && + ( + !organization.ExpirationDate.HasValue || + DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 + ); + + private async Task RecordEvent(Organization organization) { - await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); + await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced); } - if (sponsorshipIdsToDelete.Any()) - { - await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipIdsToDelete); - } - - return (sponsorshipsToReturn, sponsorshipsToEmailOffer); - } - - /// - /// True if Organization is disabled and the expiration date is more than three months ago - /// - /// - private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => - !organization.Enabled && - ( - !organization.ExpirationDate.HasValue || - DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 - ); - - private async Task RecordEvent(Organization organization) - { - await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs index 1d7b66a66d..148b525d73 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs @@ -1,27 +1,28 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class OrganizationSponsorshipRenewCommand : IOrganizationSponsorshipRenewCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - - public OrganizationSponsorshipRenewCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository) + public class OrganizationSponsorshipRenewCommand : IOrganizationSponsorshipRenewCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate) - { - var sponsorship = await _organizationSponsorshipRepository.GetBySponsoredOrganizationIdAsync(organizationId); - - if (sponsorship == null) + public OrganizationSponsorshipRenewCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository) { - return; + _organizationSponsorshipRepository = organizationSponsorshipRepository; } - sponsorship.ValidUntil = expireDate; - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate) + { + var sponsorship = await _organizationSponsorshipRepository.GetBySponsoredOrganizationIdAsync(organizationId); + + if (sponsorship == null) + { + return; + } + + sponsorship.ValidUntil = expireDate; + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs index 1e05f8bc4f..136c1681b8 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs @@ -3,23 +3,24 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class RemoveSponsorshipCommand : CancelSponsorshipCommand, IRemoveSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - public RemoveSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) + public class RemoveSponsorshipCommand : CancelSponsorshipCommand, IRemoveSponsorshipCommand { - } - - public async Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship) - { - if (sponsorship == null || sponsorship.SponsoredOrganizationId == null) + public RemoveSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) { - throw new BadRequestException("The requested organization is not currently being sponsored."); } - await MarkToDeleteSponsorshipAsync(sponsorship); + public async Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship) + { + if (sponsorship == null || sponsorship.SponsoredOrganizationId == null) + { + throw new BadRequestException("The requested organization is not currently being sponsored."); + } + + await MarkToDeleteSponsorshipAsync(sponsorship); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs index 5f9a62d25f..b77706051d 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs @@ -7,63 +7,64 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Tokens; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class SendSponsorshipOfferCommand : ISendSponsorshipOfferCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IDataProtectorTokenFactory _tokenFactory; - - public SendSponsorshipOfferCommand(IUserRepository userRepository, - IMailService mailService, - IDataProtectorTokenFactory tokenFactory) + public class SendSponsorshipOfferCommand : ISendSponsorshipOfferCommand { - _userRepository = userRepository; - _mailService = mailService; - _tokenFactory = tokenFactory; - } + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IDataProtectorTokenFactory _tokenFactory; - public async Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable sponsorships) - { - var invites = new List<(string, bool, string)>(); - foreach (var sponsorship in sponsorships) + public SendSponsorshipOfferCommand(IUserRepository userRepository, + IMailService mailService, + IDataProtectorTokenFactory tokenFactory) + { + _userRepository = userRepository; + _mailService = mailService; + _tokenFactory = tokenFactory; + } + + public async Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable sponsorships) + { + var invites = new List<(string, bool, string)>(); + foreach (var sponsorship in sponsorships) + { + var user = await _userRepository.GetByEmailAsync(sponsorship.OfferedToEmail); + var isExistingAccount = user != null; + invites.Add((sponsorship.OfferedToEmail, user != null, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship)))); + } + + await _mailService.BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, invites); + } + + public async Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName) { var user = await _userRepository.GetByEmailAsync(sponsorship.OfferedToEmail); var isExistingAccount = user != null; - invites.Add((sponsorship.OfferedToEmail, user != null, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship)))); + + await _mailService.SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, + isExistingAccount, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship))); } - await _mailService.BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, invites); - } - - public async Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName) - { - var user = await _userRepository.GetByEmailAsync(sponsorship.OfferedToEmail); - var isExistingAccount = user != null; - - await _mailService.SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, - isExistingAccount, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship))); - } - - public async Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - OrganizationSponsorship sponsorship) - { - if (sponsoringOrg == null) + public async Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + OrganizationSponsorship sponsorship) { - throw new BadRequestException("Cannot find the requested sponsoring organization."); - } + if (sponsoringOrg == null) + { + throw new BadRequestException("Cannot find the requested sponsoring organization."); + } - if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) - { - throw new BadRequestException("Only confirmed users can sponsor other organizations."); - } + if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) + { + throw new BadRequestException("Only confirmed users can sponsor other organizations."); + } - if (sponsorship == null || sponsorship.OfferedToEmail == null) - { - throw new BadRequestException("Cannot find an outstanding sponsorship offer for this organization."); - } + if (sponsorship == null || sponsorship.OfferedToEmail == null) + { + throw new BadRequestException("Cannot find an outstanding sponsorship offer for this organization."); + } - await SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); + await SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); + } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs index 9230e7d13d..698ec549d8 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs @@ -5,62 +5,63 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; - - public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) + public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationRepository = organizationRepository; - _paymentService = paymentService; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IPaymentService _paymentService; - public async Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, - Organization sponsoredOrganization) - { - if (sponsorship == null) + public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) { - throw new BadRequestException("No unredeemed sponsorship offer exists for you."); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationRepository = organizationRepository; + _paymentService = paymentService; } - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoredOrganizationIdAsync(sponsoredOrganization.Id); - if (existingOrgSponsorship != null) + public async Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, + Organization sponsoredOrganization) { - throw new BadRequestException("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first."); + if (sponsorship == null) + { + throw new BadRequestException("No unredeemed sponsorship offer exists for you."); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoredOrganizationIdAsync(sponsoredOrganization.Id); + if (existingOrgSponsorship != null) + { + throw new BadRequestException("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first."); + } + + if (sponsorship.PlanSponsorshipType == null) + { + throw new BadRequestException("Cannot set up sponsorship without a known sponsorship type."); + } + + // Do not allow self-hosted sponsorships that haven't been synced for > 0.5 year + if (sponsorship.LastSyncDate != null && DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5) + { + await _organizationSponsorshipRepository.DeleteAsync(sponsorship); + throw new BadRequestException("This sponsorship offer is more than 6 months old and has expired."); + } + + // Check org to sponsor's product type + var requiredSponsoredProductType = StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)?.SponsoredProductType; + if (requiredSponsoredProductType == null || + sponsoredOrganization == null || + StaticStore.GetPlan(sponsoredOrganization.PlanType).Product != requiredSponsoredProductType.Value) + { + throw new BadRequestException("Can only redeem sponsorship offer on families organizations."); + } + + await _paymentService.SponsorOrganizationAsync(sponsoredOrganization, sponsorship); + await _organizationRepository.UpsertAsync(sponsoredOrganization); + + sponsorship.SponsoredOrganizationId = sponsoredOrganization.Id; + sponsorship.OfferedToEmail = null; + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); } - - if (sponsorship.PlanSponsorshipType == null) - { - throw new BadRequestException("Cannot set up sponsorship without a known sponsorship type."); - } - - // Do not allow self-hosted sponsorships that haven't been synced for > 0.5 year - if (sponsorship.LastSyncDate != null && DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5) - { - await _organizationSponsorshipRepository.DeleteAsync(sponsorship); - throw new BadRequestException("This sponsorship offer is more than 6 months old and has expired."); - } - - // Check org to sponsor's product type - var requiredSponsoredProductType = StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)?.SponsoredProductType; - if (requiredSponsoredProductType == null || - sponsoredOrganization == null || - StaticStore.GetPlan(sponsoredOrganization.PlanType).Product != requiredSponsoredProductType.Value) - { - throw new BadRequestException("Can only redeem sponsorship offer on families organizations."); - } - - await _paymentService.SponsorOrganizationAsync(sponsoredOrganization, sponsorship); - await _organizationRepository.UpsertAsync(sponsoredOrganization); - - sponsorship.SponsoredOrganizationId = sponsoredOrganization.Id; - sponsorship.OfferedToEmail = null; - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs index 19c4398a70..f1032f0b2e 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs @@ -3,37 +3,38 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class ValidateBillingSyncKeyCommand : IValidateBillingSyncKeyCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationApiKeyRepository _apiKeyRepository; - - public ValidateBillingSyncKeyCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository) + public class ValidateBillingSyncKeyCommand : IValidateBillingSyncKeyCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _apiKeyRepository = organizationApiKeyRepository; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationApiKeyRepository _apiKeyRepository; - public async Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey) - { - if (organization == null) + public ValidateBillingSyncKeyCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository) { - throw new BadRequestException("Invalid organization"); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _apiKeyRepository = organizationApiKeyRepository; } - if (string.IsNullOrWhiteSpace(billingSyncKey)) + + public async Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey) { + if (organization == null) + { + throw new BadRequestException("Invalid organization"); + } + if (string.IsNullOrWhiteSpace(billingSyncKey)) + { + return false; + } + + var orgApiKey = (await _apiKeyRepository.GetManyByOrganizationIdTypeAsync(organization.Id, Enums.OrganizationApiKeyType.BillingSync)).FirstOrDefault(); + if (string.Equals(orgApiKey.ApiKey, billingSyncKey)) + { + return true; + } return false; } - - var orgApiKey = (await _apiKeyRepository.GetManyByOrganizationIdTypeAsync(organization.Id, Enums.OrganizationApiKeyType.BillingSync)).FirstOrDefault(); - if (string.Equals(orgApiKey.ApiKey, billingSyncKey)) - { - return true; - } - return false; } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs index fc3f5b1321..179f5b3ac5 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs @@ -4,33 +4,34 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterpri using Bit.Core.Repositories; using Bit.Core.Tokens; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class ValidateRedemptionTokenCommand : IValidateRedemptionTokenCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IDataProtectorTokenFactory _dataProtectorTokenFactory; - - public ValidateRedemptionTokenCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IDataProtectorTokenFactory dataProtectorTokenFactory) + public class ValidateRedemptionTokenCommand : IValidateRedemptionTokenCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _dataProtectorTokenFactory = dataProtectorTokenFactory; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IDataProtectorTokenFactory _dataProtectorTokenFactory; - public async Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail) - { - - if (!_dataProtectorTokenFactory.TryUnprotect(encryptedToken, out var tokenable)) + public ValidateRedemptionTokenCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IDataProtectorTokenFactory dataProtectorTokenFactory) { - return (false, null); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _dataProtectorTokenFactory = dataProtectorTokenFactory; } - var sponsorship = await _organizationSponsorshipRepository.GetByIdAsync(tokenable.Id); - if (!tokenable.IsValid(sponsorship, sponsoredUserEmail)) + public async Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail) { - return (false, sponsorship); + + if (!_dataProtectorTokenFactory.TryUnprotect(encryptedToken, out var tokenable)) + { + return (false, null); + } + + var sponsorship = await _organizationSponsorshipRepository.GetByIdAsync(tokenable.Id); + if (!tokenable.IsValid(sponsorship, sponsoredUserEmail)) + { + return (false, sponsorship); + } + return (true, sponsorship); } - return (true, sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs index 3f2d7af5eb..3b0bf3f141 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs @@ -4,111 +4,112 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Microsoft.Extensions.Logging; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - private readonly IPaymentService _paymentService; - private readonly IMailService _mailService; - private readonly ILogger _logger; - - public ValidateSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository, - IPaymentService paymentService, - IMailService mailService, - ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) + public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand { - _paymentService = paymentService; - _mailService = mailService; - _logger = logger; - } + private readonly IPaymentService _paymentService; + private readonly IMailService _mailService; + private readonly ILogger _logger; - public async Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId) - { - var sponsoredOrganization = await _organizationRepository.GetByIdAsync(sponsoredOrganizationId); - if (sponsoredOrganization == null) + public ValidateSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository, + IPaymentService paymentService, + IMailService mailService, + ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) { - return false; + _paymentService = paymentService; + _mailService = mailService; + _logger = logger; } - var existingSponsorship = await _organizationSponsorshipRepository - .GetBySponsoredOrganizationIdAsync(sponsoredOrganizationId); - - if (existingSponsorship == null) + public async Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId) { - await CancelSponsorshipAsync(sponsoredOrganization, null); - return false; - } - - if (existingSponsorship.SponsoringOrganizationId == null || existingSponsorship.SponsoringOrganizationUserId == default || existingSponsorship.PlanSponsorshipType == null) - { - await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); - return false; - } - var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(existingSponsorship.PlanSponsorshipType.Value); - - var sponsoringOrganization = await _organizationRepository - .GetByIdAsync(existingSponsorship.SponsoringOrganizationId.Value); - if (sponsoringOrganization == null) - { - await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); - return false; - } - - var sponsoringOrgPlan = Utilities.StaticStore.GetPlan(sponsoringOrganization.PlanType); - if (OrgDisabledForMoreThanGracePeriod(sponsoringOrganization) || - sponsoredPlan.SponsoringProductType != sponsoringOrgPlan.Product || - existingSponsorship.ToDelete || - SponsorshipIsSelfHostedOutOfSync(existingSponsorship)) - { - await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); - return false; - } - - return true; - } - - protected async Task CancelSponsorshipAsync(Organization sponsoredOrganization, OrganizationSponsorship sponsorship = null) - { - if (sponsoredOrganization != null) - { - await _paymentService.RemoveOrganizationSponsorshipAsync(sponsoredOrganization, sponsorship); - await _organizationRepository.UpsertAsync(sponsoredOrganization); - - try + var sponsoredOrganization = await _organizationRepository.GetByIdAsync(sponsoredOrganizationId); + if (sponsoredOrganization == null) { - if (sponsorship != null) + return false; + } + + var existingSponsorship = await _organizationSponsorshipRepository + .GetBySponsoredOrganizationIdAsync(sponsoredOrganizationId); + + if (existingSponsorship == null) + { + await CancelSponsorshipAsync(sponsoredOrganization, null); + return false; + } + + if (existingSponsorship.SponsoringOrganizationId == null || existingSponsorship.SponsoringOrganizationUserId == default || existingSponsorship.PlanSponsorshipType == null) + { + await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); + return false; + } + var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(existingSponsorship.PlanSponsorshipType.Value); + + var sponsoringOrganization = await _organizationRepository + .GetByIdAsync(existingSponsorship.SponsoringOrganizationId.Value); + if (sponsoringOrganization == null) + { + await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); + return false; + } + + var sponsoringOrgPlan = Utilities.StaticStore.GetPlan(sponsoringOrganization.PlanType); + if (OrgDisabledForMoreThanGracePeriod(sponsoringOrganization) || + sponsoredPlan.SponsoringProductType != sponsoringOrgPlan.Product || + existingSponsorship.ToDelete || + SponsorshipIsSelfHostedOutOfSync(existingSponsorship)) + { + await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); + return false; + } + + return true; + } + + protected async Task CancelSponsorshipAsync(Organization sponsoredOrganization, OrganizationSponsorship sponsorship = null) + { + if (sponsoredOrganization != null) + { + await _paymentService.RemoveOrganizationSponsorshipAsync(sponsoredOrganization, sponsorship); + await _organizationRepository.UpsertAsync(sponsoredOrganization); + + try { - await _mailService.SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync( - sponsoredOrganization.BillingEmailAddress(), - sponsorship.ValidUntil ?? DateTime.UtcNow.AddDays(15)); + if (sponsorship != null) + { + await _mailService.SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync( + sponsoredOrganization.BillingEmailAddress(), + sponsorship.ValidUntil ?? DateTime.UtcNow.AddDays(15)); + } + } + catch (Exception e) + { + _logger.LogError("Error sending Family sponsorship removed email.", e); } } - catch (Exception e) - { - _logger.LogError("Error sending Family sponsorship removed email.", e); - } + await base.DeleteSponsorshipAsync(sponsorship); } - await base.DeleteSponsorshipAsync(sponsorship); + + /// + /// True if Sponsorship is from a self-hosted instance that has failed to sync for more than 6 months + /// + /// + private bool SponsorshipIsSelfHostedOutOfSync(OrganizationSponsorship sponsorship) => + sponsorship.LastSyncDate.HasValue && + DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5; + + /// + /// True if Organization is disabled and the expiration date is more than three months ago + /// + /// + private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => + !organization.Enabled && + ( + !organization.ExpirationDate.HasValue || + DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 + ); } - - /// - /// True if Sponsorship is from a self-hosted instance that has failed to sync for more than 6 months - /// - /// - private bool SponsorshipIsSelfHostedOutOfSync(OrganizationSponsorship sponsorship) => - sponsorship.LastSyncDate.HasValue && - DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5; - - /// - /// True if Organization is disabled and the expiration date is more than three months ago - /// - /// - private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => - !organization.Enabled && - ( - !organization.ExpirationDate.HasValue || - DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 - ); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index 69e6c3232c..6d186726a8 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -6,76 +6,77 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -public class CreateSponsorshipCommand : ICreateSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IUserService _userService; - - public CreateSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IUserService userService) + public class CreateSponsorshipCommand : ICreateSponsorshipCommand { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _userService = userService; - } + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IUserService _userService; - public async Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName) - { - var sponsoringUser = await _userService.GetUserByIdAsync(sponsoringOrgUser.UserId.Value); - if (sponsoringUser == null || string.Equals(sponsoringUser.Email, sponsoredEmail, System.StringComparison.InvariantCultureIgnoreCase)) + public CreateSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IUserService userService) { - throw new BadRequestException("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email."); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _userService = userService; } - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(sponsorshipType)?.SponsoringProductType; - if (requiredSponsoringProductType == null || - sponsoringOrg == null || - StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) + public async Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName) { - throw new BadRequestException("Specified Organization cannot sponsor other organizations."); - } - - if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) - { - throw new BadRequestException("Only confirmed users can sponsor other organizations."); - } - - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id); - if (existingOrgSponsorship?.SponsoredOrganizationId != null) - { - throw new BadRequestException("Can only sponsor one organization per Organization User."); - } - - var sponsorship = new OrganizationSponsorship - { - SponsoringOrganizationId = sponsoringOrg.Id, - SponsoringOrganizationUserId = sponsoringOrgUser.Id, - FriendlyName = friendlyName, - OfferedToEmail = sponsoredEmail, - PlanSponsorshipType = sponsorshipType, - }; - - if (existingOrgSponsorship != null) - { - // Replace existing invalid offer with our new sponsorship offer - sponsorship.Id = existingOrgSponsorship.Id; - } - - try - { - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); - return sponsorship; - } - catch - { - if (sponsorship.Id != default) + var sponsoringUser = await _userService.GetUserByIdAsync(sponsoringOrgUser.UserId.Value); + if (sponsoringUser == null || string.Equals(sponsoringUser.Email, sponsoredEmail, System.StringComparison.InvariantCultureIgnoreCase)) { - await _organizationSponsorshipRepository.DeleteAsync(sponsorship); + throw new BadRequestException("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email."); + } + + var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(sponsorshipType)?.SponsoringProductType; + if (requiredSponsoringProductType == null || + sponsoringOrg == null || + StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) + { + throw new BadRequestException("Specified Organization cannot sponsor other organizations."); + } + + if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) + { + throw new BadRequestException("Only confirmed users can sponsor other organizations."); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id); + if (existingOrgSponsorship?.SponsoredOrganizationId != null) + { + throw new BadRequestException("Can only sponsor one organization per Organization User."); + } + + var sponsorship = new OrganizationSponsorship + { + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = sponsoringOrgUser.Id, + FriendlyName = friendlyName, + OfferedToEmail = sponsoredEmail, + PlanSponsorshipType = sponsorshipType, + }; + + if (existingOrgSponsorship != null) + { + // Replace existing invalid offer with our new sponsorship offer + sponsorship.Id = existingOrgSponsorship.Id; + } + + try + { + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + return sponsorship; + } + catch + { + if (sponsorship.Id != default) + { + await _organizationSponsorshipRepository.DeleteAsync(sponsorship); + } + throw; } - throw; } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs index 1ba4b36628..c321524e7e 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface ICreateSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName); + public interface ICreateSponsorshipCommand + { + Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs index 9d04c280d0..762d166bd7 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface IOrganizationSponsorshipRenewCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate); + public interface IOrganizationSponsorshipRenewCommand + { + Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs index a37e6cee90..21a8fec89a 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface IRemoveSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship); + public interface IRemoveSponsorshipCommand + { + Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs index 48a4964944..18ca25d178 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface IRevokeSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship); + public interface IRevokeSponsorshipCommand + { + Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs index 9795ed00f2..a047c4d668 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface ISendSponsorshipOfferCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable invites); - Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName); - Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - OrganizationSponsorship sponsorship); + public interface ISendSponsorshipOfferCommand + { + Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable invites); + Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName); + Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + OrganizationSponsorship sponsorship); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs index 4c57c90728..d4c5e9b0e3 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface ISetUpSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, - Organization sponsoredOrganization); + public interface ISetUpSponsorshipCommand + { + Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, + Organization sponsoredOrganization); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs index 0b8bb6444d..9533e4bfd9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs @@ -1,14 +1,15 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface ISelfHostedSyncSponsorshipsCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection); -} + public interface ISelfHostedSyncSponsorshipsCommand + { + Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection); + } -public interface ICloudSyncSponsorshipsCommand -{ - Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData); + public interface ICloudSyncSponsorshipsCommand + { + Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs index 53e926903f..1ac3e1c0d7 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface IValidateBillingSyncKeyCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey); + public interface IValidateBillingSyncKeyCommand + { + Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs index 714e9e2b52..a7db2ed2ee 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface IValidateRedemptionTokenCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail); + public interface IValidateRedemptionTokenCommand + { + Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs index 47b2e47c22..0d01246778 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - -public interface IValidateSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces { - Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId); + public interface IValidateSponsorshipCommand + { + Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs index 820d277581..aad92f43cb 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs @@ -3,30 +3,31 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; - -public class SelfHostedRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted { - public SelfHostedRevokeSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) + public class SelfHostedRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand { - } - - public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) - { - if (sponsorship == null) + public SelfHostedRevokeSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) { - throw new BadRequestException("You are not currently sponsoring an organization."); } - if (sponsorship.LastSyncDate == null) + public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) { - await base.DeleteSponsorshipAsync(sponsorship); - } - else - { - await MarkToDeleteSponsorshipAsync(sponsorship); + if (sponsorship == null) + { + throw new BadRequestException("You are not currently sponsoring an organization."); + } + + if (sponsorship.LastSyncDate == null) + { + await base.DeleteSponsorshipAsync(sponsorship); + } + else + { + await MarkToDeleteSponsorshipAsync(sponsorship); + } } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs index df293c3a7b..4f12c1cf33 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs @@ -11,120 +11,121 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Extensions.Logging; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; - -public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISelfHostedSyncSponsorshipsCommand +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted { - private readonly IGlobalSettings _globalSettings; - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - - public SelfHostedSyncSponsorshipsCommand( - IHttpClientFactory httpFactory, - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - IGlobalSettings globalSettings, - ILogger logger) - : base( - httpFactory, - globalSettings.Installation.ApiUri, - globalSettings.Installation.IdentityUri, - "api.installation", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) + public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISelfHostedSyncSponsorshipsCommand { - _globalSettings = globalSettings; - _organizationUserRepository = organizationUserRepository; - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationConnectionRepository = organizationConnectionRepository; - } + private readonly IGlobalSettings _globalSettings; + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - public async Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) - { - if (!_globalSettings.EnableCloudCommunication) + public SelfHostedSyncSponsorshipsCommand( + IHttpClientFactory httpFactory, + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + IGlobalSettings globalSettings, + ILogger logger) + : base( + httpFactory, + globalSettings.Installation.ApiUri, + globalSettings.Installation.IdentityUri, + "api.installation", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) { - throw new BadRequestException("Failed to sync instance with cloud - Cloud communication is disabled in global settings"); - } - if (!billingSyncConnection.Enabled) - { - throw new BadRequestException($"Billing Sync Key disabled for organization {organizationId}"); - } - if (string.IsNullOrWhiteSpace(billingSyncConnection.Config)) - { - throw new BadRequestException($"No Billing Sync Key known for organization {organizationId}"); - } - var billingSyncConfig = billingSyncConnection.GetConfig(); - if (billingSyncConfig == null || string.IsNullOrWhiteSpace(billingSyncConfig.BillingSyncKey)) - { - throw new BadRequestException($"Failed to get Billing Sync Key for organization {organizationId}"); + _globalSettings = globalSettings; + _organizationUserRepository = organizationUserRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationConnectionRepository = organizationConnectionRepository; } - var organizationSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(organizationId)) - .ToDictionary(i => i.SponsoringOrganizationUserId); - if (!organizationSponsorshipsDict.Any()) + public async Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) { - _logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}"); - return; - } - var syncedSponsorships = new List(); - - foreach (var orgSponsorshipsBatch in CoreHelpers.Batch(organizationSponsorshipsDict.Values, 1000)) - { - var response = await SendAsync(HttpMethod.Post, "organization/sponsorship/sync", new OrganizationSponsorshipSyncRequestModel + if (!_globalSettings.EnableCloudCommunication) { - BillingSyncKey = billingSyncConfig.BillingSyncKey, - SponsoringOrganizationCloudId = cloudOrganizationId, - SponsorshipsBatch = orgSponsorshipsBatch.Select(s => new OrganizationSponsorshipRequestModel(s)) + throw new BadRequestException("Failed to sync instance with cloud - Cloud communication is disabled in global settings"); + } + if (!billingSyncConnection.Enabled) + { + throw new BadRequestException($"Billing Sync Key disabled for organization {organizationId}"); + } + if (string.IsNullOrWhiteSpace(billingSyncConnection.Config)) + { + throw new BadRequestException($"No Billing Sync Key known for organization {organizationId}"); + } + var billingSyncConfig = billingSyncConnection.GetConfig(); + if (billingSyncConfig == null || string.IsNullOrWhiteSpace(billingSyncConfig.BillingSyncKey)) + { + throw new BadRequestException($"Failed to get Billing Sync Key for organization {organizationId}"); + } + + var organizationSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(organizationId)) + .ToDictionary(i => i.SponsoringOrganizationUserId); + if (!organizationSponsorshipsDict.Any()) + { + _logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}"); + return; + } + var syncedSponsorships = new List(); + + foreach (var orgSponsorshipsBatch in CoreHelpers.Batch(organizationSponsorshipsDict.Values, 1000)) + { + var response = await SendAsync(HttpMethod.Post, "organization/sponsorship/sync", new OrganizationSponsorshipSyncRequestModel + { + BillingSyncKey = billingSyncConfig.BillingSyncKey, + SponsoringOrganizationCloudId = cloudOrganizationId, + SponsorshipsBatch = orgSponsorshipsBatch.Select(s => new OrganizationSponsorshipRequestModel(s)) + }); + + if (response == null) + { + _logger.LogDebug("Organization sync failed for '{OrgId}'", organizationId); + throw new BadRequestException("Organization sync failed"); + } + + syncedSponsorships.AddRange(response.ToOrganizationSponsorshipSync().SponsorshipsBatch); + } + + var sponsorshipsToDelete = syncedSponsorships.Where(s => s.CloudSponsorshipRemoved).Select(i => organizationSponsorshipsDict[i.SponsoringOrganizationUserId].Id); + var sponsorshipsToUpsert = syncedSponsorships.Where(s => !s.CloudSponsorshipRemoved).Select(i => + { + var existingSponsorship = organizationSponsorshipsDict[i.SponsoringOrganizationUserId]; + if (existingSponsorship != null) + { + existingSponsorship.LastSyncDate = i.LastSyncDate; + existingSponsorship.ValidUntil = i.ValidUntil; + existingSponsorship.ToDelete = i.ToDelete; + } + else + { + // shouldn't occur, added in case self hosted loses a sponsorship + existingSponsorship = new OrganizationSponsorship + { + SponsoringOrganizationId = organizationId, + SponsoringOrganizationUserId = i.SponsoringOrganizationUserId, + FriendlyName = i.FriendlyName, + OfferedToEmail = i.OfferedToEmail, + PlanSponsorshipType = i.PlanSponsorshipType, + LastSyncDate = i.LastSyncDate, + ValidUntil = i.ValidUntil, + ToDelete = i.ToDelete + }; + } + return existingSponsorship; }); - if (response == null) + if (sponsorshipsToDelete.Any()) { - _logger.LogDebug("Organization sync failed for '{OrgId}'", organizationId); - throw new BadRequestException("Organization sync failed"); + await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipsToDelete); + } + if (sponsorshipsToUpsert.Any()) + { + await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); } - - syncedSponsorships.AddRange(response.ToOrganizationSponsorshipSync().SponsorshipsBatch); } - var sponsorshipsToDelete = syncedSponsorships.Where(s => s.CloudSponsorshipRemoved).Select(i => organizationSponsorshipsDict[i.SponsoringOrganizationUserId].Id); - var sponsorshipsToUpsert = syncedSponsorships.Where(s => !s.CloudSponsorshipRemoved).Select(i => - { - var existingSponsorship = organizationSponsorshipsDict[i.SponsoringOrganizationUserId]; - if (existingSponsorship != null) - { - existingSponsorship.LastSyncDate = i.LastSyncDate; - existingSponsorship.ValidUntil = i.ValidUntil; - existingSponsorship.ToDelete = i.ToDelete; - } - else - { - // shouldn't occur, added in case self hosted loses a sponsorship - existingSponsorship = new OrganizationSponsorship - { - SponsoringOrganizationId = organizationId, - SponsoringOrganizationUserId = i.SponsoringOrganizationUserId, - FriendlyName = i.FriendlyName, - OfferedToEmail = i.OfferedToEmail, - PlanSponsorshipType = i.PlanSponsorshipType, - LastSyncDate = i.LastSyncDate, - ValidUntil = i.ValidUntil, - ToDelete = i.ToDelete - }; - } - return existingSponsorship; - }); - - if (sponsorshipsToDelete.Any()) - { - await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipsToDelete); - } - if (sponsorshipsToUpsert.Any()) - { - await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); - } } - } diff --git a/src/Core/Repositories/ICipherRepository.cs b/src/Core/Repositories/ICipherRepository.cs index 56f7619350..5e071a55f6 100644 --- a/src/Core/Repositories/ICipherRepository.cs +++ b/src/Core/Repositories/ICipherRepository.cs @@ -2,37 +2,38 @@ using Bit.Core.Models.Data; using Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface ICipherRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByIdAsync(Guid id, Guid userId); - Task GetOrganizationDetailsByIdAsync(Guid id); - Task> GetManyOrganizationDetailsByOrganizationIdAsync(Guid organizationId); - Task GetCanEditByIdAsync(Guid userId, Guid cipherId); - Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task CreateAsync(Cipher cipher, IEnumerable collectionIds); - Task CreateAsync(CipherDetails cipher); - Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds); - Task ReplaceAsync(CipherDetails cipher); - Task UpsertAsync(CipherDetails cipher); - Task ReplaceAsync(Cipher obj, IEnumerable collectionIds); - Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite); - Task UpdateAttachmentAsync(CipherAttachment attachment); - Task DeleteAttachmentAsync(Guid cipherId, string attachmentId); - Task DeleteAsync(IEnumerable ids, Guid userId); - Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); - Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId); - Task DeleteByUserIdAsync(Guid userId); - Task DeleteByOrganizationIdAsync(Guid organizationId); - Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends); - Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); - Task CreateAsync(IEnumerable ciphers, IEnumerable folders); - Task CreateAsync(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers); - Task SoftDeleteAsync(IEnumerable ids, Guid userId); - Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); - Task RestoreAsync(IEnumerable ids, Guid userId); - Task DeleteDeletedAsync(DateTime deletedDateBefore); + public interface ICipherRepository : IRepository + { + Task GetByIdAsync(Guid id, Guid userId); + Task GetOrganizationDetailsByIdAsync(Guid id); + Task> GetManyOrganizationDetailsByOrganizationIdAsync(Guid organizationId); + Task GetCanEditByIdAsync(Guid userId, Guid cipherId); + Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task CreateAsync(Cipher cipher, IEnumerable collectionIds); + Task CreateAsync(CipherDetails cipher); + Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds); + Task ReplaceAsync(CipherDetails cipher); + Task UpsertAsync(CipherDetails cipher); + Task ReplaceAsync(Cipher obj, IEnumerable collectionIds); + Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite); + Task UpdateAttachmentAsync(CipherAttachment attachment); + Task DeleteAttachmentAsync(Guid cipherId, string attachmentId); + Task DeleteAsync(IEnumerable ids, Guid userId); + Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); + Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId); + Task DeleteByUserIdAsync(Guid userId); + Task DeleteByOrganizationIdAsync(Guid organizationId); + Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends); + Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); + Task CreateAsync(IEnumerable ciphers, IEnumerable folders); + Task CreateAsync(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers); + Task SoftDeleteAsync(IEnumerable ids, Guid userId); + Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); + Task RestoreAsync(IEnumerable ids, Guid userId); + Task DeleteDeletedAsync(DateTime deletedDateBefore); + } } diff --git a/src/Core/Repositories/ICollectionCipherRepository.cs b/src/Core/Repositories/ICollectionCipherRepository.cs index 2721288100..b79c65737f 100644 --- a/src/Core/Repositories/ICollectionCipherRepository.cs +++ b/src/Core/Repositories/ICollectionCipherRepository.cs @@ -1,14 +1,15 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface ICollectionCipherRepository +namespace Bit.Core.Repositories { - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId); - Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds); - Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds); - Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, - IEnumerable collectionIds); + public interface ICollectionCipherRepository + { + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId); + Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds); + Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds); + Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, + IEnumerable collectionIds); + } } diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index dda042aa89..e533799971 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -1,19 +1,20 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface ICollectionRepository : IRepository +namespace Bit.Core.Repositories { - Task GetCountByOrganizationIdAsync(Guid organizationId); - Task>> GetByIdWithGroupsAsync(Guid id); - Task>> GetByIdWithGroupsAsync(Guid id, Guid userId); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task GetByIdAsync(Guid id, Guid userId); - Task> GetManyByUserIdAsync(Guid userId); - Task CreateAsync(Collection obj, IEnumerable groups); - Task ReplaceAsync(Collection obj, IEnumerable groups); - Task DeleteUserAsync(Guid collectionId, Guid organizationUserId); - Task UpdateUsersAsync(Guid id, IEnumerable users); - Task> GetManyUsersByIdAsync(Guid id); + public interface ICollectionRepository : IRepository + { + Task GetCountByOrganizationIdAsync(Guid organizationId); + Task>> GetByIdWithGroupsAsync(Guid id); + Task>> GetByIdWithGroupsAsync(Guid id, Guid userId); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task GetByIdAsync(Guid id, Guid userId); + Task> GetManyByUserIdAsync(Guid userId); + Task CreateAsync(Collection obj, IEnumerable groups); + Task ReplaceAsync(Collection obj, IEnumerable groups); + Task DeleteUserAsync(Guid collectionId, Guid organizationUserId); + Task UpdateUsersAsync(Guid id, IEnumerable users); + Task> GetManyUsersByIdAsync(Guid id); + } } diff --git a/src/Core/Repositories/IDeviceRepository.cs b/src/Core/Repositories/IDeviceRepository.cs index 5424d5fe36..85221d446b 100644 --- a/src/Core/Repositories/IDeviceRepository.cs +++ b/src/Core/Repositories/IDeviceRepository.cs @@ -1,12 +1,13 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface IDeviceRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByIdAsync(Guid id, Guid userId); - Task GetByIdentifierAsync(string identifier); - Task GetByIdentifierAsync(string identifier, Guid userId); - Task> GetManyByUserIdAsync(Guid userId); - Task ClearPushTokenAsync(Guid id); + public interface IDeviceRepository : IRepository + { + Task GetByIdAsync(Guid id, Guid userId); + Task GetByIdentifierAsync(string identifier); + Task GetByIdentifierAsync(string identifier, Guid userId); + Task> GetManyByUserIdAsync(Guid userId); + Task ClearPushTokenAsync(Guid id); + } } diff --git a/src/Core/Repositories/IEmergencyAccessRepository.cs b/src/Core/Repositories/IEmergencyAccessRepository.cs index 790f7191c5..449bfe631c 100644 --- a/src/Core/Repositories/IEmergencyAccessRepository.cs +++ b/src/Core/Repositories/IEmergencyAccessRepository.cs @@ -1,14 +1,15 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IEmergencyAccessRepository : IRepository +namespace Bit.Core.Repositories { - Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers); - Task> GetManyDetailsByGrantorIdAsync(Guid grantorId); - Task> GetManyDetailsByGranteeIdAsync(Guid granteeId); - Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId); - Task> GetManyToNotifyAsync(); - Task> GetExpiredRecoveriesAsync(); + public interface IEmergencyAccessRepository : IRepository + { + Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers); + Task> GetManyDetailsByGrantorIdAsync(Guid grantorId); + Task> GetManyDetailsByGranteeIdAsync(Guid granteeId); + Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId); + Task> GetManyToNotifyAsync(); + Task> GetExpiredRecoveriesAsync(); + } } diff --git a/src/Core/Repositories/IEventRepository.cs b/src/Core/Repositories/IEventRepository.cs index bac3cb5345..c2af5c0e0f 100644 --- a/src/Core/Repositories/IEventRepository.cs +++ b/src/Core/Repositories/IEventRepository.cs @@ -1,22 +1,23 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IEventRepository +namespace Bit.Core.Repositories { - Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions); - Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions); - Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task CreateAsync(IEvent e); - Task CreateManyAsync(IEnumerable e); + public interface IEventRepository + { + Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions); + Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions); + Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task CreateAsync(IEvent e); + Task CreateManyAsync(IEnumerable e); + } } diff --git a/src/Core/Repositories/IFolderRepository.cs b/src/Core/Repositories/IFolderRepository.cs index b93ca097b6..c174f4fb12 100644 --- a/src/Core/Repositories/IFolderRepository.cs +++ b/src/Core/Repositories/IFolderRepository.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface IFolderRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByIdAsync(Guid id, Guid userId); - Task> GetManyByUserIdAsync(Guid userId); + public interface IFolderRepository : IRepository + { + Task GetByIdAsync(Guid id, Guid userId); + Task> GetManyByUserIdAsync(Guid userId); + } } diff --git a/src/Core/Repositories/IGrantRepository.cs b/src/Core/Repositories/IGrantRepository.cs index 14f4fcb03b..edab4c8153 100644 --- a/src/Core/Repositories/IGrantRepository.cs +++ b/src/Core/Repositories/IGrantRepository.cs @@ -1,12 +1,13 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface IGrantRepository +namespace Bit.Core.Repositories { - Task GetByKeyAsync(string key); - Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type); - Task SaveAsync(Grant obj); - Task DeleteByKeyAsync(string key); - Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type); + public interface IGrantRepository + { + Task GetByKeyAsync(string key); + Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type); + Task SaveAsync(Grant obj); + Task DeleteByKeyAsync(string key); + Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type); + } } diff --git a/src/Core/Repositories/IGroupRepository.cs b/src/Core/Repositories/IGroupRepository.cs index d7b9b664d7..e8cdc43bca 100644 --- a/src/Core/Repositories/IGroupRepository.cs +++ b/src/Core/Repositories/IGroupRepository.cs @@ -1,17 +1,18 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IGroupRepository : IRepository +namespace Bit.Core.Repositories { - Task>> GetByIdWithCollectionsAsync(Guid id); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task> GetManyIdsByUserIdAsync(Guid organizationUserId); - Task> GetManyUserIdsByIdAsync(Guid id); - Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId); - Task CreateAsync(Group obj, IEnumerable collections); - Task ReplaceAsync(Group obj, IEnumerable collections); - Task DeleteUserAsync(Guid groupId, Guid organizationUserId); - Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds); + public interface IGroupRepository : IRepository + { + Task>> GetByIdWithCollectionsAsync(Guid id); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task> GetManyIdsByUserIdAsync(Guid organizationUserId); + Task> GetManyUserIdsByIdAsync(Guid id); + Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId); + Task CreateAsync(Group obj, IEnumerable collections); + Task ReplaceAsync(Group obj, IEnumerable collections); + Task DeleteUserAsync(Guid groupId, Guid organizationUserId); + Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds); + } } diff --git a/src/Core/Repositories/IInstallationDeviceRepository.cs b/src/Core/Repositories/IInstallationDeviceRepository.cs index bdbeaf2975..394b80837d 100644 --- a/src/Core/Repositories/IInstallationDeviceRepository.cs +++ b/src/Core/Repositories/IInstallationDeviceRepository.cs @@ -1,10 +1,11 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IInstallationDeviceRepository +namespace Bit.Core.Repositories { - Task UpsertAsync(InstallationDeviceEntity entity); - Task UpsertManyAsync(IList entities); - Task DeleteAsync(InstallationDeviceEntity entity); + public interface IInstallationDeviceRepository + { + Task UpsertAsync(InstallationDeviceEntity entity); + Task UpsertManyAsync(IList entities); + Task DeleteAsync(InstallationDeviceEntity entity); + } } diff --git a/src/Core/Repositories/IInstallationRepository.cs b/src/Core/Repositories/IInstallationRepository.cs index 65ee34aafe..f88e81e5f1 100644 --- a/src/Core/Repositories/IInstallationRepository.cs +++ b/src/Core/Repositories/IInstallationRepository.cs @@ -1,7 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface IInstallationRepository : IRepository +namespace Bit.Core.Repositories { + public interface IInstallationRepository : IRepository + { + } } diff --git a/src/Core/Repositories/IMaintenanceRepository.cs b/src/Core/Repositories/IMaintenanceRepository.cs index a89c38bd02..c1dc098c6b 100644 --- a/src/Core/Repositories/IMaintenanceRepository.cs +++ b/src/Core/Repositories/IMaintenanceRepository.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Repositories; - -public interface IMaintenanceRepository +namespace Bit.Core.Repositories { - Task UpdateStatisticsAsync(); - Task DisableCipherAutoStatsAsync(); - Task RebuildIndexesAsync(); - Task DeleteExpiredGrantsAsync(); - Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate); + public interface IMaintenanceRepository + { + Task UpdateStatisticsAsync(); + Task DisableCipherAutoStatsAsync(); + Task RebuildIndexesAsync(); + Task DeleteExpiredGrantsAsync(); + Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate); + } } diff --git a/src/Core/Repositories/IMetaDataRepository.cs b/src/Core/Repositories/IMetaDataRepository.cs index e087234da6..69895b9c85 100644 --- a/src/Core/Repositories/IMetaDataRepository.cs +++ b/src/Core/Repositories/IMetaDataRepository.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Repositories; - -public interface IMetaDataRepository +namespace Bit.Core.Repositories { - Task DeleteAsync(string objectName, string id); - Task> GetAsync(string objectName, string id); - Task GetAsync(string objectName, string id, string prop); - Task UpsertAsync(string objectName, string id, IDictionary dict); - Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair); + public interface IMetaDataRepository + { + Task DeleteAsync(string objectName, string id); + Task> GetAsync(string objectName, string id); + Task GetAsync(string objectName, string id, string prop); + Task UpsertAsync(string objectName, string id, IDictionary dict); + Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair); + } } diff --git a/src/Core/Repositories/IOrganizationApiKeyRepository.cs b/src/Core/Repositories/IOrganizationApiKeyRepository.cs index 778db9d734..8b1b24978a 100644 --- a/src/Core/Repositories/IOrganizationApiKeyRepository.cs +++ b/src/Core/Repositories/IOrganizationApiKeyRepository.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories; - -public interface IOrganizationApiKeyRepository : IRepository +namespace Bit.Core.Repositories { - Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null); + public interface IOrganizationApiKeyRepository : IRepository + { + Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null); + } } diff --git a/src/Core/Repositories/IOrganizationConnectionRepository.cs b/src/Core/Repositories/IOrganizationConnectionRepository.cs index a3bdbb0370..b87a82d14c 100644 --- a/src/Core/Repositories/IOrganizationConnectionRepository.cs +++ b/src/Core/Repositories/IOrganizationConnectionRepository.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories; - -public interface IOrganizationConnectionRepository : IRepository +namespace Bit.Core.Repositories { - Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); - Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); + public interface IOrganizationConnectionRepository : IRepository + { + Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); + Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); + } } diff --git a/src/Core/Repositories/IOrganizationRepository.cs b/src/Core/Repositories/IOrganizationRepository.cs index 690bff9139..392a925c37 100644 --- a/src/Core/Repositories/IOrganizationRepository.cs +++ b/src/Core/Repositories/IOrganizationRepository.cs @@ -1,14 +1,15 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations; -namespace Bit.Core.Repositories; - -public interface IOrganizationRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByIdentifierAsync(string identifier); - Task> GetManyByEnabledAsync(); - Task> GetManyByUserIdAsync(Guid userId); - Task> SearchAsync(string name, string userEmail, bool? paid, int skip, int take); - Task UpdateStorageAsync(Guid id); - Task> GetManyAbilitiesAsync(); + public interface IOrganizationRepository : IRepository + { + Task GetByIdentifierAsync(string identifier); + Task> GetManyByEnabledAsync(); + Task> GetManyByUserIdAsync(Guid userId); + Task> SearchAsync(string name, string userEmail, bool? paid, int skip, int take); + Task UpdateStorageAsync(Guid id); + Task> GetManyAbilitiesAsync(); + } } diff --git a/src/Core/Repositories/IOrganizationSponsorshipRepository.cs b/src/Core/Repositories/IOrganizationSponsorshipRepository.cs index 232fd1b9dd..2ef24580cd 100644 --- a/src/Core/Repositories/IOrganizationSponsorshipRepository.cs +++ b/src/Core/Repositories/IOrganizationSponsorshipRepository.cs @@ -1,15 +1,16 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface IOrganizationSponsorshipRepository : IRepository +namespace Bit.Core.Repositories { - Task> CreateManyAsync(IEnumerable organizationSponsorships); - Task ReplaceManyAsync(IEnumerable organizationSponsorships); - Task UpsertManyAsync(IEnumerable organizationSponsorships); - Task DeleteManyAsync(IEnumerable organizationSponsorshipIds); - Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId); - Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId); - Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId); - Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId); + public interface IOrganizationSponsorshipRepository : IRepository + { + Task> CreateManyAsync(IEnumerable organizationSponsorships); + Task ReplaceManyAsync(IEnumerable organizationSponsorships); + Task UpsertManyAsync(IEnumerable organizationSponsorships); + Task DeleteManyAsync(IEnumerable organizationSponsorshipIds); + Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId); + Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId); + Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId); + Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId); + } } diff --git a/src/Core/Repositories/IOrganizationUserRepository.cs b/src/Core/Repositories/IOrganizationUserRepository.cs index f8909f7844..8597adb522 100644 --- a/src/Core/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/Repositories/IOrganizationUserRepository.cs @@ -3,39 +3,40 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.Repositories; - -public interface IOrganizationUserRepository : IRepository +namespace Bit.Core.Repositories { - Task GetCountByOrganizationIdAsync(Guid organizationId); - Task GetCountByFreeOrganizationAdminUserAsync(Guid userId); - Task GetCountByOnlyOwnerAsync(Guid userId); - Task> GetManyByUserAsync(Guid userId); - Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type); - Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers); - Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers); - Task GetByOrganizationAsync(Guid organizationId, Guid userId); - Task>> GetByIdWithCollectionsAsync(Guid id); - Task GetDetailsByIdAsync(Guid id); - Task>> - GetDetailsByIdWithCollectionsAsync(Guid id); - Task> GetManyDetailsByOrganizationAsync(Guid organizationId); - Task> GetManyDetailsByUserAsync(Guid userId, - OrganizationUserStatusType? status = null); - Task GetDetailsByUserAsync(Guid userId, Guid organizationId, - OrganizationUserStatusType? status = null); - Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds); - Task UpsertManyAsync(IEnumerable organizationUsers); - Task CreateAsync(OrganizationUser obj, IEnumerable collections); - Task> CreateManyAsync(IEnumerable organizationIdUsers); - Task ReplaceAsync(OrganizationUser obj, IEnumerable collections); - Task ReplaceManyAsync(IEnumerable organizationUsers); - Task> GetManyByManyUsersAsync(IEnumerable userIds); - Task> GetManyAsync(IEnumerable Ids); - Task DeleteManyAsync(IEnumerable userIds); - Task GetByOrganizationEmailAsync(Guid organizationId, string email); - Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids); - Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole); - Task RevokeAsync(Guid id); - Task RestoreAsync(Guid id, OrganizationUserStatusType status); + public interface IOrganizationUserRepository : IRepository + { + Task GetCountByOrganizationIdAsync(Guid organizationId); + Task GetCountByFreeOrganizationAdminUserAsync(Guid userId); + Task GetCountByOnlyOwnerAsync(Guid userId); + Task> GetManyByUserAsync(Guid userId); + Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type); + Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers); + Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers); + Task GetByOrganizationAsync(Guid organizationId, Guid userId); + Task>> GetByIdWithCollectionsAsync(Guid id); + Task GetDetailsByIdAsync(Guid id); + Task>> + GetDetailsByIdWithCollectionsAsync(Guid id); + Task> GetManyDetailsByOrganizationAsync(Guid organizationId); + Task> GetManyDetailsByUserAsync(Guid userId, + OrganizationUserStatusType? status = null); + Task GetDetailsByUserAsync(Guid userId, Guid organizationId, + OrganizationUserStatusType? status = null); + Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds); + Task UpsertManyAsync(IEnumerable organizationUsers); + Task CreateAsync(OrganizationUser obj, IEnumerable collections); + Task> CreateManyAsync(IEnumerable organizationIdUsers); + Task ReplaceAsync(OrganizationUser obj, IEnumerable collections); + Task ReplaceManyAsync(IEnumerable organizationUsers); + Task> GetManyByManyUsersAsync(IEnumerable userIds); + Task> GetManyAsync(IEnumerable Ids); + Task DeleteManyAsync(IEnumerable userIds); + Task GetByOrganizationEmailAsync(Guid organizationId, string email); + Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids); + Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole); + Task RevokeAsync(Guid id); + Task RestoreAsync(Guid id, OrganizationUserStatusType status); + } } diff --git a/src/Core/Repositories/IPolicyRepository.cs b/src/Core/Repositories/IPolicyRepository.cs index ce965e1745..34206770e5 100644 --- a/src/Core/Repositories/IPolicyRepository.cs +++ b/src/Core/Repositories/IPolicyRepository.cs @@ -1,15 +1,16 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories; - -public interface IPolicyRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); - Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); + public interface IPolicyRepository : IRepository + { + Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); + Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); + } } diff --git a/src/Core/Repositories/IProviderOrganizationRepository.cs b/src/Core/Repositories/IProviderOrganizationRepository.cs index b546d8d2ef..7c2cfb3b19 100644 --- a/src/Core/Repositories/IProviderOrganizationRepository.cs +++ b/src/Core/Repositories/IProviderOrganizationRepository.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IProviderOrganizationRepository : IRepository +namespace Bit.Core.Repositories { - Task> GetManyDetailsByProviderAsync(Guid providerId); - Task GetByOrganizationId(Guid organizationId); + public interface IProviderOrganizationRepository : IRepository + { + Task> GetManyDetailsByProviderAsync(Guid providerId); + Task GetByOrganizationId(Guid organizationId); + } } diff --git a/src/Core/Repositories/IProviderRepository.cs b/src/Core/Repositories/IProviderRepository.cs index 8d92fb6d2e..5b3700d818 100644 --- a/src/Core/Repositories/IProviderRepository.cs +++ b/src/Core/Repositories/IProviderRepository.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IProviderRepository : IRepository +namespace Bit.Core.Repositories { - Task> SearchAsync(string name, string userEmail, int skip, int take); - Task> GetManyAbilitiesAsync(); + public interface IProviderRepository : IRepository + { + Task> SearchAsync(string name, string userEmail, int skip, int take); + Task> GetManyAbilitiesAsync(); + } } diff --git a/src/Core/Repositories/IProviderUserRepository.cs b/src/Core/Repositories/IProviderUserRepository.cs index 4a5db368ee..14882a0a5f 100644 --- a/src/Core/Repositories/IProviderUserRepository.cs +++ b/src/Core/Repositories/IProviderUserRepository.cs @@ -2,20 +2,21 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IProviderUserRepository : IRepository +namespace Bit.Core.Repositories { - Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers); - Task> GetManyAsync(IEnumerable ids); - Task> GetManyByUserAsync(Guid userId); - Task GetByProviderUserAsync(Guid providerId, Guid userId); - Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); - Task> GetManyDetailsByProviderAsync(Guid providerId); - Task> GetManyDetailsByUserAsync(Guid userId, - ProviderUserStatusType? status = null); - Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null); - Task DeleteManyAsync(IEnumerable userIds); - Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids); - Task GetCountByOnlyOwnerAsync(Guid userId); + public interface IProviderUserRepository : IRepository + { + Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers); + Task> GetManyAsync(IEnumerable ids); + Task> GetManyByUserAsync(Guid userId); + Task GetByProviderUserAsync(Guid providerId, Guid userId); + Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); + Task> GetManyDetailsByProviderAsync(Guid providerId); + Task> GetManyDetailsByUserAsync(Guid userId, + ProviderUserStatusType? status = null); + Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null); + Task DeleteManyAsync(IEnumerable userIds); + Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids); + Task GetCountByOnlyOwnerAsync(Guid userId); + } } diff --git a/src/Core/Repositories/IRepository.cs b/src/Core/Repositories/IRepository.cs index 18bb81ff8f..3316bef518 100644 --- a/src/Core/Repositories/IRepository.cs +++ b/src/Core/Repositories/IRepository.cs @@ -1,12 +1,13 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface IRepository where TId : IEquatable where T : class, ITableObject +namespace Bit.Core.Repositories { - Task GetByIdAsync(TId id); - Task CreateAsync(T obj); - Task ReplaceAsync(T obj); - Task UpsertAsync(T obj); - Task DeleteAsync(T obj); + public interface IRepository where TId : IEquatable where T : class, ITableObject + { + Task GetByIdAsync(TId id); + Task CreateAsync(T obj); + Task ReplaceAsync(T obj); + Task UpsertAsync(T obj); + Task DeleteAsync(T obj); + } } diff --git a/src/Core/Repositories/ISendRepository.cs b/src/Core/Repositories/ISendRepository.cs index b35a059d31..4a4fe5ebf2 100644 --- a/src/Core/Repositories/ISendRepository.cs +++ b/src/Core/Repositories/ISendRepository.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface ISendRepository : IRepository +namespace Bit.Core.Repositories { - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore); + public interface ISendRepository : IRepository + { + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore); + } } diff --git a/src/Core/Repositories/ISsoConfigRepository.cs b/src/Core/Repositories/ISsoConfigRepository.cs index 8f65618190..2350e0a4ad 100644 --- a/src/Core/Repositories/ISsoConfigRepository.cs +++ b/src/Core/Repositories/ISsoConfigRepository.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface ISsoConfigRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByOrganizationIdAsync(Guid organizationId); - Task GetByIdentifierAsync(string identifier); - Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore); + public interface ISsoConfigRepository : IRepository + { + Task GetByOrganizationIdAsync(Guid organizationId); + Task GetByIdentifierAsync(string identifier); + Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore); + } } diff --git a/src/Core/Repositories/ISsoUserRepository.cs b/src/Core/Repositories/ISsoUserRepository.cs index 653734450a..6dcada9296 100644 --- a/src/Core/Repositories/ISsoUserRepository.cs +++ b/src/Core/Repositories/ISsoUserRepository.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface ISsoUserRepository : IRepository +namespace Bit.Core.Repositories { - Task DeleteAsync(Guid userId, Guid? organizationId); - Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId); + public interface ISsoUserRepository : IRepository + { + Task DeleteAsync(Guid userId, Guid? organizationId); + Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId); + } } diff --git a/src/Core/Repositories/ITaxRateRepository.cs b/src/Core/Repositories/ITaxRateRepository.cs index a8557a7890..779c2c7146 100644 --- a/src/Core/Repositories/ITaxRateRepository.cs +++ b/src/Core/Repositories/ITaxRateRepository.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories; - -public interface ITaxRateRepository : IRepository +namespace Bit.Core.Repositories { - Task> SearchAsync(int skip, int count); - Task> GetAllActiveAsync(); - Task ArchiveAsync(TaxRate model); - Task> GetByLocationAsync(TaxRate taxRate); + public interface ITaxRateRepository : IRepository + { + Task> SearchAsync(int skip, int count); + Task> GetAllActiveAsync(); + Task ArchiveAsync(TaxRate model); + Task> GetByLocationAsync(TaxRate taxRate); + } } diff --git a/src/Core/Repositories/ITransactionRepository.cs b/src/Core/Repositories/ITransactionRepository.cs index 82b6f961b8..6fb9b27b76 100644 --- a/src/Core/Repositories/ITransactionRepository.cs +++ b/src/Core/Repositories/ITransactionRepository.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories; - -public interface ITransactionRepository : IRepository +namespace Bit.Core.Repositories { - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId); + public interface ITransactionRepository : IRepository + { + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId); + } } diff --git a/src/Core/Repositories/IUserRepository.cs b/src/Core/Repositories/IUserRepository.cs index 0c6ee85712..8ed89f7e09 100644 --- a/src/Core/Repositories/IUserRepository.cs +++ b/src/Core/Repositories/IUserRepository.cs @@ -1,18 +1,19 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories; - -public interface IUserRepository : IRepository +namespace Bit.Core.Repositories { - Task GetByEmailAsync(string email); - Task GetBySsoUserAsync(string externalId, Guid? organizationId); - Task GetKdfInformationByEmailAsync(string email); - Task> SearchAsync(string email, int skip, int take); - Task> GetManyByPremiumAsync(bool premium); - Task GetPublicKeyAsync(Guid id); - Task GetAccountRevisionDateAsync(Guid id); - Task UpdateStorageAsync(Guid id); - Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate); - Task> GetManyAsync(IEnumerable ids); + public interface IUserRepository : IRepository + { + Task GetByEmailAsync(string email); + Task GetBySsoUserAsync(string externalId, Guid? organizationId); + Task GetKdfInformationByEmailAsync(string email); + Task> SearchAsync(string email, int skip, int take); + Task> GetManyByPremiumAsync(bool premium); + Task GetPublicKeyAsync(Guid id); + Task GetAccountRevisionDateAsync(Guid id); + Task UpdateStorageAsync(Guid id); + Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate); + Task> GetManyAsync(IEnumerable ids); + } } diff --git a/src/Core/Repositories/Noop/InstallationDeviceRepository.cs b/src/Core/Repositories/Noop/InstallationDeviceRepository.cs index b704459013..eb446547ad 100644 --- a/src/Core/Repositories/Noop/InstallationDeviceRepository.cs +++ b/src/Core/Repositories/Noop/InstallationDeviceRepository.cs @@ -1,21 +1,22 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Repositories.Noop; - -public class InstallationDeviceRepository : IInstallationDeviceRepository +namespace Bit.Core.Repositories.Noop { - public Task UpsertAsync(InstallationDeviceEntity entity) + public class InstallationDeviceRepository : IInstallationDeviceRepository { - return Task.FromResult(0); - } + public Task UpsertAsync(InstallationDeviceEntity entity) + { + return Task.FromResult(0); + } - public Task UpsertManyAsync(IList entities) - { - return Task.FromResult(0); - } + public Task UpsertManyAsync(IList entities) + { + return Task.FromResult(0); + } - public Task DeleteAsync(InstallationDeviceEntity entity) - { - return Task.FromResult(0); + public Task DeleteAsync(InstallationDeviceEntity entity) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Repositories/Noop/MetaDataRepository.cs b/src/Core/Repositories/Noop/MetaDataRepository.cs index bc235c683c..1f46584552 100644 --- a/src/Core/Repositories/Noop/MetaDataRepository.cs +++ b/src/Core/Repositories/Noop/MetaDataRepository.cs @@ -1,29 +1,30 @@ -namespace Bit.Core.Repositories.Noop; - -public class MetaDataRepository : IMetaDataRepository +namespace Bit.Core.Repositories.Noop { - public Task DeleteAsync(string objectName, string id) + public class MetaDataRepository : IMetaDataRepository { - return Task.FromResult(0); - } + public Task DeleteAsync(string objectName, string id) + { + return Task.FromResult(0); + } - public Task> GetAsync(string objectName, string id) - { - return Task.FromResult(null as IDictionary); - } + public Task> GetAsync(string objectName, string id) + { + return Task.FromResult(null as IDictionary); + } - public Task GetAsync(string objectName, string id, string prop) - { - return Task.FromResult(null as string); - } + public Task GetAsync(string objectName, string id, string prop) + { + return Task.FromResult(null as string); + } - public Task UpsertAsync(string objectName, string id, IDictionary dict) - { - return Task.FromResult(0); - } + public Task UpsertAsync(string objectName, string id, IDictionary dict) + { + return Task.FromResult(0); + } - public Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) - { - return Task.FromResult(0); + public Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Repositories/TableStorage/EventRepository.cs b/src/Core/Repositories/TableStorage/EventRepository.cs index 514b61099b..9ee541b8a3 100644 --- a/src/Core/Repositories/TableStorage/EventRepository.cs +++ b/src/Core/Repositories/TableStorage/EventRepository.cs @@ -4,184 +4,185 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Repositories.TableStorage; - -public class EventRepository : IEventRepository +namespace Bit.Core.Repositories.TableStorage { - private readonly CloudTable _table; - - public EventRepository(GlobalSettings globalSettings) - : this(globalSettings.Events.ConnectionString) - { } - - public EventRepository(string storageConnectionString) + public class EventRepository : IEventRepository { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("event"); - } + private readonly CloudTable _table; - public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, - PageOptions pageOptions) - { - return await GetManyAsync($"UserId={userId}", "Date={{0}}", startDate, endDate, pageOptions); - } + public EventRepository(GlobalSettings globalSettings) + : this(globalSettings.Events.ConnectionString) + { } - public async Task> GetManyByOrganizationAsync(Guid organizationId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"OrganizationId={organizationId}", "Date={0}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"OrganizationId={organizationId}", - $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderAsync(Guid providerId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"ProviderId={providerId}", "Date={0}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"ProviderId={providerId}", - $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, - PageOptions pageOptions) - { - var partitionKey = cipher.OrganizationId.HasValue ? - $"OrganizationId={cipher.OrganizationId}" : $"UserId={cipher.UserId}"; - return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions); - } - - public async Task CreateAsync(IEvent e) - { - if (!(e is EventTableEntity entity)) + public EventRepository(string storageConnectionString) { - throw new ArgumentException(nameof(e)); + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("event"); } - await CreateEntityAsync(entity); - } - - public async Task CreateManyAsync(IEnumerable e) - { - if (!e?.Any() ?? true) + public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, + PageOptions pageOptions) { - return; + return await GetManyAsync($"UserId={userId}", "Date={{0}}", startDate, endDate, pageOptions); } - if (!e.Skip(1).Any()) + public async Task> GetManyByOrganizationAsync(Guid organizationId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) { - await CreateAsync(e.First()); - return; + return await GetManyAsync($"OrganizationId={organizationId}", "Date={0}", startDate, endDate, pageOptions); } - var entities = e.Where(ev => ev is EventTableEntity).Select(ev => ev as EventTableEntity); - var entityGroups = entities.GroupBy(ent => ent.PartitionKey); - foreach (var group in entityGroups) + public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) { - var groupEntities = group.ToList(); - if (groupEntities.Count == 1) + return await GetManyAsync($"OrganizationId={organizationId}", + $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderAsync(Guid providerId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"ProviderId={providerId}", "Date={0}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"ProviderId={providerId}", + $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, + PageOptions pageOptions) + { + var partitionKey = cipher.OrganizationId.HasValue ? + $"OrganizationId={cipher.OrganizationId}" : $"UserId={cipher.UserId}"; + return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions); + } + + public async Task CreateAsync(IEvent e) + { + if (!(e is EventTableEntity entity)) { - await CreateEntityAsync(groupEntities.First()); - continue; + throw new ArgumentException(nameof(e)); } - // A batch insert can only contain 100 entities at a time - var iterations = groupEntities.Count / 100; - for (var i = 0; i <= iterations; i++) + await CreateEntityAsync(entity); + } + + public async Task CreateManyAsync(IEnumerable e) + { + if (!e?.Any() ?? true) { - var batch = new TableBatchOperation(); - var batchEntities = groupEntities.Skip(i * 100).Take(100); - if (!batchEntities.Any()) + return; + } + + if (!e.Skip(1).Any()) + { + await CreateAsync(e.First()); + return; + } + + var entities = e.Where(ev => ev is EventTableEntity).Select(ev => ev as EventTableEntity); + var entityGroups = entities.GroupBy(ent => ent.PartitionKey); + foreach (var group in entityGroups) + { + var groupEntities = group.ToList(); + if (groupEntities.Count == 1) { - break; + await CreateEntityAsync(groupEntities.First()); + continue; } - foreach (var entity in batchEntities) + // A batch insert can only contain 100 entities at a time + var iterations = groupEntities.Count / 100; + for (var i = 0; i <= iterations; i++) { - batch.InsertOrReplace(entity); - } + var batch = new TableBatchOperation(); + var batchEntities = groupEntities.Skip(i * 100).Take(100); + if (!batchEntities.Any()) + { + break; + } - await _table.ExecuteBatchAsync(batch); + foreach (var entity in batchEntities) + { + batch.InsertOrReplace(entity); + } + + await _table.ExecuteBatchAsync(batch); + } } } - } - public async Task CreateEntityAsync(ITableEntity entity) - { - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } - - public async Task> GetManyAsync(string partitionKey, string rowKey, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - var start = CoreHelpers.DateTimeToTableStorageKey(startDate); - var end = CoreHelpers.DateTimeToTableStorageKey(endDate); - var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end)); - - var query = new TableQuery().Where(filter).Take(pageOptions.PageSize); - var result = new PagedResult(); - var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken); - - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken); - result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken); - result.Data.AddRange(queryResults.Results); - - return result; - } - - private string MakeFilter(string partitionKey, string rowStart, string rowEnd) - { - var rowFilter = TableQuery.CombineFilters( - TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"), - TableOperators.And, - TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_")); - - return TableQuery.CombineFilters( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey), - TableOperators.And, - rowFilter); - } - - private string SerializeContinuationToken(TableContinuationToken token) - { - if (token == null) + public async Task CreateEntityAsync(ITableEntity entity) { - return null; + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); } - return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName, - token.NextPartitionKey, token.NextRowKey); - } - - private TableContinuationToken DeserializeContinuationToken(string token) - { - if (string.IsNullOrWhiteSpace(token)) + public async Task> GetManyAsync(string partitionKey, string rowKey, + DateTime startDate, DateTime endDate, PageOptions pageOptions) { - return null; + var start = CoreHelpers.DateTimeToTableStorageKey(startDate); + var end = CoreHelpers.DateTimeToTableStorageKey(endDate); + var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end)); + + var query = new TableQuery().Where(filter).Take(pageOptions.PageSize); + var result = new PagedResult(); + var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken); + + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken); + result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken); + result.Data.AddRange(queryResults.Results); + + return result; } - var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None); - if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc)) + private string MakeFilter(string partitionKey, string rowStart, string rowEnd) { - return null; + var rowFilter = TableQuery.CombineFilters( + TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"), + TableOperators.And, + TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_")); + + return TableQuery.CombineFilters( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey), + TableOperators.And, + rowFilter); } - return new TableContinuationToken + private string SerializeContinuationToken(TableContinuationToken token) { - TargetLocation = tLoc, - NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1], - NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2], - NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3] - }; + if (token == null) + { + return null; + } + + return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName, + token.NextPartitionKey, token.NextRowKey); + } + + private TableContinuationToken DeserializeContinuationToken(string token) + { + if (string.IsNullOrWhiteSpace(token)) + { + return null; + } + + var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None); + if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc)) + { + return null; + } + + return new TableContinuationToken + { + TargetLocation = tLoc, + NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1], + NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2], + NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3] + }; + } } } diff --git a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs index 32b466d1b3..125360e6bd 100644 --- a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs +++ b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs @@ -3,82 +3,83 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Repositories.TableStorage; - -public class InstallationDeviceRepository : IInstallationDeviceRepository +namespace Bit.Core.Repositories.TableStorage { - private readonly CloudTable _table; - - public InstallationDeviceRepository(GlobalSettings globalSettings) - : this(globalSettings.Events.ConnectionString) - { } - - public InstallationDeviceRepository(string storageConnectionString) + public class InstallationDeviceRepository : IInstallationDeviceRepository { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("installationdevice"); - } + private readonly CloudTable _table; - public async Task UpsertAsync(InstallationDeviceEntity entity) - { - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } + public InstallationDeviceRepository(GlobalSettings globalSettings) + : this(globalSettings.Events.ConnectionString) + { } - public async Task UpsertManyAsync(IList entities) - { - if (!entities?.Any() ?? true) + public InstallationDeviceRepository(string storageConnectionString) { - return; + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("installationdevice"); } - if (entities.Count == 1) + public async Task UpsertAsync(InstallationDeviceEntity entity) { - await UpsertAsync(entities.First()); - return; + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); } - var entityGroups = entities.GroupBy(ent => ent.PartitionKey); - foreach (var group in entityGroups) + public async Task UpsertManyAsync(IList entities) { - var groupEntities = group.ToList(); - if (groupEntities.Count == 1) + if (!entities?.Any() ?? true) { - await UpsertAsync(groupEntities.First()); - continue; + return; } - // A batch insert can only contain 100 entities at a time - var iterations = groupEntities.Count / 100; - for (var i = 0; i <= iterations; i++) + if (entities.Count == 1) { - var batch = new TableBatchOperation(); - var batchEntities = groupEntities.Skip(i * 100).Take(100); - if (!batchEntities.Any()) + await UpsertAsync(entities.First()); + return; + } + + var entityGroups = entities.GroupBy(ent => ent.PartitionKey); + foreach (var group in entityGroups) + { + var groupEntities = group.ToList(); + if (groupEntities.Count == 1) { - break; + await UpsertAsync(groupEntities.First()); + continue; } - foreach (var entity in batchEntities) + // A batch insert can only contain 100 entities at a time + var iterations = groupEntities.Count / 100; + for (var i = 0; i <= iterations; i++) { - batch.InsertOrReplace(entity); - } + var batch = new TableBatchOperation(); + var batchEntities = groupEntities.Skip(i * 100).Take(100); + if (!batchEntities.Any()) + { + break; + } - await _table.ExecuteBatchAsync(batch); + foreach (var entity in batchEntities) + { + batch.InsertOrReplace(entity); + } + + await _table.ExecuteBatchAsync(batch); + } } } - } - public async Task DeleteAsync(InstallationDeviceEntity entity) - { - try + public async Task DeleteAsync(InstallationDeviceEntity entity) { - entity.ETag = "*"; - await _table.ExecuteAsync(TableOperation.Delete(entity)); - } - catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) - { - throw; + try + { + entity.ETag = "*"; + await _table.ExecuteAsync(TableOperation.Delete(entity)); + } + catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) + { + throw; + } } } } diff --git a/src/Core/Repositories/TableStorage/MetaDataRepository.cs b/src/Core/Repositories/TableStorage/MetaDataRepository.cs index c70426e2a2..83ae04e4be 100644 --- a/src/Core/Repositories/TableStorage/MetaDataRepository.cs +++ b/src/Core/Repositories/TableStorage/MetaDataRepository.cs @@ -3,91 +3,92 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Repositories.TableStorage; - -public class MetaDataRepository : IMetaDataRepository +namespace Bit.Core.Repositories.TableStorage { - private readonly CloudTable _table; - - public MetaDataRepository(GlobalSettings globalSettings) - : this(globalSettings.Events.ConnectionString) - { } - - public MetaDataRepository(string storageConnectionString) + public class MetaDataRepository : IMetaDataRepository { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("metadata"); - } + private readonly CloudTable _table; - public async Task> GetAsync(string objectName, string id) - { - var query = new TableQuery().Where( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); - return queryResults.Results.FirstOrDefault()?.ToDictionary(d => d.Key, d => d.Value.StringValue); - } + public MetaDataRepository(GlobalSettings globalSettings) + : this(globalSettings.Events.ConnectionString) + { } - public async Task GetAsync(string objectName, string id, string prop) - { - var dict = await GetAsync(objectName, id); - if (dict != null && dict.ContainsKey(prop)) + public MetaDataRepository(string storageConnectionString) { - return dict[prop]; + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("metadata"); } - return null; - } - public async Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) - { - var query = new TableQuery().Where( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); - var entity = queryResults.Results.FirstOrDefault(); - if (entity == null) + public async Task> GetAsync(string objectName, string id) { - entity = new DictionaryEntity + var query = new TableQuery().Where( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); + return queryResults.Results.FirstOrDefault()?.ToDictionary(d => d.Key, d => d.Value.StringValue); + } + + public async Task GetAsync(string objectName, string id, string prop) + { + var dict = await GetAsync(objectName, id); + if (dict != null && dict.ContainsKey(prop)) + { + return dict[prop]; + } + return null; + } + + public async Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) + { + var query = new TableQuery().Where( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); + var entity = queryResults.Results.FirstOrDefault(); + if (entity == null) + { + entity = new DictionaryEntity + { + PartitionKey = $"{objectName}_{id}", + RowKey = string.Empty + }; + } + if (entity.ContainsKey(keyValuePair.Key)) + { + entity.Remove(keyValuePair.Key); + } + entity.Add(keyValuePair.Key, keyValuePair.Value); + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + } + + public async Task UpsertAsync(string objectName, string id, IDictionary dict) + { + var entity = new DictionaryEntity { PartitionKey = $"{objectName}_{id}", RowKey = string.Empty }; - } - if (entity.ContainsKey(keyValuePair.Key)) - { - entity.Remove(keyValuePair.Key); - } - entity.Add(keyValuePair.Key, keyValuePair.Value); - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } - - public async Task UpsertAsync(string objectName, string id, IDictionary dict) - { - var entity = new DictionaryEntity - { - PartitionKey = $"{objectName}_{id}", - RowKey = string.Empty - }; - foreach (var item in dict) - { - entity.Add(item.Key, item.Value); - } - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } - - public async Task DeleteAsync(string objectName, string id) - { - try - { - await _table.ExecuteAsync(TableOperation.Delete(new DictionaryEntity + foreach (var item in dict) { - PartitionKey = $"{objectName}_{id}", - RowKey = string.Empty, - ETag = "*" - })); + entity.Add(item.Key, item.Value); + } + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); } - catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) + + public async Task DeleteAsync(string objectName, string id) { - throw; + try + { + await _table.ExecuteAsync(TableOperation.Delete(new DictionaryEntity + { + PartitionKey = $"{objectName}_{id}", + RowKey = string.Empty, + ETag = "*" + })); + } + catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) + { + throw; + } } } } diff --git a/src/Core/Resources/SharedResources.cs b/src/Core/Resources/SharedResources.cs index 39eea7c6c4..543ec227aa 100644 --- a/src/Core/Resources/SharedResources.cs +++ b/src/Core/Resources/SharedResources.cs @@ -1,5 +1,6 @@ -namespace Bit.Core.Resources; - -public class SharedResources +namespace Bit.Core.Resources { + public class SharedResources + { + } } diff --git a/src/Core/Services/IAppleIapService.cs b/src/Core/Services/IAppleIapService.cs index b258b9e3b3..aef7e2c88e 100644 --- a/src/Core/Services/IAppleIapService.cs +++ b/src/Core/Services/IAppleIapService.cs @@ -1,10 +1,11 @@ using Bit.Billing.Models; -namespace Bit.Core.Services; - -public interface IAppleIapService +namespace Bit.Core.Services { - Task GetVerifiedReceiptStatusAsync(string receiptData); - Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId); - Task> GetReceiptAsync(string originalTransactionId); + public interface IAppleIapService + { + Task GetVerifiedReceiptStatusAsync(string receiptData); + Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId); + Task> GetReceiptAsync(string originalTransactionId); + } } diff --git a/src/Core/Services/IApplicationCacheService.cs b/src/Core/Services/IApplicationCacheService.cs index 7c21fac76f..08efe7b7ce 100644 --- a/src/Core/Services/IApplicationCacheService.cs +++ b/src/Core/Services/IApplicationCacheService.cs @@ -3,13 +3,14 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; -namespace Bit.Core.Services; - -public interface IApplicationCacheService +namespace Bit.Core.Services { - Task> GetOrganizationAbilitiesAsync(); - Task> GetProviderAbilitiesAsync(); - Task UpsertOrganizationAbilityAsync(Organization organization); - Task UpsertProviderAbilityAsync(Provider provider); - Task DeleteOrganizationAbilityAsync(Guid organizationId); + public interface IApplicationCacheService + { + Task> GetOrganizationAbilitiesAsync(); + Task> GetProviderAbilitiesAsync(); + Task UpsertOrganizationAbilityAsync(Organization organization); + Task UpsertProviderAbilityAsync(Provider provider); + Task DeleteOrganizationAbilityAsync(Guid organizationId); + } } diff --git a/src/Core/Services/IAttachmentStorageService.cs b/src/Core/Services/IAttachmentStorageService.cs index 964b711f05..c0b11a0217 100644 --- a/src/Core/Services/IAttachmentStorageService.cs +++ b/src/Core/Services/IAttachmentStorageService.cs @@ -2,21 +2,22 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface IAttachmentStorageService +namespace Bit.Core.Services { - FileUploadType FileUploadType { get; } - Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData); - Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); - Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); - Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer); - Task CleanupAsync(Guid cipherId); - Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData); - Task DeleteAttachmentsForCipherAsync(Guid cipherId); - Task DeleteAttachmentsForOrganizationAsync(Guid organizationId); - Task DeleteAttachmentsForUserAsync(Guid userId); - Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); - Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); - Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway); + public interface IAttachmentStorageService + { + FileUploadType FileUploadType { get; } + Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData); + Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); + Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); + Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer); + Task CleanupAsync(Guid cipherId); + Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData); + Task DeleteAttachmentsForCipherAsync(Guid cipherId); + Task DeleteAttachmentsForOrganizationAsync(Guid organizationId); + Task DeleteAttachmentsForUserAsync(Guid userId); + Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway); + } } diff --git a/src/Core/Services/IBlockIpService.cs b/src/Core/Services/IBlockIpService.cs index 87af1a2ce6..547a7cedeb 100644 --- a/src/Core/Services/IBlockIpService.cs +++ b/src/Core/Services/IBlockIpService.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Services; - -public interface IBlockIpService +namespace Bit.Core.Services { - Task BlockIpAsync(string ipAddress, bool permanentBlock); + public interface IBlockIpService + { + Task BlockIpAsync(string ipAddress, bool permanentBlock); + } } diff --git a/src/Core/Services/ICaptchaValidationService.cs b/src/Core/Services/ICaptchaValidationService.cs index 50faad31f8..d908be7c24 100644 --- a/src/Core/Services/ICaptchaValidationService.cs +++ b/src/Core/Services/ICaptchaValidationService.cs @@ -2,14 +2,15 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.Services; - -public interface ICaptchaValidationService +namespace Bit.Core.Services { - string SiteKey { get; } - string SiteKeyResponseKeyName { get; } - bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null); - Task ValidateCaptchaResponseAsync(string captchResponse, string clientIpAddress, - User user = null); - string GenerateCaptchaBypassToken(User user); + public interface ICaptchaValidationService + { + string SiteKey { get; } + string SiteKeyResponseKeyName { get; } + bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null); + Task ValidateCaptchaResponseAsync(string captchResponse, string clientIpAddress, + User user = null); + string GenerateCaptchaBypassToken(User user); + } } diff --git a/src/Core/Services/ICipherService.cs b/src/Core/Services/ICipherService.cs index ad93990c2d..9afeb5926b 100644 --- a/src/Core/Services/ICipherService.cs +++ b/src/Core/Services/ICipherService.cs @@ -2,42 +2,43 @@ using Bit.Core.Models.Data; using Core.Models.Data; -namespace Bit.Core.Services; - -public interface ICipherService +namespace Bit.Core.Services { - Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, - bool skipPermissionCheck = false, bool limitCollectionScope = true); - Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, - IEnumerable collectionIds = null, bool skipPermissionCheck = false); - Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); - Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false); - Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, string attachmentId, - Guid organizationShareId); - Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); - Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); - Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); - Task PurgeAsync(Guid organizationId); - Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId); - Task SaveFolderAsync(Folder folder); - Task DeleteFolderAsync(Folder folder); - Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, IEnumerable collectionIds, - Guid userId, DateTime? lastKnownRevisionDate); - Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId, - IEnumerable collectionIds, Guid sharingUserId); - Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, bool orgAdmin); - Task ImportCiphersAsync(List folders, List ciphers, - IEnumerable> folderRelationships); - Task ImportCiphersAsync(List collections, List ciphers, - IEnumerable> collectionRelationships, Guid importingUserId); - Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); - Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); - Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false); - Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId); - Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); - Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); - Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); - Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId); + public interface ICipherService + { + Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, + bool skipPermissionCheck = false, bool limitCollectionScope = true); + Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, + IEnumerable collectionIds = null, bool skipPermissionCheck = false); + Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); + Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, + long requestLength, Guid savingUserId, bool orgAdmin = false); + Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, string attachmentId, + Guid organizationShareId); + Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); + Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); + Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); + Task PurgeAsync(Guid organizationId); + Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId); + Task SaveFolderAsync(Folder folder); + Task DeleteFolderAsync(Folder folder); + Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, IEnumerable collectionIds, + Guid userId, DateTime? lastKnownRevisionDate); + Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId, + IEnumerable collectionIds, Guid sharingUserId); + Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, bool orgAdmin); + Task ImportCiphersAsync(List folders, List ciphers, + IEnumerable> folderRelationships); + Task ImportCiphersAsync(List collections, List ciphers, + IEnumerable> collectionRelationships, Guid importingUserId); + Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); + Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); + Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false); + Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId); + Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); + Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); + Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId); + } } diff --git a/src/Core/Services/ICollectionService.cs b/src/Core/Services/ICollectionService.cs index 7ae3562ea0..015474b6f5 100644 --- a/src/Core/Services/ICollectionService.cs +++ b/src/Core/Services/ICollectionService.cs @@ -1,12 +1,13 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface ICollectionService +namespace Bit.Core.Services { - Task SaveAsync(Collection collection, IEnumerable groups = null, Guid? assignUserId = null); - Task DeleteAsync(Collection collection); - Task DeleteUserAsync(Collection collection, Guid organizationUserId); - Task> GetOrganizationCollections(Guid organizationId); + public interface ICollectionService + { + Task SaveAsync(Collection collection, IEnumerable groups = null, Guid? assignUserId = null); + Task DeleteAsync(Collection collection); + Task DeleteUserAsync(Collection collection, Guid organizationUserId); + Task> GetOrganizationCollections(Guid organizationId); + } } diff --git a/src/Core/Services/IDeviceService.cs b/src/Core/Services/IDeviceService.cs index 3109cc107a..6455e6a32e 100644 --- a/src/Core/Services/IDeviceService.cs +++ b/src/Core/Services/IDeviceService.cs @@ -1,10 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Services; - -public interface IDeviceService +namespace Bit.Core.Services { - Task SaveAsync(Device device); - Task ClearTokenAsync(Device device); - Task DeleteAsync(Device device); + public interface IDeviceService + { + Task SaveAsync(Device device); + Task ClearTokenAsync(Device device); + Task DeleteAsync(Device device); + } } diff --git a/src/Core/Services/IEmergencyAccessService.cs b/src/Core/Services/IEmergencyAccessService.cs index 96edb752c5..f975bfe76e 100644 --- a/src/Core/Services/IEmergencyAccessService.cs +++ b/src/Core/Services/IEmergencyAccessService.cs @@ -2,25 +2,26 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface IEmergencyAccessService +namespace Bit.Core.Services { - Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime); - Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId); - Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService); - Task DeleteAsync(Guid emergencyAccessId, Guid grantorId); - Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId); - Task GetAsync(Guid emergencyAccessId, Guid userId); - Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser); - Task InitiateAsync(Guid id, User initiatingUser); - Task ApproveAsync(Guid id, User approvingUser); - Task RejectAsync(Guid id, User rejectingUser); - Task> GetPoliciesAsync(Guid id, User requestingUser); - Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); - Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); - Task SendNotificationsAsync(); - Task HandleTimedOutRequestsAsync(); - Task ViewAsync(Guid id, User user); - Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User user); + public interface IEmergencyAccessService + { + Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime); + Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId); + Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService); + Task DeleteAsync(Guid emergencyAccessId, Guid grantorId); + Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId); + Task GetAsync(Guid emergencyAccessId, Guid userId); + Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser); + Task InitiateAsync(Guid id, User initiatingUser); + Task ApproveAsync(Guid id, User approvingUser); + Task RejectAsync(Guid id, User rejectingUser); + Task> GetPoliciesAsync(Guid id, User requestingUser); + Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); + Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); + Task SendNotificationsAsync(); + Task HandleTimedOutRequestsAsync(); + Task ViewAsync(Guid id, User user); + Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User user); + } } diff --git a/src/Core/Services/IEventService.cs b/src/Core/Services/IEventService.cs index fd0ca44918..fa98485842 100644 --- a/src/Core/Services/IEventService.cs +++ b/src/Core/Services/IEventService.cs @@ -2,20 +2,21 @@ using Bit.Core.Entities.Provider; using Bit.Core.Enums; -namespace Bit.Core.Services; - -public interface IEventService +namespace Bit.Core.Services { - Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null); - Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null); - Task LogCipherEventsAsync(IEnumerable> events); - Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null); - Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null); - Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null); - Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, DateTime? date = null); - Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events); - Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null); - Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null); - Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events); - Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, DateTime? date = null); + public interface IEventService + { + Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null); + Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null); + Task LogCipherEventsAsync(IEnumerable> events); + Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null); + Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null); + Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null); + Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, DateTime? date = null); + Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events); + Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null); + Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null); + Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events); + Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, DateTime? date = null); + } } diff --git a/src/Core/Services/IEventWriteService.cs b/src/Core/Services/IEventWriteService.cs index cbe8790d31..dc33189376 100644 --- a/src/Core/Services/IEventWriteService.cs +++ b/src/Core/Services/IEventWriteService.cs @@ -1,9 +1,10 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface IEventWriteService +namespace Bit.Core.Services { - Task CreateAsync(IEvent e); - Task CreateManyAsync(IEnumerable e); + public interface IEventWriteService + { + Task CreateAsync(IEvent e); + Task CreateManyAsync(IEnumerable e); + } } diff --git a/src/Core/Services/IGroupService.cs b/src/Core/Services/IGroupService.cs index 494d3e6c08..82fd9792a3 100644 --- a/src/Core/Services/IGroupService.cs +++ b/src/Core/Services/IGroupService.cs @@ -1,11 +1,12 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface IGroupService +namespace Bit.Core.Services { - Task SaveAsync(Group group, IEnumerable collections = null); - Task DeleteAsync(Group group); - Task DeleteUserAsync(Group group, Guid organizationUserId); + public interface IGroupService + { + Task SaveAsync(Group group, IEnumerable collections = null); + Task DeleteAsync(Group group); + Task DeleteUserAsync(Group group, Guid organizationUserId); + } } diff --git a/src/Core/Services/II18nService.cs b/src/Core/Services/II18nService.cs index ee92664d88..a66e148833 100644 --- a/src/Core/Services/II18nService.cs +++ b/src/Core/Services/II18nService.cs @@ -1,11 +1,12 @@ using Microsoft.Extensions.Localization; -namespace Bit.Core.Services; - -public interface II18nService +namespace Bit.Core.Services { - LocalizedString GetLocalizedHtmlString(string key); - LocalizedString GetLocalizedHtmlString(string key, params object[] args); - string Translate(string key, params object[] args); - string T(string key, params object[] args); + public interface II18nService + { + LocalizedString GetLocalizedHtmlString(string key); + LocalizedString GetLocalizedHtmlString(string key, params object[] args); + string Translate(string key, params object[] args); + string T(string key, params object[] args); + } } diff --git a/src/Core/Services/ILicensingService.cs b/src/Core/Services/ILicensingService.cs index bf3b5ee425..fd3ad9afef 100644 --- a/src/Core/Services/ILicensingService.cs +++ b/src/Core/Services/ILicensingService.cs @@ -1,16 +1,17 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.Services; - -public interface ILicensingService +namespace Bit.Core.Services { - Task ValidateOrganizationsAsync(); - Task ValidateUsersAsync(); - Task ValidateUserPremiumAsync(User user); - bool VerifyLicense(ILicense license); - byte[] SignLicense(ILicense license); - Task ReadOrganizationLicenseAsync(Organization organization); - Task ReadOrganizationLicenseAsync(Guid organizationId); + public interface ILicensingService + { + Task ValidateOrganizationsAsync(); + Task ValidateUsersAsync(); + Task ValidateUserPremiumAsync(User user); + bool VerifyLicense(ILicense license); + byte[] SignLicense(ILicense license); + Task ReadOrganizationLicenseAsync(Organization organization); + Task ReadOrganizationLicenseAsync(Guid organizationId); + } } diff --git a/src/Core/Services/IMailDeliveryService.cs b/src/Core/Services/IMailDeliveryService.cs index 9247367221..1c42e39e24 100644 --- a/src/Core/Services/IMailDeliveryService.cs +++ b/src/Core/Services/IMailDeliveryService.cs @@ -1,8 +1,9 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - -public interface IMailDeliveryService +namespace Bit.Core.Services { - Task SendEmailAsync(MailMessage message); + public interface IMailDeliveryService + { + Task SendEmailAsync(MailMessage message); + } } diff --git a/src/Core/Services/IMailEnqueuingService.cs b/src/Core/Services/IMailEnqueuingService.cs index 19dc33f19e..1e681b6ca4 100644 --- a/src/Core/Services/IMailEnqueuingService.cs +++ b/src/Core/Services/IMailEnqueuingService.cs @@ -1,9 +1,10 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - -public interface IMailEnqueuingService +namespace Bit.Core.Services { - Task EnqueueAsync(IMailQueueMessage message, Func fallback); - Task EnqueueManyAsync(IEnumerable messages, Func fallback); + public interface IMailEnqueuingService + { + Task EnqueueAsync(IMailQueueMessage message, Func fallback); + Task EnqueueManyAsync(IEnumerable messages, Func fallback); + } } diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index 3af89108c6..7be31e82a3 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -3,55 +3,56 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - -public interface IMailService +namespace Bit.Core.Services { - Task SendWelcomeEmailAsync(User user); - Task SendVerifyEmailEmailAsync(string email, Guid userId, string token); - Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token); - Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail); - Task SendChangeEmailEmailAsync(string newEmailAddress, string token); - Task SendTwoFactorEmailAsync(string email, string token); - Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token); - Task SendNoMasterPasswordHintEmailAsync(string email); - Task SendMasterPasswordHintEmailAsync(string email, string hint); - Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token); - Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites); - Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails); - Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails); - Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails); - Task SendOrganizationConfirmedEmailAsync(string organizationName, string email); - Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); - Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); - Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List items, - bool mentionInvoices); - Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); - Task SendAddedCreditAsync(string email, decimal amount); - Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null); - Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip); - Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip); - Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email); - Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token); - Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email); - Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email); - Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email); - Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email); - Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email); - Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); - Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); - Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); - Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); - Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); - Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); - Task SendProviderConfirmedEmailAsync(string providerName, string email); - Task SendProviderUserRemoved(string providerName, string email); - Task SendUpdatedTempPasswordEmailAsync(string email, string userName); - Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token); - Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites); - Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail); - Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate); - Task SendOTPEmailAsync(string email, string token); - Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip); - Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip); + public interface IMailService + { + Task SendWelcomeEmailAsync(User user); + Task SendVerifyEmailEmailAsync(string email, Guid userId, string token); + Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token); + Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail); + Task SendChangeEmailEmailAsync(string newEmailAddress, string token); + Task SendTwoFactorEmailAsync(string email, string token); + Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token); + Task SendNoMasterPasswordHintEmailAsync(string email); + Task SendMasterPasswordHintEmailAsync(string email, string hint); + Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token); + Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites); + Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails); + Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails); + Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails); + Task SendOrganizationConfirmedEmailAsync(string organizationName, string email); + Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); + Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); + Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List items, + bool mentionInvoices); + Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); + Task SendAddedCreditAsync(string email, decimal amount); + Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null); + Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip); + Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip); + Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email); + Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token); + Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email); + Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email); + Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email); + Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email); + Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email); + Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); + Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); + Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); + Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); + Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); + Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); + Task SendProviderConfirmedEmailAsync(string providerName, string email); + Task SendProviderUserRemoved(string providerName, string email); + Task SendUpdatedTempPasswordEmailAsync(string email, string userName); + Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token); + Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites); + Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail); + Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate); + Task SendOTPEmailAsync(string email, string token); + Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip); + Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip); + } } diff --git a/src/Core/Services/IOrganizationService.cs b/src/Core/Services/IOrganizationService.cs index 3bd3e1f6eb..076cd3eb8b 100644 --- a/src/Core/Services/IOrganizationService.cs +++ b/src/Core/Services/IOrganizationService.cs @@ -3,65 +3,66 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface IOrganizationService +namespace Bit.Core.Services { - Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType, - TaxInfo taxInfo); - Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); - Task ReinstateSubscriptionAsync(Guid organizationId); - Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade); - Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); - Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats); - Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null); - Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null); - Task VerifyBankAsync(Guid organizationId, int amount1, int amount2); - Task> SignUpAsync(OrganizationSignup organizationSignup, bool provider = false); - Task> SignUpAsync(OrganizationLicense license, User owner, - string ownerKey, string collectionName, string publicKey, string privateKey); - Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license); - Task DeleteAsync(Organization organization); - Task EnableAsync(Guid organizationId, DateTime? expirationDate); - Task DisableAsync(Guid organizationId, DateTime? expirationDate); - Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate); - Task EnableAsync(Guid organizationId); - Task UpdateAsync(Organization organization, bool updateBilling = false); - Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); - Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); - Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, - IEnumerable<(OrganizationUserInvite invite, string externalId)> invites); - Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, - OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections); - Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, IEnumerable organizationUsersId); - Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId); - Task AcceptUserAsync(Guid organizationUserId, User user, string token, - IUserService userService); - Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService); - Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, - Guid confirmingUserId, IUserService userService); - Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, - Guid confirmingUserId, IUserService userService); - Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, IEnumerable collections); - Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId); - Task DeleteUserAsync(Guid organizationId, Guid userId); - Task>> DeleteUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? deletingUserId); - Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId); - Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId); - Task GenerateLicenseAsync(Guid organizationId, Guid installationId); - Task GenerateLicenseAsync(Organization organization, Guid installationId, - int? version = null); - Task ImportAsync(Guid organizationId, Guid? importingUserId, IEnumerable groups, - IEnumerable newUsers, IEnumerable removeUserExternalIds, - bool overwriteExisting); - Task DeleteSsoUserAsync(Guid userId, Guid? organizationId); - Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey); - Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true); - Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId); - Task>> RevokeUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? revokingUserId); - Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService); - Task>> RestoreUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService); + public interface IOrganizationService + { + Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType, + TaxInfo taxInfo); + Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); + Task ReinstateSubscriptionAsync(Guid organizationId); + Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade); + Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); + Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats); + Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null); + Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null); + Task VerifyBankAsync(Guid organizationId, int amount1, int amount2); + Task> SignUpAsync(OrganizationSignup organizationSignup, bool provider = false); + Task> SignUpAsync(OrganizationLicense license, User owner, + string ownerKey, string collectionName, string publicKey, string privateKey); + Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license); + Task DeleteAsync(Organization organization); + Task EnableAsync(Guid organizationId, DateTime? expirationDate); + Task DisableAsync(Guid organizationId, DateTime? expirationDate); + Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate); + Task EnableAsync(Guid organizationId); + Task UpdateAsync(Organization organization, bool updateBilling = false); + Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); + Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); + Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, + IEnumerable<(OrganizationUserInvite invite, string externalId)> invites); + Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, + OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections); + Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, IEnumerable organizationUsersId); + Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId); + Task AcceptUserAsync(Guid organizationUserId, User user, string token, + IUserService userService); + Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService); + Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, + Guid confirmingUserId, IUserService userService); + Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, + Guid confirmingUserId, IUserService userService); + Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, IEnumerable collections); + Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId); + Task DeleteUserAsync(Guid organizationId, Guid userId); + Task>> DeleteUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? deletingUserId); + Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId); + Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId); + Task GenerateLicenseAsync(Guid organizationId, Guid installationId); + Task GenerateLicenseAsync(Organization organization, Guid installationId, + int? version = null); + Task ImportAsync(Guid organizationId, Guid? importingUserId, IEnumerable groups, + IEnumerable newUsers, IEnumerable removeUserExternalIds, + bool overwriteExisting); + Task DeleteSsoUserAsync(Guid userId, Guid? organizationId); + Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey); + Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true); + Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId); + Task>> RevokeUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? revokingUserId); + Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService); + Task>> RestoreUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService); + } } diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index a9091808d2..562c70e3ad 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -3,35 +3,36 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; - -public interface IPaymentService +namespace Bit.Core.Services { - Task CancelAndRecoverChargesAsync(ISubscriber subscriber); - Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, - string paymentToken, Plan plan, short additionalStorageGb, int additionalSeats, - bool premiumAccessAddon, TaxInfo taxInfo); - Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); - Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship); - Task UpgradeFreeOrganizationAsync(Organization org, Plan plan, - short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo); - Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, string paymentToken, - short additionalStorageGb, TaxInfo taxInfo); - Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); - Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null); - Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, - bool skipInAppPurchaseCheck = false); - Task ReinstateSubscriptionAsync(ISubscriber subscriber); - Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null); - Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); - Task GetBillingAsync(ISubscriber subscriber); - Task GetBillingHistoryAsync(ISubscriber subscriber); - Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber); - Task GetSubscriptionAsync(ISubscriber subscriber); - Task GetTaxInfoAsync(ISubscriber subscriber); - Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); - Task CreateTaxRateAsync(TaxRate taxRate); - Task UpdateTaxRateAsync(TaxRate taxRate); - Task ArchiveTaxRateAsync(TaxRate taxRate); + public interface IPaymentService + { + Task CancelAndRecoverChargesAsync(ISubscriber subscriber); + Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, + string paymentToken, Plan plan, short additionalStorageGb, int additionalSeats, + bool premiumAccessAddon, TaxInfo taxInfo); + Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); + Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship); + Task UpgradeFreeOrganizationAsync(Organization org, Plan plan, + short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo); + Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, string paymentToken, + short additionalStorageGb, TaxInfo taxInfo); + Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); + Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null); + Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, + bool skipInAppPurchaseCheck = false); + Task ReinstateSubscriptionAsync(ISubscriber subscriber); + Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, + string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null); + Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); + Task GetBillingAsync(ISubscriber subscriber); + Task GetBillingHistoryAsync(ISubscriber subscriber); + Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber); + Task GetSubscriptionAsync(ISubscriber subscriber); + Task GetTaxInfoAsync(ISubscriber subscriber); + Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); + Task CreateTaxRateAsync(TaxRate taxRate); + Task UpdateTaxRateAsync(TaxRate taxRate); + Task ArchiveTaxRateAsync(TaxRate taxRate); + } } diff --git a/src/Core/Services/IPolicyService.cs b/src/Core/Services/IPolicyService.cs index 5f1b4d3664..d7487cbd45 100644 --- a/src/Core/Services/IPolicyService.cs +++ b/src/Core/Services/IPolicyService.cs @@ -1,9 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Services; - -public interface IPolicyService +namespace Bit.Core.Services { - Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, - Guid? savingUserId); + public interface IPolicyService + { + Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, + Guid? savingUserId); + } } diff --git a/src/Core/Services/IProviderService.cs b/src/Core/Services/IProviderService.cs index c5cf039b28..eb38afad2c 100644 --- a/src/Core/Services/IProviderService.cs +++ b/src/Core/Services/IProviderService.cs @@ -3,28 +3,29 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Business.Provider; -namespace Bit.Core.Services; - -public interface IProviderService +namespace Bit.Core.Services { - Task CreateAsync(string ownerEmail); - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key); - Task UpdateAsync(Provider provider, bool updateBilling = false); + public interface IProviderService + { + Task CreateAsync(string ownerEmail); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key); + Task UpdateAsync(Provider provider, bool updateBilling = false); - Task> InviteUserAsync(ProviderUserInvite invite); - Task>> ResendInvitesAsync(ProviderUserInvite invite); - Task AcceptUserAsync(Guid providerUserId, User user, string token); - Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId); + Task> InviteUserAsync(ProviderUserInvite invite); + Task>> ResendInvitesAsync(ProviderUserInvite invite); + Task AcceptUserAsync(Guid providerUserId, User user, string token); + Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId); - Task SaveUserAsync(ProviderUser user, Guid savingUserId); - Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, - Guid deletingUserId); + Task SaveUserAsync(ProviderUser user, Guid savingUserId); + Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, + Guid deletingUserId); - Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key); - Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, - string clientOwnerEmail, User user); - Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId); - Task LogProviderAccessToOrganizationAsync(Guid organizationId); - Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId); + Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key); + Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, + string clientOwnerEmail, User user); + Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId); + Task LogProviderAccessToOrganizationAsync(Guid organizationId); + Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId); + } } diff --git a/src/Core/Services/IPushNotificationService.cs b/src/Core/Services/IPushNotificationService.cs index 34e98515ff..9707b93dc6 100644 --- a/src/Core/Services/IPushNotificationService.cs +++ b/src/Core/Services/IPushNotificationService.cs @@ -1,25 +1,26 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services; - -public interface IPushNotificationService +namespace Bit.Core.Services { - Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds); - Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds); - Task PushSyncCipherDeleteAsync(Cipher cipher); - Task PushSyncFolderCreateAsync(Folder folder); - Task PushSyncFolderUpdateAsync(Folder folder); - Task PushSyncFolderDeleteAsync(Folder folder); - Task PushSyncCiphersAsync(Guid userId); - Task PushSyncVaultAsync(Guid userId); - Task PushSyncOrgKeysAsync(Guid userId); - Task PushSyncSettingsAsync(Guid userId); - Task PushLogOutAsync(Guid userId); - Task PushSyncSendCreateAsync(Send send); - Task PushSyncSendUpdateAsync(Send send); - Task PushSyncSendDeleteAsync(Send send); - Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, string deviceId = null); - Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null); + public interface IPushNotificationService + { + Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds); + Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds); + Task PushSyncCipherDeleteAsync(Cipher cipher); + Task PushSyncFolderCreateAsync(Folder folder); + Task PushSyncFolderUpdateAsync(Folder folder); + Task PushSyncFolderDeleteAsync(Folder folder); + Task PushSyncCiphersAsync(Guid userId); + Task PushSyncVaultAsync(Guid userId); + Task PushSyncOrgKeysAsync(Guid userId); + Task PushSyncSettingsAsync(Guid userId); + Task PushLogOutAsync(Guid userId); + Task PushSyncSendCreateAsync(Send send); + Task PushSyncSendUpdateAsync(Send send); + Task PushSyncSendDeleteAsync(Send send); + Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, string deviceId = null); + Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null); + } } diff --git a/src/Core/Services/IPushRegistrationService.cs b/src/Core/Services/IPushRegistrationService.cs index 985246de0c..14d2c82ef1 100644 --- a/src/Core/Services/IPushRegistrationService.cs +++ b/src/Core/Services/IPushRegistrationService.cs @@ -1,12 +1,13 @@ using Bit.Core.Enums; -namespace Bit.Core.Services; - -public interface IPushRegistrationService +namespace Bit.Core.Services { - Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type); - Task DeleteRegistrationAsync(string deviceId); - Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); - Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); + public interface IPushRegistrationService + { + Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type); + Task DeleteRegistrationAsync(string deviceId); + Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); + Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); + } } diff --git a/src/Core/Services/IReferenceEventService.cs b/src/Core/Services/IReferenceEventService.cs index 03339f08c4..fa85a2a3de 100644 --- a/src/Core/Services/IReferenceEventService.cs +++ b/src/Core/Services/IReferenceEventService.cs @@ -1,8 +1,9 @@ using Bit.Core.Models.Business; -namespace Bit.Core.Services; - -public interface IReferenceEventService +namespace Bit.Core.Services { - Task RaiseEventAsync(ReferenceEvent referenceEvent); + public interface IReferenceEventService + { + Task RaiseEventAsync(ReferenceEvent referenceEvent); + } } diff --git a/src/Core/Services/ISendService.cs b/src/Core/Services/ISendService.cs index a2b6b8c35c..8ee97a6298 100644 --- a/src/Core/Services/ISendService.cs +++ b/src/Core/Services/ISendService.cs @@ -1,16 +1,17 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public interface ISendService +namespace Bit.Core.Services { - Task DeleteSendAsync(Send send); - Task SaveSendAsync(Send send); - Task SaveFileSendAsync(Send send, SendFileData data, long fileLength); - Task UploadFileToExistingSendAsync(Stream stream, Send send); - Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password); - string HashPassword(string password); - Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password); - Task ValidateSendFile(Send send); + public interface ISendService + { + Task DeleteSendAsync(Send send); + Task SaveSendAsync(Send send); + Task SaveFileSendAsync(Send send, SendFileData data, long fileLength); + Task UploadFileToExistingSendAsync(Stream stream, Send send); + Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password); + string HashPassword(string password); + Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password); + Task ValidateSendFile(Send send); + } } diff --git a/src/Core/Services/ISendStorageService.cs b/src/Core/Services/ISendStorageService.cs index 63c0d44ca9..f671d00777 100644 --- a/src/Core/Services/ISendStorageService.cs +++ b/src/Core/Services/ISendStorageService.cs @@ -1,16 +1,17 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services; - -public interface ISendFileStorageService +namespace Bit.Core.Services { - FileUploadType FileUploadType { get; } - Task UploadNewFileAsync(Stream stream, Send send, string fileId); - Task DeleteFileAsync(Send send, string fileId); - Task DeleteFilesForOrganizationAsync(Guid organizationId); - Task DeleteFilesForUserAsync(Guid userId); - Task GetSendFileDownloadUrlAsync(Send send, string fileId); - Task GetSendFileUploadUrlAsync(Send send, string fileId); - Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway); + public interface ISendFileStorageService + { + FileUploadType FileUploadType { get; } + Task UploadNewFileAsync(Stream stream, Send send, string fileId); + Task DeleteFileAsync(Send send, string fileId); + Task DeleteFilesForOrganizationAsync(Guid organizationId); + Task DeleteFilesForUserAsync(Guid userId); + Task GetSendFileDownloadUrlAsync(Send send, string fileId); + Task GetSendFileUploadUrlAsync(Send send, string fileId); + Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway); + } } diff --git a/src/Core/Services/ISsoConfigService.cs b/src/Core/Services/ISsoConfigService.cs index c25127d956..d4d2cfcefc 100644 --- a/src/Core/Services/ISsoConfigService.cs +++ b/src/Core/Services/ISsoConfigService.cs @@ -1,8 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Services; - -public interface ISsoConfigService +namespace Bit.Core.Services { - Task SaveAsync(SsoConfig config, Organization organization); + public interface ISsoConfigService + { + Task SaveAsync(SsoConfig config, Organization organization); + } } diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index ff922161cc..ffb0e2a1c5 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -1,39 +1,40 @@ using Bit.Core.Models.BitStripe; -namespace Bit.Core.Services; - -public interface IStripeAdapter +namespace Bit.Core.Services { - Task CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions); - Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null); - Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); - Task CustomerDeleteAsync(string id); - Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); - Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); - Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); - Task SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null); - Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null); - Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); - Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); - Task> InvoiceListAsync(Stripe.InvoiceListOptions options); - Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); - Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); - Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); - Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null); - Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null); - Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null); - IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options); - Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); - Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); - Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options); - Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options); - Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); - Task TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null); - Task> ChargeListAsync(Stripe.ChargeListOptions options); - Task RefundCreateAsync(Stripe.RefundCreateOptions options); - Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null); - Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null); - Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null); - Task> PriceListAsync(Stripe.PriceListOptions options = null); - Task> TestClockListAsync(); + public interface IStripeAdapter + { + Task CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions); + Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null); + Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); + Task CustomerDeleteAsync(string id); + Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); + Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); + Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); + Task SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null); + Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null); + Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); + Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); + Task> InvoiceListAsync(Stripe.InvoiceListOptions options); + Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); + Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); + Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); + Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null); + Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null); + Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null); + IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options); + Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); + Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); + Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options); + Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options); + Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); + Task TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null); + Task> ChargeListAsync(Stripe.ChargeListOptions options); + Task RefundCreateAsync(Stripe.RefundCreateOptions options); + Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null); + Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null); + Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null); + Task> PriceListAsync(Stripe.PriceListOptions options = null); + Task> TestClockListAsync(); + } } diff --git a/src/Core/Services/IStripeSyncService.cs b/src/Core/Services/IStripeSyncService.cs index 655998805e..0219bc5d2e 100644 --- a/src/Core/Services/IStripeSyncService.cs +++ b/src/Core/Services/IStripeSyncService.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Services; - -public interface IStripeSyncService +namespace Bit.Core.Services { - Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); + public interface IStripeSyncService + { + Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); + } } diff --git a/src/Core/Services/IUserService.cs b/src/Core/Services/IUserService.cs index 077f66756e..989bea85dc 100644 --- a/src/Core/Services/IUserService.cs +++ b/src/Core/Services/IUserService.cs @@ -6,76 +6,77 @@ using Bit.Core.Models.Business; using Fido2NetLib; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Services; - -public interface IUserService +namespace Bit.Core.Services { - Guid? GetProperUserId(ClaimsPrincipal principal); - Task GetUserByIdAsync(string userId); - Task GetUserByIdAsync(Guid userId); - Task GetUserByPrincipalAsync(ClaimsPrincipal principal); - Task GetAccountRevisionDateByIdAsync(Guid userId); - Task SaveUserAsync(User user, bool push = false); - Task RegisterUserAsync(User user, string masterPassword, string token, Guid? orgUserId); - Task RegisterUserAsync(User user); - Task SendMasterPasswordHintAsync(string email); - Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false); - Task VerifyTwoFactorEmailAsync(User user, string token); - Task StartWebAuthnRegistrationAsync(User user); - Task DeleteWebAuthnKeyAsync(User user, int id); - Task CompleteWebAuthRegistrationAsync(User user, int value, string name, AuthenticatorAttestationRawResponse attestationResponse); - Task SendEmailVerificationAsync(User user); - Task ConfirmEmailAsync(User user, string token); - Task InitiateEmailChangeAsync(User user, string newEmail); - Task ChangeEmailAsync(User user, string masterPassword, string newEmail, string newMasterPassword, - string token, string key); - Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, string key); - Task SetPasswordAsync(User user, string newMasterPassword, string key, string orgIdentifier = null); - Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier); - Task ConvertToKeyConnectorAsync(User user); - Task AdminResetPasswordAsync(OrganizationUserType type, Guid orgId, Guid id, string newMasterPassword, string key); - Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint); - Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key, - KdfType kdf, int kdfIterations); - Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, - IEnumerable ciphers, IEnumerable folders, IEnumerable sends); - Task RefreshSecurityStampAsync(User user, string masterPasswordHash); - Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true); - Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, - IOrganizationService organizationService); - Task RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode, - IOrganizationService organizationService); - Task GenerateUserTokenAsync(User user, string tokenProvider, string purpose); - Task DeleteAsync(User user); - Task DeleteAsync(User user, string token); - Task SendDeleteConfirmationAsync(string email); - Task> SignUpPremiumAsync(User user, string paymentToken, - PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, - TaxInfo taxInfo); - Task IapCheckAsync(User user, PaymentMethodType paymentMethodType); - Task UpdateLicenseAsync(User user, UserLicense license); - Task AdjustStorageAsync(User user, short storageAdjustmentGb); - Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo); - Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false); - Task ReinstatePremiumAsync(User user); - Task EnablePremiumAsync(Guid userId, DateTime? expirationDate); - Task EnablePremiumAsync(User user, DateTime? expirationDate); - Task DisablePremiumAsync(Guid userId, DateTime? expirationDate); - Task DisablePremiumAsync(User user, DateTime? expirationDate); - Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate); - Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, - int? version = null); - Task CheckPasswordAsync(User user, string password); - Task CanAccessPremium(ITwoFactorProvidersUser user); - Task HasPremiumFromOrganization(ITwoFactorProvidersUser user); - Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user); - Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user); - Task GenerateSignInTokenAsync(User user, string purpose); - Task RotateApiKeyAsync(User user); - string GetUserName(ClaimsPrincipal principal); - Task SendOTPAsync(User user); - Task VerifyOTPAsync(User user, string token); - Task VerifySecretAsync(User user, string secret); - Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType); - bool CanEditDeviceVerificationSettings(User user); + public interface IUserService + { + Guid? GetProperUserId(ClaimsPrincipal principal); + Task GetUserByIdAsync(string userId); + Task GetUserByIdAsync(Guid userId); + Task GetUserByPrincipalAsync(ClaimsPrincipal principal); + Task GetAccountRevisionDateByIdAsync(Guid userId); + Task SaveUserAsync(User user, bool push = false); + Task RegisterUserAsync(User user, string masterPassword, string token, Guid? orgUserId); + Task RegisterUserAsync(User user); + Task SendMasterPasswordHintAsync(string email); + Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false); + Task VerifyTwoFactorEmailAsync(User user, string token); + Task StartWebAuthnRegistrationAsync(User user); + Task DeleteWebAuthnKeyAsync(User user, int id); + Task CompleteWebAuthRegistrationAsync(User user, int value, string name, AuthenticatorAttestationRawResponse attestationResponse); + Task SendEmailVerificationAsync(User user); + Task ConfirmEmailAsync(User user, string token); + Task InitiateEmailChangeAsync(User user, string newEmail); + Task ChangeEmailAsync(User user, string masterPassword, string newEmail, string newMasterPassword, + string token, string key); + Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, string key); + Task SetPasswordAsync(User user, string newMasterPassword, string key, string orgIdentifier = null); + Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier); + Task ConvertToKeyConnectorAsync(User user); + Task AdminResetPasswordAsync(OrganizationUserType type, Guid orgId, Guid id, string newMasterPassword, string key); + Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint); + Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key, + KdfType kdf, int kdfIterations); + Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, + IEnumerable ciphers, IEnumerable folders, IEnumerable sends); + Task RefreshSecurityStampAsync(User user, string masterPasswordHash); + Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true); + Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, + IOrganizationService organizationService); + Task RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode, + IOrganizationService organizationService); + Task GenerateUserTokenAsync(User user, string tokenProvider, string purpose); + Task DeleteAsync(User user); + Task DeleteAsync(User user, string token); + Task SendDeleteConfirmationAsync(string email); + Task> SignUpPremiumAsync(User user, string paymentToken, + PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, + TaxInfo taxInfo); + Task IapCheckAsync(User user, PaymentMethodType paymentMethodType); + Task UpdateLicenseAsync(User user, UserLicense license); + Task AdjustStorageAsync(User user, short storageAdjustmentGb); + Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo); + Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false); + Task ReinstatePremiumAsync(User user); + Task EnablePremiumAsync(Guid userId, DateTime? expirationDate); + Task EnablePremiumAsync(User user, DateTime? expirationDate); + Task DisablePremiumAsync(Guid userId, DateTime? expirationDate); + Task DisablePremiumAsync(User user, DateTime? expirationDate); + Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate); + Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, + int? version = null); + Task CheckPasswordAsync(User user, string password); + Task CanAccessPremium(ITwoFactorProvidersUser user); + Task HasPremiumFromOrganization(ITwoFactorProvidersUser user); + Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user); + Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user); + Task GenerateSignInTokenAsync(User user, string purpose); + Task RotateApiKeyAsync(User user); + string GetUserName(ClaimsPrincipal principal); + Task SendOTPAsync(User user); + Task VerifyOTPAsync(User user, string token); + Task VerifySecretAsync(User user, string secret); + Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType); + bool CanEditDeviceVerificationSettings(User user); + } } diff --git a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs b/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs index adf406cf0b..98275b4ab9 100644 --- a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs +++ b/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs @@ -7,136 +7,137 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable +namespace Bit.Core.Services { - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly IAmazonSimpleEmailService _client; - private readonly string _source; - private readonly string _senderTag; - private readonly string _configSetName; - - public AmazonSesMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger) - : this(globalSettings, hostingEnvironment, logger, - new AmazonSimpleEmailServiceClient( - globalSettings.Amazon.AccessKeyId, - globalSettings.Amazon.AccessKeySecret, - RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) - ) + public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable { - } + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly IAmazonSimpleEmailService _client; + private readonly string _source; + private readonly string _senderTag; + private readonly string _configSetName; - public AmazonSesMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger, - IAmazonSimpleEmailService amazonSimpleEmailService) - { - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) + public AmazonSesMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger) + : this(globalSettings, hostingEnvironment, logger, + new AmazonSimpleEmailServiceClient( + globalSettings.Amazon.AccessKeyId, + globalSettings.Amazon.AccessKeySecret, + RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) + ) { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); } - var replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); - - _globalSettings = globalSettings; - _hostingEnvironment = hostingEnvironment; - _logger = logger; - _client = amazonSimpleEmailService; - _source = $"\"{globalSettings.SiteName}\" <{replyToEmail}>"; - _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; - if (!string.IsNullOrWhiteSpace(_globalSettings.Mail.AmazonConfigSetName)) + public AmazonSesMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger, + IAmazonSimpleEmailService amazonSimpleEmailService) { - _configSetName = _globalSettings.Mail.AmazonConfigSetName; - } - } - - public void Dispose() - { - _client?.Dispose(); - } - - public async Task SendEmailAsync(MailMessage message) - { - var request = new SendEmailRequest - { - ConfigurationSetName = _configSetName, - Source = _source, - Destination = new Destination + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) { - ToAddresses = message.ToEmails - .Select(email => CoreHelpers.PunyEncode(email)) - .ToList() - }, - Message = new Message - { - Subject = new Content(message.Subject), - Body = new Body - { - Html = new Content - { - Charset = "UTF-8", - Data = message.HtmlContent - }, - Text = new Content - { - Charset = "UTF-8", - Data = message.TextContent - } - } - }, - Tags = new List - { - new MessageTag { Name = "Environment", Value = _hostingEnvironment.EnvironmentName }, - new MessageTag { Name = "Sender", Value = _senderTag } + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); } - }; - if (message.BccEmails?.Any() ?? false) - { - request.Destination.BccAddresses = message.BccEmails - .Select(email => CoreHelpers.PunyEncode(email)) - .ToList(); + var replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); + + _globalSettings = globalSettings; + _hostingEnvironment = hostingEnvironment; + _logger = logger; + _client = amazonSimpleEmailService; + _source = $"\"{globalSettings.SiteName}\" <{replyToEmail}>"; + _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; + if (!string.IsNullOrWhiteSpace(_globalSettings.Mail.AmazonConfigSetName)) + { + _configSetName = _globalSettings.Mail.AmazonConfigSetName; + } } - if (!string.IsNullOrWhiteSpace(message.Category)) + public void Dispose() { - request.Tags.Add(new MessageTag { Name = "Category", Value = message.Category }); + _client?.Dispose(); } - try + public async Task SendEmailAsync(MailMessage message) { - await SendAsync(request, false); - } - catch (Exception e) - { - _logger.LogWarning(e, "Failed to send email. Retrying..."); - await SendAsync(request, true); - throw; - } - } + var request = new SendEmailRequest + { + ConfigurationSetName = _configSetName, + Source = _source, + Destination = new Destination + { + ToAddresses = message.ToEmails + .Select(email => CoreHelpers.PunyEncode(email)) + .ToList() + }, + Message = new Message + { + Subject = new Content(message.Subject), + Body = new Body + { + Html = new Content + { + Charset = "UTF-8", + Data = message.HtmlContent + }, + Text = new Content + { + Charset = "UTF-8", + Data = message.TextContent + } + } + }, + Tags = new List + { + new MessageTag { Name = "Environment", Value = _hostingEnvironment.EnvironmentName }, + new MessageTag { Name = "Sender", Value = _senderTag } + } + }; - private async Task SendAsync(SendEmailRequest request, bool retry) - { - if (retry) - { - // wait and try again - await Task.Delay(2000); + if (message.BccEmails?.Any() ?? false) + { + request.Destination.BccAddresses = message.BccEmails + .Select(email => CoreHelpers.PunyEncode(email)) + .ToList(); + } + + if (!string.IsNullOrWhiteSpace(message.Category)) + { + request.Tags.Add(new MessageTag { Name = "Category", Value = message.Category }); + } + + try + { + await SendAsync(request, false); + } + catch (Exception e) + { + _logger.LogWarning(e, "Failed to send email. Retrying..."); + await SendAsync(request, true); + throw; + } + } + + private async Task SendAsync(SendEmailRequest request, bool retry) + { + if (retry) + { + // wait and try again + await Task.Delay(2000); + } + await _client.SendEmailAsync(request); } - await _client.SendEmailAsync(request); } } diff --git a/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs b/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs index ac5dfb45c7..1e6dcf9356 100644 --- a/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs +++ b/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs @@ -2,80 +2,81 @@ using Amazon.SQS; using Bit.Core.Settings; -namespace Bit.Core.Services; - -public class AmazonSqsBlockIpService : IBlockIpService, IDisposable +namespace Bit.Core.Services { - private readonly IAmazonSQS _client; - private string _blockIpQueueUrl; - private string _unblockIpQueueUrl; - private bool _didInit = false; - private Tuple _lastBlock; - - public AmazonSqsBlockIpService( - GlobalSettings globalSettings) - : this(globalSettings, new AmazonSQSClient( - globalSettings.Amazon.AccessKeyId, - globalSettings.Amazon.AccessKeySecret, - RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) - ) + public class AmazonSqsBlockIpService : IBlockIpService, IDisposable { - } + private readonly IAmazonSQS _client; + private string _blockIpQueueUrl; + private string _unblockIpQueueUrl; + private bool _didInit = false; + private Tuple _lastBlock; - public AmazonSqsBlockIpService( - GlobalSettings globalSettings, - IAmazonSQS amazonSqs) - { - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) + public AmazonSqsBlockIpService( + GlobalSettings globalSettings) + : this(globalSettings, new AmazonSQSClient( + globalSettings.Amazon.AccessKeyId, + globalSettings.Amazon.AccessKeySecret, + RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) + ) { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); } - _client = amazonSqs; - } - - public void Dispose() - { - _client?.Dispose(); - } - - public async Task BlockIpAsync(string ipAddress, bool permanentBlock) - { - var now = DateTime.UtcNow; - if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && - (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) + public AmazonSqsBlockIpService( + GlobalSettings globalSettings, + IAmazonSQS amazonSqs) { - // Already blocked this IP recently. - return; + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); + } + + _client = amazonSqs; } - _lastBlock = new Tuple(ipAddress, permanentBlock, now); - await _client.SendMessageAsync(_blockIpQueueUrl, ipAddress); - if (!permanentBlock) + public void Dispose() { - await _client.SendMessageAsync(_unblockIpQueueUrl, ipAddress); - } - } - - private async Task InitAsync() - { - if (_didInit) - { - return; + _client?.Dispose(); } - var blockIpQueue = await _client.GetQueueUrlAsync("block-ip"); - _blockIpQueueUrl = blockIpQueue.QueueUrl; - var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip"); - _unblockIpQueueUrl = unblockIpQueue.QueueUrl; - _didInit = true; + public async Task BlockIpAsync(string ipAddress, bool permanentBlock) + { + var now = DateTime.UtcNow; + if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && + (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) + { + // Already blocked this IP recently. + return; + } + + _lastBlock = new Tuple(ipAddress, permanentBlock, now); + await _client.SendMessageAsync(_blockIpQueueUrl, ipAddress); + if (!permanentBlock) + { + await _client.SendMessageAsync(_unblockIpQueueUrl, ipAddress); + } + } + + private async Task InitAsync() + { + if (_didInit) + { + return; + } + + var blockIpQueue = await _client.GetQueueUrlAsync("block-ip"); + _blockIpQueueUrl = blockIpQueue.QueueUrl; + var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip"); + _unblockIpQueueUrl = unblockIpQueue.QueueUrl; + _didInit = true; + } } } diff --git a/src/Core/Services/Implementations/AppleIapService.cs b/src/Core/Services/Implementations/AppleIapService.cs index 35cd2ac113..2fa8edfd7e 100644 --- a/src/Core/Services/Implementations/AppleIapService.cs +++ b/src/Core/Services/Implementations/AppleIapService.cs @@ -7,126 +7,127 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class AppleIapService : IAppleIapService +namespace Bit.Core.Services { - private readonly HttpClient _httpClient = new HttpClient(); - - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly IMetaDataRepository _metaDataRespository; - private readonly ILogger _logger; - - public AppleIapService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - IMetaDataRepository metaDataRespository, - ILogger logger) + public class AppleIapService : IAppleIapService { - _globalSettings = globalSettings; - _hostingEnvironment = hostingEnvironment; - _metaDataRespository = metaDataRespository; - _logger = logger; - } + private readonly HttpClient _httpClient = new HttpClient(); - public async Task GetVerifiedReceiptStatusAsync(string receiptData) - { - var receiptStatus = await GetReceiptStatusAsync(receiptData); - if (receiptStatus?.Status != 0) - { - return null; - } - var validEnvironment = _globalSettings.AppleIap.AppInReview || - (!(_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment == "Sandbox") || - ((_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment != "Sandbox"); - var validProductBundle = receiptStatus.Receipt.BundleId == "com.bitwarden.desktop" || - receiptStatus.Receipt.BundleId == "com.8bit.bitwarden"; - var validProduct = receiptStatus.LatestReceiptInfo.LastOrDefault()?.ProductId == "premium_annually"; - var validIds = receiptStatus.GetOriginalTransactionId() != null && - receiptStatus.GetLastTransactionId() != null; - var validTransaction = receiptStatus.GetLastExpiresDate() - .GetValueOrDefault(DateTime.MinValue) > DateTime.UtcNow; - if (validEnvironment && validProductBundle && validProduct && validIds && validTransaction) - { - return receiptStatus; - } - return null; - } + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly IMetaDataRepository _metaDataRespository; + private readonly ILogger _logger; - public async Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId) - { - var originalTransactionId = receiptStatus.GetOriginalTransactionId(); - if (string.IsNullOrWhiteSpace(originalTransactionId)) + public AppleIapService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + IMetaDataRepository metaDataRespository, + ILogger logger) { - throw new Exception("OriginalTransactionId is null"); + _globalSettings = globalSettings; + _hostingEnvironment = hostingEnvironment; + _metaDataRespository = metaDataRespository; + _logger = logger; } - await _metaDataRespository.UpsertAsync("AppleReceipt", originalTransactionId, - new Dictionary + + public async Task GetVerifiedReceiptStatusAsync(string receiptData) + { + var receiptStatus = await GetReceiptStatusAsync(receiptData); + if (receiptStatus?.Status != 0) { - ["Data"] = receiptStatus.GetReceiptData(), - ["UserId"] = userId.ToString() - }); - } - - public async Task> GetReceiptAsync(string originalTransactionId) - { - var receipt = await _metaDataRespository.GetAsync("AppleReceipt", originalTransactionId); - if (receipt == null) - { - return null; - } - return new Tuple(receipt.ContainsKey("Data") ? receipt["Data"] : null, - receipt.ContainsKey("UserId") ? new Guid(receipt["UserId"]) : (Guid?)null); - } - - // Internal for testing - internal async Task GetReceiptStatusAsync(string receiptData, bool prod = true, - int attempt = 0, AppleReceiptStatus lastReceiptStatus = null) - { - try - { - if (attempt > 4) - { - throw new Exception( - $"Failed verifying Apple IAP after too many attempts. Last attempt status: {lastReceiptStatus?.Status.ToString() ?? "null"}"); + return null; } - - var url = string.Format("https://{0}.itunes.apple.com/verifyReceipt", prod ? "buy" : "sandbox"); - - var response = await _httpClient.PostAsJsonAsync(url, new AppleVerifyReceiptRequestModel + var validEnvironment = _globalSettings.AppleIap.AppInReview || + (!(_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment == "Sandbox") || + ((_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment != "Sandbox"); + var validProductBundle = receiptStatus.Receipt.BundleId == "com.bitwarden.desktop" || + receiptStatus.Receipt.BundleId == "com.8bit.bitwarden"; + var validProduct = receiptStatus.LatestReceiptInfo.LastOrDefault()?.ProductId == "premium_annually"; + var validIds = receiptStatus.GetOriginalTransactionId() != null && + receiptStatus.GetLastTransactionId() != null; + var validTransaction = receiptStatus.GetLastExpiresDate() + .GetValueOrDefault(DateTime.MinValue) > DateTime.UtcNow; + if (validEnvironment && validProductBundle && validProduct && validIds && validTransaction) { - ReceiptData = receiptData, - Password = _globalSettings.AppleIap.Password - }); - - if (response.IsSuccessStatusCode) - { - var receiptStatus = await response.Content.ReadFromJsonAsync(); - if (receiptStatus.Status == 21007) - { - return await GetReceiptStatusAsync(receiptData, false, attempt + 1, receiptStatus); - } - else if (receiptStatus.Status == 21005) - { - await Task.Delay(2000); - return await GetReceiptStatusAsync(receiptData, prod, attempt + 1, receiptStatus); - } return receiptStatus; } + return null; } - catch (Exception e) + + public async Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId) { - _logger.LogWarning(e, "Error verifying Apple IAP receipt."); + var originalTransactionId = receiptStatus.GetOriginalTransactionId(); + if (string.IsNullOrWhiteSpace(originalTransactionId)) + { + throw new Exception("OriginalTransactionId is null"); + } + await _metaDataRespository.UpsertAsync("AppleReceipt", originalTransactionId, + new Dictionary + { + ["Data"] = receiptStatus.GetReceiptData(), + ["UserId"] = userId.ToString() + }); } - return null; + + public async Task> GetReceiptAsync(string originalTransactionId) + { + var receipt = await _metaDataRespository.GetAsync("AppleReceipt", originalTransactionId); + if (receipt == null) + { + return null; + } + return new Tuple(receipt.ContainsKey("Data") ? receipt["Data"] : null, + receipt.ContainsKey("UserId") ? new Guid(receipt["UserId"]) : (Guid?)null); + } + + // Internal for testing + internal async Task GetReceiptStatusAsync(string receiptData, bool prod = true, + int attempt = 0, AppleReceiptStatus lastReceiptStatus = null) + { + try + { + if (attempt > 4) + { + throw new Exception( + $"Failed verifying Apple IAP after too many attempts. Last attempt status: {lastReceiptStatus?.Status.ToString() ?? "null"}"); + } + + var url = string.Format("https://{0}.itunes.apple.com/verifyReceipt", prod ? "buy" : "sandbox"); + + var response = await _httpClient.PostAsJsonAsync(url, new AppleVerifyReceiptRequestModel + { + ReceiptData = receiptData, + Password = _globalSettings.AppleIap.Password + }); + + if (response.IsSuccessStatusCode) + { + var receiptStatus = await response.Content.ReadFromJsonAsync(); + if (receiptStatus.Status == 21007) + { + return await GetReceiptStatusAsync(receiptData, false, attempt + 1, receiptStatus); + } + else if (receiptStatus.Status == 21005) + { + await Task.Delay(2000); + return await GetReceiptStatusAsync(receiptData, prod, attempt + 1, receiptStatus); + } + return receiptStatus; + } + } + catch (Exception e) + { + _logger.LogWarning(e, "Error verifying Apple IAP receipt."); + } + return null; + } + } + + public class AppleVerifyReceiptRequestModel + { + [JsonPropertyName("receipt-data")] + public string ReceiptData { get; set; } + [JsonPropertyName("password")] + public string Password { get; set; } } } - -public class AppleVerifyReceiptRequestModel -{ - [JsonPropertyName("receipt-data")] - public string ReceiptData { get; set; } - [JsonPropertyName("password")] - public string Password { get; set; } -} diff --git a/src/Core/Services/Implementations/AzureAttachmentStorageService.cs b/src/Core/Services/Implementations/AzureAttachmentStorageService.cs index edc35e03a3..6a9e8f77fb 100644 --- a/src/Core/Services/Implementations/AzureAttachmentStorageService.cs +++ b/src/Core/Services/Implementations/AzureAttachmentStorageService.cs @@ -7,259 +7,260 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class AzureAttachmentStorageService : IAttachmentStorageService +namespace Bit.Core.Services { - public FileUploadType FileUploadType => FileUploadType.Azure; - public const string EventGridEnabledContainerName = "attachments-v2"; - private const string _defaultContainerName = "attachments"; - private readonly static string[] _attachmentContainerName = { "attachments", "attachments-v2" }; - private static readonly TimeSpan blobLinkLiveTime = TimeSpan.FromMinutes(1); - private readonly BlobServiceClient _blobServiceClient; - private readonly Dictionary _attachmentContainers = new Dictionary(); - private readonly ILogger _logger; - - private string BlobName(Guid cipherId, CipherAttachment.MetaData attachmentData, Guid? organizationId = null, bool temp = false) => - string.Concat( - temp ? "temp/" : "", - $"{cipherId}/", - organizationId != null ? $"{organizationId.Value}/" : "", - attachmentData.AttachmentId - ); - - public static (string cipherId, string organizationId, string attachmentId) IdentifiersFromBlobName(string blobName) + public class AzureAttachmentStorageService : IAttachmentStorageService { - var parts = blobName.Split('/'); - switch (parts.Length) + public FileUploadType FileUploadType => FileUploadType.Azure; + public const string EventGridEnabledContainerName = "attachments-v2"; + private const string _defaultContainerName = "attachments"; + private readonly static string[] _attachmentContainerName = { "attachments", "attachments-v2" }; + private static readonly TimeSpan blobLinkLiveTime = TimeSpan.FromMinutes(1); + private readonly BlobServiceClient _blobServiceClient; + private readonly Dictionary _attachmentContainers = new Dictionary(); + private readonly ILogger _logger; + + private string BlobName(Guid cipherId, CipherAttachment.MetaData attachmentData, Guid? organizationId = null, bool temp = false) => + string.Concat( + temp ? "temp/" : "", + $"{cipherId}/", + organizationId != null ? $"{organizationId.Value}/" : "", + attachmentData.AttachmentId + ); + + public static (string cipherId, string organizationId, string attachmentId) IdentifiersFromBlobName(string blobName) { - case 4: - return (parts[1], parts[2], parts[3]); - case 3: - if (parts[0] == "temp") - { - return (parts[1], null, parts[2]); - } - else - { - return (parts[0], parts[1], parts[2]); - } - case 2: - return (parts[0], null, parts[1]); - default: - throw new Exception("Cannot determine cipher information from blob name"); - } - } - - public AzureAttachmentStorageService( - GlobalSettings globalSettings, - ILogger logger) - { - _blobServiceClient = new BlobServiceClient(globalSettings.Attachment.ConnectionString); - _logger = logger; - } - - public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - await InitAsync(attachmentData.ContainerName); - var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(blobLinkLiveTime)); - return sasUri.ToString(); - } - - public async Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - await InitAsync(EventGridEnabledContainerName); - var blobClient = _attachmentContainers[EventGridEnabledContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - attachmentData.ContainerName = EventGridEnabledContainerName; - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(blobLinkLiveTime)); - return sasUri.ToString(); - } - - public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - attachmentData.ContainerName = _defaultContainerName; - await InitAsync(_defaultContainerName); - var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - - var metadata = new Dictionary(); - metadata.Add("cipherId", cipher.Id.ToString()); - if (cipher.UserId.HasValue) - { - metadata.Add("userId", cipher.UserId.Value.ToString()); - } - else - { - metadata.Add("organizationId", cipher.OrganizationId.Value.ToString()); + var parts = blobName.Split('/'); + switch (parts.Length) + { + case 4: + return (parts[1], parts[2], parts[3]); + case 3: + if (parts[0] == "temp") + { + return (parts[1], null, parts[2]); + } + else + { + return (parts[0], parts[1], parts[2]); + } + case 2: + return (parts[0], null, parts[1]); + default: + throw new Exception("Cannot determine cipher information from blob name"); + } } - var headers = new BlobHttpHeaders + public AzureAttachmentStorageService( + GlobalSettings globalSettings, + ILogger logger) { - ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" - }; - await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); - } - - public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - attachmentData.ContainerName = _defaultContainerName; - await InitAsync(_defaultContainerName); - var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient( - BlobName(cipherId, attachmentData, organizationId, temp: true)); - - var metadata = new Dictionary(); - metadata.Add("cipherId", cipherId.ToString()); - metadata.Add("organizationId", organizationId.ToString()); - - var headers = new BlobHttpHeaders - { - ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" - }; - await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); - } - - public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData data) - { - await InitAsync(data.ContainerName); - var source = _attachmentContainers[data.ContainerName].GetBlobClient( - BlobName(cipherId, data, organizationId, temp: true)); - if (!await source.ExistsAsync()) - { - return; + _blobServiceClient = new BlobServiceClient(globalSettings.Attachment.ConnectionString); + _logger = logger; } - await InitAsync(_defaultContainerName); - var dest = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipherId, data)); - if (!await dest.ExistsAsync()) + public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) { - return; + await InitAsync(attachmentData.ContainerName); + var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(blobLinkLiveTime)); + return sasUri.ToString(); } - var original = _attachmentContainers[_defaultContainerName].GetBlobClient( - BlobName(cipherId, data, temp: true)); - await original.DeleteIfExistsAsync(); - await original.StartCopyFromUriAsync(dest.Uri); - - await dest.DeleteIfExistsAsync(); - await dest.StartCopyFromUriAsync(source.Uri); - } - - public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) - { - await InitAsync(attachmentData.ContainerName); - var source = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( - BlobName(cipherId, attachmentData, organizationId, temp: true)); - await source.DeleteIfExistsAsync(); - - await InitAsync(originalContainer); - var original = _attachmentContainers[originalContainer].GetBlobClient( - BlobName(cipherId, attachmentData, temp: true)); - if (!await original.ExistsAsync()) + public async Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) { - return; + await InitAsync(EventGridEnabledContainerName); + var blobClient = _attachmentContainers[EventGridEnabledContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + attachmentData.ContainerName = EventGridEnabledContainerName; + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(blobLinkLiveTime)); + return sasUri.ToString(); } - var dest = _attachmentContainers[originalContainer].GetBlobClient( - BlobName(cipherId, attachmentData)); - await dest.DeleteIfExistsAsync(); - await dest.StartCopyFromUriAsync(original.Uri); - await original.DeleteIfExistsAsync(); - } - - public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(attachmentData.ContainerName); - var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( - BlobName(cipherId, attachmentData)); - await blobClient.DeleteIfExistsAsync(); - } - - public async Task CleanupAsync(Guid cipherId) => await DeleteAttachmentsForPathAsync($"temp/{cipherId}"); - - public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) => - await DeleteAttachmentsForPathAsync(cipherId.ToString()); - - public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) - { - await InitAsync(_defaultContainerName); - } - - public async Task DeleteAttachmentsForUserAsync(Guid userId) - { - await InitAsync(_defaultContainerName); - } - - public async Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) - { - await InitAsync(attachmentData.ContainerName); - - var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - - try + public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) { - var blobProperties = await blobClient.GetPropertiesAsync(); + attachmentData.ContainerName = _defaultContainerName; + await InitAsync(_defaultContainerName); + var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - var metadata = blobProperties.Value.Metadata; - metadata["cipherId"] = cipher.Id.ToString(); + var metadata = new Dictionary(); + metadata.Add("cipherId", cipher.Id.ToString()); if (cipher.UserId.HasValue) { - metadata["userId"] = cipher.UserId.Value.ToString(); + metadata.Add("userId", cipher.UserId.Value.ToString()); } else { - metadata["organizationId"] = cipher.OrganizationId.Value.ToString(); + metadata.Add("organizationId", cipher.OrganizationId.Value.ToString()); } - await blobClient.SetMetadataAsync(metadata); var headers = new BlobHttpHeaders { ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" }; - await blobClient.SetHttpHeadersAsync(headers); + await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); + } - var length = blobProperties.Value.ContentLength; - if (length < attachmentData.Size - leeway || length > attachmentData.Size + leeway) + public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + attachmentData.ContainerName = _defaultContainerName; + await InitAsync(_defaultContainerName); + var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient( + BlobName(cipherId, attachmentData, organizationId, temp: true)); + + var metadata = new Dictionary(); + metadata.Add("cipherId", cipherId.ToString()); + metadata.Add("organizationId", organizationId.ToString()); + + var headers = new BlobHttpHeaders { - return (false, length); + ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" + }; + await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); + } + + public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData data) + { + await InitAsync(data.ContainerName); + var source = _attachmentContainers[data.ContainerName].GetBlobClient( + BlobName(cipherId, data, organizationId, temp: true)); + if (!await source.ExistsAsync()) + { + return; } - return (true, length); - } - catch (Exception ex) - { - _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); - return (false, null); - } - } - - private async Task DeleteAttachmentsForPathAsync(string path) - { - foreach (var container in _attachmentContainerName) - { - await InitAsync(container); - var blobContainerClient = _attachmentContainers[container]; - - var blobItems = blobContainerClient.GetBlobsAsync(BlobTraits.None, BlobStates.None, prefix: path); - await foreach (var blobItem in blobItems) + await InitAsync(_defaultContainerName); + var dest = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipherId, data)); + if (!await dest.ExistsAsync()) { - BlobClient blobClient = blobContainerClient.GetBlobClient(blobItem.Name); - await blobClient.DeleteIfExistsAsync(); + return; + } + + var original = _attachmentContainers[_defaultContainerName].GetBlobClient( + BlobName(cipherId, data, temp: true)); + await original.DeleteIfExistsAsync(); + await original.StartCopyFromUriAsync(dest.Uri); + + await dest.DeleteIfExistsAsync(); + await dest.StartCopyFromUriAsync(source.Uri); + } + + public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) + { + await InitAsync(attachmentData.ContainerName); + var source = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( + BlobName(cipherId, attachmentData, organizationId, temp: true)); + await source.DeleteIfExistsAsync(); + + await InitAsync(originalContainer); + var original = _attachmentContainers[originalContainer].GetBlobClient( + BlobName(cipherId, attachmentData, temp: true)); + if (!await original.ExistsAsync()) + { + return; + } + + var dest = _attachmentContainers[originalContainer].GetBlobClient( + BlobName(cipherId, attachmentData)); + await dest.DeleteIfExistsAsync(); + await dest.StartCopyFromUriAsync(original.Uri); + await original.DeleteIfExistsAsync(); + } + + public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) + { + await InitAsync(attachmentData.ContainerName); + var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( + BlobName(cipherId, attachmentData)); + await blobClient.DeleteIfExistsAsync(); + } + + public async Task CleanupAsync(Guid cipherId) => await DeleteAttachmentsForPathAsync($"temp/{cipherId}"); + + public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) => + await DeleteAttachmentsForPathAsync(cipherId.ToString()); + + public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) + { + await InitAsync(_defaultContainerName); + } + + public async Task DeleteAttachmentsForUserAsync(Guid userId) + { + await InitAsync(_defaultContainerName); + } + + public async Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) + { + await InitAsync(attachmentData.ContainerName); + + var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + + try + { + var blobProperties = await blobClient.GetPropertiesAsync(); + + var metadata = blobProperties.Value.Metadata; + metadata["cipherId"] = cipher.Id.ToString(); + if (cipher.UserId.HasValue) + { + metadata["userId"] = cipher.UserId.Value.ToString(); + } + else + { + metadata["organizationId"] = cipher.OrganizationId.Value.ToString(); + } + await blobClient.SetMetadataAsync(metadata); + + var headers = new BlobHttpHeaders + { + ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" + }; + await blobClient.SetHttpHeadersAsync(headers); + + var length = blobProperties.Value.ContentLength; + if (length < attachmentData.Size - leeway || length > attachmentData.Size + leeway) + { + return (false, length); + } + + return (true, length); + } + catch (Exception ex) + { + _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); + return (false, null); } } - } - private async Task InitAsync(string containerName) - { - if (!_attachmentContainers.ContainsKey(containerName) || _attachmentContainers[containerName] == null) + private async Task DeleteAttachmentsForPathAsync(string path) { - _attachmentContainers[containerName] = _blobServiceClient.GetBlobContainerClient(containerName); - if (containerName == "attachments") + foreach (var container in _attachmentContainerName) { - await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.Blob, null, null); + await InitAsync(container); + var blobContainerClient = _attachmentContainers[container]; + + var blobItems = blobContainerClient.GetBlobsAsync(BlobTraits.None, BlobStates.None, prefix: path); + await foreach (var blobItem in blobItems) + { + BlobClient blobClient = blobContainerClient.GetBlobClient(blobItem.Name); + await blobClient.DeleteIfExistsAsync(); + } } - else + } + + private async Task InitAsync(string containerName) + { + if (!_attachmentContainers.ContainsKey(containerName) || _attachmentContainers[containerName] == null) { - await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.None, null, null); + _attachmentContainers[containerName] = _blobServiceClient.GetBlobContainerClient(containerName); + if (containerName == "attachments") + { + await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.Blob, null, null); + } + else + { + await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.None, null, null); + } } } } diff --git a/src/Core/Services/Implementations/AzureQueueBlockIpService.cs b/src/Core/Services/Implementations/AzureQueueBlockIpService.cs index ab78c46548..8682b0c499 100644 --- a/src/Core/Services/Implementations/AzureQueueBlockIpService.cs +++ b/src/Core/Services/Implementations/AzureQueueBlockIpService.cs @@ -1,36 +1,37 @@ using Azure.Storage.Queues; using Bit.Core.Settings; -namespace Bit.Core.Services; - -public class AzureQueueBlockIpService : IBlockIpService +namespace Bit.Core.Services { - private readonly QueueClient _blockIpQueueClient; - private readonly QueueClient _unblockIpQueueClient; - private Tuple _lastBlock; - - public AzureQueueBlockIpService( - GlobalSettings globalSettings) + public class AzureQueueBlockIpService : IBlockIpService { - _blockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "blockip"); - _unblockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "unblockip"); - } + private readonly QueueClient _blockIpQueueClient; + private readonly QueueClient _unblockIpQueueClient; + private Tuple _lastBlock; - public async Task BlockIpAsync(string ipAddress, bool permanentBlock) - { - var now = DateTime.UtcNow; - if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && - (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) + public AzureQueueBlockIpService( + GlobalSettings globalSettings) { - // Already blocked this IP recently. - return; + _blockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "blockip"); + _unblockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "unblockip"); } - _lastBlock = new Tuple(ipAddress, permanentBlock, now); - await _blockIpQueueClient.SendMessageAsync(ipAddress); - if (!permanentBlock) + public async Task BlockIpAsync(string ipAddress, bool permanentBlock) { - await _unblockIpQueueClient.SendMessageAsync(ipAddress, new TimeSpan(0, 15, 0)); + var now = DateTime.UtcNow; + if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && + (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) + { + // Already blocked this IP recently. + return; + } + + _lastBlock = new Tuple(ipAddress, permanentBlock, now); + await _blockIpQueueClient.SendMessageAsync(ipAddress); + if (!permanentBlock) + { + await _unblockIpQueueClient.SendMessageAsync(ipAddress, new TimeSpan(0, 15, 0)); + } } } } diff --git a/src/Core/Services/Implementations/AzureQueueEventWriteService.cs b/src/Core/Services/Implementations/AzureQueueEventWriteService.cs index f81175f7b5..bf74677f1d 100644 --- a/src/Core/Services/Implementations/AzureQueueEventWriteService.cs +++ b/src/Core/Services/Implementations/AzureQueueEventWriteService.cs @@ -3,14 +3,15 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services; - -public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService +namespace Bit.Core.Services { - public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Events.ConnectionString, "event"), - JsonHelpers.IgnoreWritingNull) - { } + public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService + { + public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( + new QueueClient(globalSettings.Events.ConnectionString, "event"), + JsonHelpers.IgnoreWritingNull) + { } - public Task CreateAsync(IEvent e) => CreateManyAsync(new[] { e }); + public Task CreateAsync(IEvent e) => CreateManyAsync(new[] { e }); + } } diff --git a/src/Core/Services/Implementations/AzureQueueMailService.cs b/src/Core/Services/Implementations/AzureQueueMailService.cs index 92d6fd17bb..e05c106ea5 100644 --- a/src/Core/Services/Implementations/AzureQueueMailService.cs +++ b/src/Core/Services/Implementations/AzureQueueMailService.cs @@ -3,18 +3,19 @@ using Bit.Core.Models.Mail; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services; - -public class AzureQueueMailService : AzureQueueService, IMailEnqueuingService +namespace Bit.Core.Services { - public AzureQueueMailService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Mail.ConnectionString, "mail"), - JsonHelpers.IgnoreWritingNull) - { } + public class AzureQueueMailService : AzureQueueService, IMailEnqueuingService + { + public AzureQueueMailService(GlobalSettings globalSettings) : base( + new QueueClient(globalSettings.Mail.ConnectionString, "mail"), + JsonHelpers.IgnoreWritingNull) + { } - public Task EnqueueAsync(IMailQueueMessage message, Func fallback) => - CreateManyAsync(new[] { message }); + public Task EnqueueAsync(IMailQueueMessage message, Func fallback) => + CreateManyAsync(new[] { message }); - public Task EnqueueManyAsync(IEnumerable messages, Func fallback) => - CreateManyAsync(messages); + public Task EnqueueManyAsync(IEnumerable messages, Func fallback) => + CreateManyAsync(messages); + } } diff --git a/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs b/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs index fb7bcafca2..7062c6c188 100644 --- a/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs +++ b/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs @@ -8,189 +8,190 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Services; - -public class AzureQueuePushNotificationService : IPushNotificationService +namespace Bit.Core.Services { - private readonly QueueClient _queueClient; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - - public AzureQueuePushNotificationService( - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor) + public class AzureQueuePushNotificationService : IPushNotificationService { - _queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"); - _globalSettings = globalSettings; - _httpContextAccessor = httpContextAccessor; - } + private readonly QueueClient _queueClient; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } - - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } - - public async Task PushSyncCipherDeleteAsync(Cipher cipher) - { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + public AzureQueuePushNotificationService( + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor) { - var message = new SyncCipherPushNotification + _queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"); + _globalSettings = globalSettings; + _httpContextAccessor = httpContextAccessor; + } + + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } + + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } + + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } + + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - Id = cipher.Id, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, + var message = new SyncCipherPushNotification + { + Id = cipher.Id, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, + }; + + await SendMessageAsync(type, message, true); + } + else if (cipher.UserId.HasValue) + { + var message = new SyncCipherPushNotification + { + Id = cipher.Id, + UserId = cipher.UserId, + RevisionDate = cipher.RevisionDate, + }; + + await SendMessageAsync(type, message, true); + } + } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate }; await SendMessageAsync(type, message, true); } - else if (cipher.UserId.HasValue) + + public async Task PushSyncCiphersAsync(Guid userId) { - var message = new SyncCipherPushNotification + await PushUserAsync(userId, PushType.SyncCiphers); + } + + public async Task PushSyncVaultAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncVault); + } + + public async Task PushSyncOrgKeysAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification { - Id = cipher.Id, - UserId = cipher.UserId, - RevisionDate = cipher.RevisionDate, + UserId = userId, + Date = DateTime.UtcNow }; - await SendMessageAsync(type, message, true); + await SendMessageAsync(type, message, false); } - } - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification + public async Task PushSyncSendCreateAsync(Send send) { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate - }; + await PushSendAsync(send, PushType.SyncSendCreate); + } - await SendMessageAsync(type, message, true); - } - - public async Task PushSyncCiphersAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification + public async Task PushSyncSendUpdateAsync(Send send) { - UserId = userId, - Date = DateTime.UtcNow - }; + await PushSendAsync(send, PushType.SyncSendUpdate); + } - await SendMessageAsync(type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) + public async Task PushSyncSendDeleteAsync(Send send) { - var message = new SyncSendPushNotification + await PushSendAsync(send, PushType.SyncSendDelete); + } + + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate - }; + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; - await SendMessageAsync(type, message, true); + await SendMessageAsync(type, message, true); + } } - } - private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) - { - var contextId = GetContextIdentifier(excludeCurrentContext); - var message = JsonSerializer.Serialize(new PushNotificationData(type, payload, contextId), - JsonHelpers.IgnoreWritingNull); - await _queueClient.SendMessageAsync(message); - } - - private string GetContextIdentifier(bool excludeCurrentContext) - { - if (!excludeCurrentContext) + private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) { - return null; + var contextId = GetContextIdentifier(excludeCurrentContext); + var message = JsonSerializer.Serialize(new PushNotificationData(type, payload, contextId), + JsonHelpers.IgnoreWritingNull); + await _queueClient.SendMessageAsync(message); } - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - return currentContext?.DeviceIdentifier; - } + private string GetContextIdentifier(bool excludeCurrentContext) + { + if (!excludeCurrentContext) + { + return null; + } - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); - } + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + return currentContext?.DeviceIdentifier; + } - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs b/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs index 6abbe97836..e3b0f0ecfc 100644 --- a/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs +++ b/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs @@ -5,44 +5,45 @@ using Bit.Core.Models.Business; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services; - -public class AzureQueueReferenceEventService : IReferenceEventService +namespace Bit.Core.Services { - private const string _queueName = "reference-events"; - - private readonly QueueClient _queueClient; - private readonly GlobalSettings _globalSettings; - - public AzureQueueReferenceEventService( - GlobalSettings globalSettings) + public class AzureQueueReferenceEventService : IReferenceEventService { - _queueClient = new QueueClient(globalSettings.Events.ConnectionString, _queueName); - _globalSettings = globalSettings; - } + private const string _queueName = "reference-events"; - public async Task RaiseEventAsync(ReferenceEvent referenceEvent) - { - await SendMessageAsync(referenceEvent); - } + private readonly QueueClient _queueClient; + private readonly GlobalSettings _globalSettings; - private async Task SendMessageAsync(ReferenceEvent referenceEvent) - { - if (_globalSettings.SelfHosted) + public AzureQueueReferenceEventService( + GlobalSettings globalSettings) { - // Ignore for self-hosted - return; + _queueClient = new QueueClient(globalSettings.Events.ConnectionString, _queueName); + _globalSettings = globalSettings; } - try + + public async Task RaiseEventAsync(ReferenceEvent referenceEvent) { - var message = JsonSerializer.Serialize(referenceEvent, JsonHelpers.IgnoreWritingNullAndCamelCase); - // Messages need to be base64 encoded - var encodedMessage = Convert.ToBase64String(Encoding.UTF8.GetBytes(message)); - await _queueClient.SendMessageAsync(encodedMessage); + await SendMessageAsync(referenceEvent); } - catch + + private async Task SendMessageAsync(ReferenceEvent referenceEvent) { - // Ignore failure + if (_globalSettings.SelfHosted) + { + // Ignore for self-hosted + return; + } + try + { + var message = JsonSerializer.Serialize(referenceEvent, JsonHelpers.IgnoreWritingNullAndCamelCase); + // Messages need to be base64 encoded + var encodedMessage = Convert.ToBase64String(Encoding.UTF8.GetBytes(message)); + await _queueClient.SendMessageAsync(encodedMessage); + } + catch + { + // Ignore failure + } } } } diff --git a/src/Core/Services/Implementations/AzureQueueService.cs b/src/Core/Services/Implementations/AzureQueueService.cs index 11c1a58ae3..942be2680e 100644 --- a/src/Core/Services/Implementations/AzureQueueService.cs +++ b/src/Core/Services/Implementations/AzureQueueService.cs @@ -3,75 +3,76 @@ using System.Text.Json; using Azure.Storage.Queues; using Bit.Core.Utilities; -namespace Bit.Core.Services; - -public abstract class AzureQueueService +namespace Bit.Core.Services { - protected QueueClient _queueClient; - protected JsonSerializerOptions _jsonOptions; - - protected AzureQueueService(QueueClient queueClient, JsonSerializerOptions jsonOptions) + public abstract class AzureQueueService { - _queueClient = queueClient; - _jsonOptions = jsonOptions; - } + protected QueueClient _queueClient; + protected JsonSerializerOptions _jsonOptions; - public async Task CreateManyAsync(IEnumerable messages) - { - if (messages?.Any() != true) + protected AzureQueueService(QueueClient queueClient, JsonSerializerOptions jsonOptions) { - return; + _queueClient = queueClient; + _jsonOptions = jsonOptions; } - foreach (var json in SerializeMany(messages, _jsonOptions)) + public async Task CreateManyAsync(IEnumerable messages) { - await _queueClient.SendMessageAsync(json); - } - } - - protected IEnumerable SerializeMany(IEnumerable messages, JsonSerializerOptions jsonOptions) - { - // Calculate Base-64 encoded text with padding - int getBase64Size(int byteCount) => ((4 * byteCount / 3) + 3) & ~3; - - var messagesList = new List(); - var messagesListSize = 0; - - int calculateByteSize(int totalSize, int toAdd) => - // Calculate the total length this would be w/ "[]" and commas - getBase64Size(totalSize + toAdd + messagesList.Count + 2); - - // Format the final array string, i.e. [{...},{...}] - string getArrayString() - { - if (messagesList.Count == 1) + if (messages?.Any() != true) { - return CoreHelpers.Base64EncodeString(messagesList[0]); + return; + } + + foreach (var json in SerializeMany(messages, _jsonOptions)) + { + await _queueClient.SendMessageAsync(json); } - return CoreHelpers.Base64EncodeString( - string.Concat("[", string.Join(',', messagesList), "]")); } - var serializedMessages = messages.Select(message => - JsonSerializer.Serialize(message, jsonOptions)); - - foreach (var message in serializedMessages) + protected IEnumerable SerializeMany(IEnumerable messages, JsonSerializerOptions jsonOptions) { - var messageSize = Encoding.UTF8.GetByteCount(message); - if (calculateByteSize(messagesListSize, messageSize) > _queueClient.MessageMaxBytes) + // Calculate Base-64 encoded text with padding + int getBase64Size(int byteCount) => ((4 * byteCount / 3) + 3) & ~3; + + var messagesList = new List(); + var messagesListSize = 0; + + int calculateByteSize(int totalSize, int toAdd) => + // Calculate the total length this would be w/ "[]" and commas + getBase64Size(totalSize + toAdd + messagesList.Count + 2); + + // Format the final array string, i.e. [{...},{...}] + string getArrayString() + { + if (messagesList.Count == 1) + { + return CoreHelpers.Base64EncodeString(messagesList[0]); + } + return CoreHelpers.Base64EncodeString( + string.Concat("[", string.Join(',', messagesList), "]")); + } + + var serializedMessages = messages.Select(message => + JsonSerializer.Serialize(message, jsonOptions)); + + foreach (var message in serializedMessages) + { + var messageSize = Encoding.UTF8.GetByteCount(message); + if (calculateByteSize(messagesListSize, messageSize) > _queueClient.MessageMaxBytes) + { + yield return getArrayString(); + messagesListSize = 0; + messagesList.Clear(); + } + + messagesList.Add(message); + messagesListSize += messageSize; + } + + if (messagesList.Any()) { yield return getArrayString(); - messagesListSize = 0; - messagesList.Clear(); } - - messagesList.Add(message); - messagesListSize += messageSize; - } - - if (messagesList.Any()) - { - yield return getArrayString(); } } } diff --git a/src/Core/Services/Implementations/AzureSendFileStorageService.cs b/src/Core/Services/Implementations/AzureSendFileStorageService.cs index d1d7822f28..94a0aaaee5 100644 --- a/src/Core/Services/Implementations/AzureSendFileStorageService.cs +++ b/src/Core/Services/Implementations/AzureSendFileStorageService.cs @@ -6,136 +6,137 @@ using Bit.Core.Enums; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class AzureSendFileStorageService : ISendFileStorageService +namespace Bit.Core.Services { - public const string FilesContainerName = "sendfiles"; - private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1); - private readonly BlobServiceClient _blobServiceClient; - private readonly ILogger _logger; - private BlobContainerClient _sendFilesContainerClient; - - public FileUploadType FileUploadType => FileUploadType.Azure; - - public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0]; - public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}"; - - public AzureSendFileStorageService( - GlobalSettings globalSettings, - ILogger logger) + public class AzureSendFileStorageService : ISendFileStorageService { - _blobServiceClient = new BlobServiceClient(globalSettings.Send.ConnectionString); - _logger = logger; - } + public const string FilesContainerName = "sendfiles"; + private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1); + private readonly BlobServiceClient _blobServiceClient; + private readonly ILogger _logger; + private BlobContainerClient _sendFilesContainerClient; - public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) - { - await InitAsync(); + public FileUploadType FileUploadType => FileUploadType.Azure; - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0]; + public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}"; - var metadata = new Dictionary(); - if (send.UserId.HasValue) + public AzureSendFileStorageService( + GlobalSettings globalSettings, + ILogger logger) { - metadata.Add("userId", send.UserId.Value.ToString()); - } - else - { - metadata.Add("organizationId", send.OrganizationId.Value.ToString()); + _blobServiceClient = new BlobServiceClient(globalSettings.Send.ConnectionString); + _logger = logger; } - var headers = new BlobHttpHeaders + public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) { - ContentDisposition = $"attachment; filename=\"{fileId}\"" - }; + await InitAsync(); - await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); - } - - public async Task DeleteFileAsync(Send send, string fileId) => await DeleteBlobAsync(BlobName(send, fileId)); - - public async Task DeleteBlobAsync(string blobName) - { - await InitAsync(); - var blobClient = _sendFilesContainerClient.GetBlobClient(blobName); - await blobClient.DeleteIfExistsAsync(); - } - - public async Task DeleteFilesForOrganizationAsync(Guid organizationId) - { - await InitAsync(); - } - - public async Task DeleteFilesForUserAsync(Guid userId) - { - await InitAsync(); - } - - public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) - { - await InitAsync(); - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(_downloadLinkLiveTime)); - return sasUri.ToString(); - } - - public async Task GetSendFileUploadUrlAsync(Send send, string fileId) - { - await InitAsync(); - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(_downloadLinkLiveTime)); - return sasUri.ToString(); - } - - public async Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) - { - await InitAsync(); - - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - - try - { - var blobProperties = await blobClient.GetPropertiesAsync(); - var metadata = blobProperties.Value.Metadata; + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + var metadata = new Dictionary(); if (send.UserId.HasValue) { - metadata["userId"] = send.UserId.Value.ToString(); + metadata.Add("userId", send.UserId.Value.ToString()); } else { - metadata["organizationId"] = send.OrganizationId.Value.ToString(); + metadata.Add("organizationId", send.OrganizationId.Value.ToString()); } - await blobClient.SetMetadataAsync(metadata); var headers = new BlobHttpHeaders { ContentDisposition = $"attachment; filename=\"{fileId}\"" }; - await blobClient.SetHttpHeadersAsync(headers); - var length = blobProperties.Value.ContentLength; - if (length < expectedFileSize - leeway || length > expectedFileSize + leeway) + await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); + } + + public async Task DeleteFileAsync(Send send, string fileId) => await DeleteBlobAsync(BlobName(send, fileId)); + + public async Task DeleteBlobAsync(string blobName) + { + await InitAsync(); + var blobClient = _sendFilesContainerClient.GetBlobClient(blobName); + await blobClient.DeleteIfExistsAsync(); + } + + public async Task DeleteFilesForOrganizationAsync(Guid organizationId) + { + await InitAsync(); + } + + public async Task DeleteFilesForUserAsync(Guid userId) + { + await InitAsync(); + } + + public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) + { + await InitAsync(); + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(_downloadLinkLiveTime)); + return sasUri.ToString(); + } + + public async Task GetSendFileUploadUrlAsync(Send send, string fileId) + { + await InitAsync(); + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(_downloadLinkLiveTime)); + return sasUri.ToString(); + } + + public async Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) + { + await InitAsync(); + + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + + try { - return (false, length); + var blobProperties = await blobClient.GetPropertiesAsync(); + var metadata = blobProperties.Value.Metadata; + + if (send.UserId.HasValue) + { + metadata["userId"] = send.UserId.Value.ToString(); + } + else + { + metadata["organizationId"] = send.OrganizationId.Value.ToString(); + } + await blobClient.SetMetadataAsync(metadata); + + var headers = new BlobHttpHeaders + { + ContentDisposition = $"attachment; filename=\"{fileId}\"" + }; + await blobClient.SetHttpHeadersAsync(headers); + + var length = blobProperties.Value.ContentLength; + if (length < expectedFileSize - leeway || length > expectedFileSize + leeway) + { + return (false, length); + } + + return (true, length); + } + catch (Exception ex) + { + _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); + return (false, null); } - - return (true, length); } - catch (Exception ex) - { - _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); - return (false, null); - } - } - private async Task InitAsync() - { - if (_sendFilesContainerClient == null) + private async Task InitAsync() { - _sendFilesContainerClient = _blobServiceClient.GetBlobContainerClient(FilesContainerName); - await _sendFilesContainerClient.CreateIfNotExistsAsync(PublicAccessType.None, null, null); + if (_sendFilesContainerClient == null) + { + _sendFilesContainerClient = _blobServiceClient.GetBlobContainerClient(FilesContainerName); + await _sendFilesContainerClient.CreateIfNotExistsAsync(PublicAccessType.None, null, null); + } } } } diff --git a/src/Core/Services/Implementations/BaseIdentityClientService.cs b/src/Core/Services/Implementations/BaseIdentityClientService.cs index fd9be533b3..2115eba243 100644 --- a/src/Core/Services/Implementations/BaseIdentityClientService.cs +++ b/src/Core/Services/Implementations/BaseIdentityClientService.cs @@ -5,202 +5,203 @@ using System.Text.Json; using Bit.Core.Utilities; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public abstract class BaseIdentityClientService : IDisposable +namespace Bit.Core.Services { - private readonly IHttpClientFactory _httpFactory; - private readonly string _identityScope; - private readonly string _identityClientId; - private readonly string _identityClientSecret; - protected readonly ILogger _logger; - - private JsonDocument _decodedToken; - private DateTime? _nextAuthAttempt = null; - - public BaseIdentityClientService( - IHttpClientFactory httpFactory, - string baseClientServerUri, - string baseIdentityServerUri, - string identityScope, - string identityClientId, - string identityClientSecret, - ILogger logger) + public abstract class BaseIdentityClientService : IDisposable { - _httpFactory = httpFactory; - _identityScope = identityScope; - _identityClientId = identityClientId; - _identityClientSecret = identityClientSecret; - _logger = logger; + private readonly IHttpClientFactory _httpFactory; + private readonly string _identityScope; + private readonly string _identityClientId; + private readonly string _identityClientSecret; + protected readonly ILogger _logger; - Client = _httpFactory.CreateClient("client"); - Client.BaseAddress = new Uri(baseClientServerUri); - Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + private JsonDocument _decodedToken; + private DateTime? _nextAuthAttempt = null; - IdentityClient = _httpFactory.CreateClient("identity"); - IdentityClient.BaseAddress = new Uri(baseIdentityServerUri); - IdentityClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); - } - - protected HttpClient Client { get; private set; } - protected HttpClient IdentityClient { get; private set; } - protected string AccessToken { get; private set; } - - protected Task SendAsync(HttpMethod method, string path) => - SendAsync(method, path, null); - - protected Task SendAsync(HttpMethod method, string path, TRequest body) => - SendAsync(method, path, body); - - protected async Task SendAsync(HttpMethod method, string path, TRequest requestModel) - { - var tokenStateResponse = await HandleTokenStateAsync(); - if (!tokenStateResponse) + public BaseIdentityClientService( + IHttpClientFactory httpFactory, + string baseClientServerUri, + string baseIdentityServerUri, + string identityScope, + string identityClientId, + string identityClientSecret, + ILogger logger) { - return default; + _httpFactory = httpFactory; + _identityScope = identityScope; + _identityClientId = identityClientId; + _identityClientSecret = identityClientSecret; + _logger = logger; + + Client = _httpFactory.CreateClient("client"); + Client.BaseAddress = new Uri(baseClientServerUri); + Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + + IdentityClient = _httpFactory.CreateClient("identity"); + IdentityClient.BaseAddress = new Uri(baseIdentityServerUri); + IdentityClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); } - var message = new TokenHttpRequestMessage(requestModel, AccessToken) - { - Method = method, - RequestUri = new Uri(string.Concat(Client.BaseAddress, path)) - }; - try - { - var response = await Client.SendAsync(message); - return await response.Content.ReadFromJsonAsync(); - } - catch (Exception e) - { - _logger.LogError(12334, e, "Failed to send to {0}.", message.RequestUri.ToString()); - return default; - } - } + protected HttpClient Client { get; private set; } + protected HttpClient IdentityClient { get; private set; } + protected string AccessToken { get; private set; } - protected async Task HandleTokenStateAsync() - { - if (_nextAuthAttempt.HasValue && DateTime.UtcNow > _nextAuthAttempt.Value) - { - return false; - } - _nextAuthAttempt = null; + protected Task SendAsync(HttpMethod method, string path) => + SendAsync(method, path, null); - if (!string.IsNullOrWhiteSpace(AccessToken) && !TokenNeedsRefresh()) + protected Task SendAsync(HttpMethod method, string path, TRequest body) => + SendAsync(method, path, body); + + protected async Task SendAsync(HttpMethod method, string path, TRequest requestModel) { + var tokenStateResponse = await HandleTokenStateAsync(); + if (!tokenStateResponse) + { + return default; + } + + var message = new TokenHttpRequestMessage(requestModel, AccessToken) + { + Method = method, + RequestUri = new Uri(string.Concat(Client.BaseAddress, path)) + }; + try + { + var response = await Client.SendAsync(message); + return await response.Content.ReadFromJsonAsync(); + } + catch (Exception e) + { + _logger.LogError(12334, e, "Failed to send to {0}.", message.RequestUri.ToString()); + return default; + } + } + + protected async Task HandleTokenStateAsync() + { + if (_nextAuthAttempt.HasValue && DateTime.UtcNow > _nextAuthAttempt.Value) + { + return false; + } + _nextAuthAttempt = null; + + if (!string.IsNullOrWhiteSpace(AccessToken) && !TokenNeedsRefresh()) + { + return true; + } + + var requestMessage = new HttpRequestMessage + { + Method = HttpMethod.Post, + RequestUri = new Uri(string.Concat(IdentityClient.BaseAddress, "connect/token")), + Content = new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "scope", _identityScope }, + { "client_id", _identityClientId }, + { "client_secret", _identityClientSecret } + }) + }; + + HttpResponseMessage response = null; + try + { + response = await IdentityClient.SendAsync(requestMessage); + } + catch (Exception e) + { + _logger.LogError(12339, e, "Unable to authenticate with identity server."); + } + + if (response == null) + { + return false; + } + + if (!response.IsSuccessStatusCode) + { + _logger.LogInformation("Unsuccessful token response with status code {StatusCode}", response.StatusCode); + + if (response.StatusCode == HttpStatusCode.BadRequest) + { + _nextAuthAttempt = DateTime.UtcNow.AddDays(1); + } + + if (_logger.IsEnabled(LogLevel.Debug)) + { + var responseBody = await response.Content.ReadAsStringAsync(); + _logger.LogDebug("Error response body:\n{ResponseBody}", responseBody); + } + + return false; + } + + using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); + + AccessToken = jsonDocument.RootElement.GetProperty("access_token").GetString(); return true; } - var requestMessage = new HttpRequestMessage + protected class TokenHttpRequestMessage : HttpRequestMessage { - Method = HttpMethod.Post, - RequestUri = new Uri(string.Concat(IdentityClient.BaseAddress, "connect/token")), - Content = new FormUrlEncodedContent(new Dictionary + public TokenHttpRequestMessage(string token) { - { "grant_type", "client_credentials" }, - { "scope", _identityScope }, - { "client_id", _identityClientId }, - { "client_secret", _identityClientSecret } - }) - }; - - HttpResponseMessage response = null; - try - { - response = await IdentityClient.SendAsync(requestMessage); - } - catch (Exception e) - { - _logger.LogError(12339, e, "Unable to authenticate with identity server."); - } - - if (response == null) - { - return false; - } - - if (!response.IsSuccessStatusCode) - { - _logger.LogInformation("Unsuccessful token response with status code {StatusCode}", response.StatusCode); - - if (response.StatusCode == HttpStatusCode.BadRequest) - { - _nextAuthAttempt = DateTime.UtcNow.AddDays(1); + Headers.Add("Authorization", $"Bearer {token}"); } - if (_logger.IsEnabled(LogLevel.Debug)) + public TokenHttpRequestMessage(object requestObject, string token) + : this(token) { - var responseBody = await response.Content.ReadAsStringAsync(); - _logger.LogDebug("Error response body:\n{ResponseBody}", responseBody); - } - - return false; - } - - using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); - - AccessToken = jsonDocument.RootElement.GetProperty("access_token").GetString(); - return true; - } - - protected class TokenHttpRequestMessage : HttpRequestMessage - { - public TokenHttpRequestMessage(string token) - { - Headers.Add("Authorization", $"Bearer {token}"); - } - - public TokenHttpRequestMessage(object requestObject, string token) - : this(token) - { - if (requestObject != null) - { - Content = JsonContent.Create(requestObject); + if (requestObject != null) + { + Content = JsonContent.Create(requestObject); + } } } - } - protected bool TokenNeedsRefresh(int minutes = 5) - { - var decoded = DecodeToken(); - if (!decoded.RootElement.TryGetProperty("exp", out var expProp)) + protected bool TokenNeedsRefresh(int minutes = 5) { - throw new InvalidOperationException("No exp in token."); + var decoded = DecodeToken(); + if (!decoded.RootElement.TryGetProperty("exp", out var expProp)) + { + throw new InvalidOperationException("No exp in token."); + } + + var expiration = CoreHelpers.FromEpocSeconds(expProp.GetInt64()); + return DateTime.UtcNow.AddMinutes(-1 * minutes) > expiration; } - var expiration = CoreHelpers.FromEpocSeconds(expProp.GetInt64()); - return DateTime.UtcNow.AddMinutes(-1 * minutes) > expiration; - } - - protected JsonDocument DecodeToken() - { - if (_decodedToken != null) + protected JsonDocument DecodeToken() { + if (_decodedToken != null) + { + return _decodedToken; + } + + if (AccessToken == null) + { + throw new InvalidOperationException($"{nameof(AccessToken)} not found."); + } + + var parts = AccessToken.Split('.'); + if (parts.Length != 3) + { + throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); + } + + var decodedBytes = CoreHelpers.Base64UrlDecode(parts[1]); + if (decodedBytes == null || decodedBytes.Length < 1) + { + throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); + } + + _decodedToken = JsonDocument.Parse(decodedBytes); return _decodedToken; } - if (AccessToken == null) + public void Dispose() { - throw new InvalidOperationException($"{nameof(AccessToken)} not found."); + _decodedToken?.Dispose(); } - - var parts = AccessToken.Split('.'); - if (parts.Length != 3) - { - throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); - } - - var decodedBytes = CoreHelpers.Base64UrlDecode(parts[1]); - if (decodedBytes == null || decodedBytes.Length < 1) - { - throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); - } - - _decodedToken = JsonDocument.Parse(decodedBytes); - return _decodedToken; - } - - public void Dispose() - { - _decodedToken?.Dispose(); } } diff --git a/src/Core/Services/Implementations/BlockingMailQueueService.cs b/src/Core/Services/Implementations/BlockingMailQueueService.cs index 0323b09af7..0a1a99b858 100644 --- a/src/Core/Services/Implementations/BlockingMailQueueService.cs +++ b/src/Core/Services/Implementations/BlockingMailQueueService.cs @@ -1,19 +1,20 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - -public class BlockingMailEnqueuingService : IMailEnqueuingService +namespace Bit.Core.Services { - public async Task EnqueueAsync(IMailQueueMessage message, Func fallback) + public class BlockingMailEnqueuingService : IMailEnqueuingService { - await fallback(message); - } - - public async Task EnqueueManyAsync(IEnumerable messages, Func fallback) - { - foreach (var message in messages) + public async Task EnqueueAsync(IMailQueueMessage message, Func fallback) { await fallback(message); } + + public async Task EnqueueManyAsync(IEnumerable messages, Func fallback) + { + foreach (var message in messages) + { + await fallback(message); + } + } } } diff --git a/src/Core/Services/Implementations/CipherService.cs b/src/Core/Services/Implementations/CipherService.cs index e2679e6283..dfa156974f 100644 --- a/src/Core/Services/Implementations/CipherService.cs +++ b/src/Core/Services/Implementations/CipherService.cs @@ -10,1021 +10,1022 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Core.Models.Data; -namespace Bit.Core.Services; - -public class CipherService : ICipherService +namespace Bit.Core.Services { - public const long MAX_FILE_SIZE = Constants.FileSize501mb; - public const string MAX_FILE_SIZE_READABLE = "500 MB"; - private readonly ICipherRepository _cipherRepository; - private readonly IFolderRepository _folderRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly ICollectionCipherRepository _collectionCipherRepository; - private readonly IPushNotificationService _pushService; - private readonly IAttachmentStorageService _attachmentStorageService; - private readonly IEventService _eventService; - private readonly IUserService _userService; - private readonly IPolicyRepository _policyRepository; - private readonly GlobalSettings _globalSettings; - private const long _fileSizeLeeway = 1024L * 1024L; // 1MB - private readonly IReferenceEventService _referenceEventService; - private readonly ICurrentContext _currentContext; - private readonly IProviderService _providerService; - - public CipherService( - ICipherRepository cipherRepository, - IFolderRepository folderRepository, - ICollectionRepository collectionRepository, - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - ICollectionCipherRepository collectionCipherRepository, - IPushNotificationService pushService, - IAttachmentStorageService attachmentStorageService, - IEventService eventService, - IUserService userService, - IPolicyRepository policyRepository, - GlobalSettings globalSettings, - IReferenceEventService referenceEventService, - ICurrentContext currentContext) + public class CipherService : ICipherService { - _cipherRepository = cipherRepository; - _folderRepository = folderRepository; - _collectionRepository = collectionRepository; - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _collectionCipherRepository = collectionCipherRepository; - _pushService = pushService; - _attachmentStorageService = attachmentStorageService; - _eventService = eventService; - _userService = userService; - _policyRepository = policyRepository; - _globalSettings = globalSettings; - _referenceEventService = referenceEventService; - _currentContext = currentContext; - } + public const long MAX_FILE_SIZE = Constants.FileSize501mb; + public const string MAX_FILE_SIZE_READABLE = "500 MB"; + private readonly ICipherRepository _cipherRepository; + private readonly IFolderRepository _folderRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly IPushNotificationService _pushService; + private readonly IAttachmentStorageService _attachmentStorageService; + private readonly IEventService _eventService; + private readonly IUserService _userService; + private readonly IPolicyRepository _policyRepository; + private readonly GlobalSettings _globalSettings; + private const long _fileSizeLeeway = 1024L * 1024L; // 1MB + private readonly IReferenceEventService _referenceEventService; + private readonly ICurrentContext _currentContext; + private readonly IProviderService _providerService; - public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, - IEnumerable collectionIds = null, bool skipPermissionCheck = false, bool limitCollectionScope = true) - { - if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) + public CipherService( + ICipherRepository cipherRepository, + IFolderRepository folderRepository, + ICollectionRepository collectionRepository, + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + ICollectionCipherRepository collectionCipherRepository, + IPushNotificationService pushService, + IAttachmentStorageService attachmentStorageService, + IEventService eventService, + IUserService userService, + IPolicyRepository policyRepository, + GlobalSettings globalSettings, + IReferenceEventService referenceEventService, + ICurrentContext currentContext) { - throw new BadRequestException("You do not have permissions to edit this."); + _cipherRepository = cipherRepository; + _folderRepository = folderRepository; + _collectionRepository = collectionRepository; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _collectionCipherRepository = collectionCipherRepository; + _pushService = pushService; + _attachmentStorageService = attachmentStorageService; + _eventService = eventService; + _userService = userService; + _policyRepository = policyRepository; + _globalSettings = globalSettings; + _referenceEventService = referenceEventService; + _currentContext = currentContext; } - if (cipher.Id == default(Guid)) + public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, + IEnumerable collectionIds = null, bool skipPermissionCheck = false, bool limitCollectionScope = true) { - if (cipher.OrganizationId.HasValue && collectionIds != null) + if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) { - if (limitCollectionScope) - { - // Set user ID to limit scope of collection ids in the create sproc - cipher.UserId = savingUserId; - } - await _cipherRepository.CreateAsync(cipher, collectionIds); + throw new BadRequestException("You do not have permissions to edit this."); + } - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.CipherCreated, await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value))); + if (cipher.Id == default(Guid)) + { + if (cipher.OrganizationId.HasValue && collectionIds != null) + { + if (limitCollectionScope) + { + // Set user ID to limit scope of collection ids in the create sproc + cipher.UserId = savingUserId; + } + await _cipherRepository.CreateAsync(cipher, collectionIds); + + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.CipherCreated, await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value))); + } + else + { + await _cipherRepository.CreateAsync(cipher); + } + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); + + // push + await _pushService.PushSyncCipherCreateAsync(cipher, null); } else { - await _cipherRepository.CreateAsync(cipher); + ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + cipher.RevisionDate = DateTime.UtcNow; + await _cipherRepository.ReplaceAsync(cipher); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); } - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); - - // push - await _pushService.PushSyncCipherCreateAsync(cipher, null); - } - else - { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); - cipher.RevisionDate = DateTime.UtcNow; - await _cipherRepository.ReplaceAsync(cipher); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); - } - } - - public async Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, - IEnumerable collectionIds = null, bool skipPermissionCheck = false) - { - if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) - { - throw new BadRequestException("You do not have permissions to edit this."); } - cipher.UserId = savingUserId; - if (cipher.Id == default(Guid)) + public async Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, + IEnumerable collectionIds = null, bool skipPermissionCheck = false) { - if (cipher.OrganizationId.HasValue && collectionIds != null) + if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) { - var existingCollectionIds = (await _collectionRepository.GetManyByOrganizationIdAsync(cipher.OrganizationId.Value)).Select(c => c.Id); - if (collectionIds.Except(existingCollectionIds).Any()) + throw new BadRequestException("You do not have permissions to edit this."); + } + + cipher.UserId = savingUserId; + if (cipher.Id == default(Guid)) + { + if (cipher.OrganizationId.HasValue && collectionIds != null) { - throw new BadRequestException("Specified CollectionId does not exist on the specified Organization."); + var existingCollectionIds = (await _collectionRepository.GetManyByOrganizationIdAsync(cipher.OrganizationId.Value)).Select(c => c.Id); + if (collectionIds.Except(existingCollectionIds).Any()) + { + throw new BadRequestException("Specified CollectionId does not exist on the specified Organization."); + } + await _cipherRepository.CreateAsync(cipher, collectionIds); } - await _cipherRepository.CreateAsync(cipher, collectionIds); + else + { + // Make sure the user can save new ciphers to their personal vault + var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(savingUserId, + PolicyType.PersonalOwnership); + if (personalOwnershipPolicyCount > 0) + { + throw new BadRequestException("Due to an Enterprise Policy, you are restricted from saving items to your personal vault."); + } + await _cipherRepository.CreateAsync(cipher); + } + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); + + if (cipher.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); + cipher.OrganizationUseTotp = org.UseTotp; + } + + // push + await _pushService.PushSyncCipherCreateAsync(cipher, null); } else { - // Make sure the user can save new ciphers to their personal vault - var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(savingUserId, - PolicyType.PersonalOwnership); - if (personalOwnershipPolicyCount > 0) - { - throw new BadRequestException("Due to an Enterprise Policy, you are restricted from saving items to your personal vault."); - } - await _cipherRepository.CreateAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); + ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + cipher.RevisionDate = DateTime.UtcNow; + await _cipherRepository.ReplaceAsync(cipher); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); - if (cipher.OrganizationId.HasValue) + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + } + } + + public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) + { + if (attachment == null) { - var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); - cipher.OrganizationUseTotp = org.UseTotp; + throw new BadRequestException("Cipher attachment does not exist"); } - // push - await _pushService.PushSyncCipherCreateAsync(cipher, null); - } - else - { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); - cipher.RevisionDate = DateTime.UtcNow; - await _cipherRepository.ReplaceAsync(cipher); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); + await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, attachment); - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); - } - } - - public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) - { - if (attachment == null) - { - throw new BadRequestException("Cipher attachment does not exist"); + if (!await ValidateCipherAttachmentFile(cipher, attachment)) + { + throw new BadRequestException("File received does not match expected file length."); + } } - await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, attachment); - - if (!await ValidateCipherAttachmentFile(cipher, attachment)) + public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) { - throw new BadRequestException("File received does not match expected file length."); - } - } + await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); - public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) - { - await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); + var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); + var data = new CipherAttachment.MetaData + { + AttachmentId = attachmentId, + FileName = fileName, + Key = key, + Size = fileSize, + Validated = false, + }; - var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); - var data = new CipherAttachment.MetaData - { - AttachmentId = attachmentId, - FileName = fileName, - Key = key, - Size = fileSize, - Validated = false, - }; + var uploadUrl = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, data); - var uploadUrl = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, data); - - await _cipherRepository.UpdateAttachmentAsync(new CipherAttachment - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - AttachmentId = attachmentId, - AttachmentData = JsonSerializer.Serialize(data) - }); - cipher.AddAttachment(attachmentId, data); - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - - return (attachmentId, uploadUrl); - } - - public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false) - { - await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); - - var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); - var data = new CipherAttachment.MetaData - { - AttachmentId = attachmentId, - FileName = fileName, - Key = key, - }; - - await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, data); - // Must read stream length after it has been saved, otherwise it's 0 - data.Size = stream.Length; - - try - { - var attachment = new CipherAttachment + await _cipherRepository.UpdateAttachmentAsync(new CipherAttachment { Id = cipher.Id, UserId = cipher.UserId, OrganizationId = cipher.OrganizationId, AttachmentId = attachmentId, AttachmentData = JsonSerializer.Serialize(data) + }); + cipher.AddAttachment(attachmentId, data); + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + + return (attachmentId, uploadUrl); + } + + public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, + long requestLength, Guid savingUserId, bool orgAdmin = false) + { + await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); + + var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); + var data = new CipherAttachment.MetaData + { + AttachmentId = attachmentId, + FileName = fileName, + Key = key, }; - await _cipherRepository.UpdateAttachmentAsync(attachment); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentCreated); - cipher.AddAttachment(attachmentId, data); + await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, data); + // Must read stream length after it has been saved, otherwise it's 0 + data.Size = stream.Length; - if (!await ValidateCipherAttachmentFile(cipher, data)) + try { - throw new Exception("Content-Length does not match uploaded file size"); + var attachment = new CipherAttachment + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + AttachmentId = attachmentId, + AttachmentData = JsonSerializer.Serialize(data) + }; + + await _cipherRepository.UpdateAttachmentAsync(attachment); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentCreated); + cipher.AddAttachment(attachmentId, data); + + if (!await ValidateCipherAttachmentFile(cipher, data)) + { + throw new Exception("Content-Length does not match uploaded file size"); + } } - } - catch - { - // Clean up since this is not transactional - await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, data); - throw; + catch + { + // Clean up since this is not transactional + await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, data); + throw; + } + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); } - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, - string attachmentId, Guid organizationId) - { - try + public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, + string attachmentId, Guid organizationId) { - if (requestLength < 1) + try { - throw new BadRequestException("No data to attach."); - } + if (requestLength < 1) + { + throw new BadRequestException("No data to attach."); + } - if (cipher.Id == default(Guid)) + if (cipher.Id == default(Guid)) + { + throw new BadRequestException(nameof(cipher.Id)); + } + + if (cipher.OrganizationId.HasValue) + { + throw new BadRequestException("Cipher belongs to an organization already."); + } + + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null || !org.MaxStorageGb.HasValue) + { + throw new BadRequestException("This organization cannot use attachments."); + } + + var storageBytesRemaining = org.StorageBytesRemaining(); + if (storageBytesRemaining < requestLength) + { + throw new BadRequestException("Not enough storage available for this organization."); + } + + var attachments = cipher.GetAttachments(); + if (!attachments.ContainsKey(attachmentId)) + { + throw new BadRequestException($"Cipher does not own specified attachment"); + } + + await _attachmentStorageService.UploadShareAttachmentAsync(stream, cipher.Id, organizationId, + attachments[attachmentId]); + + // Previous call may alter metadata + var updatedAttachment = new CipherAttachment + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + AttachmentId = attachmentId, + AttachmentData = JsonSerializer.Serialize(attachments[attachmentId]) + }; + + await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); + } + catch { - throw new BadRequestException(nameof(cipher.Id)); + await _attachmentStorageService.CleanupAsync(cipher.Id); + throw; } + } - if (cipher.OrganizationId.HasValue) + public async Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + var (valid, realSize) = await _attachmentStorageService.ValidateFileAsync(cipher, attachmentData, _fileSizeLeeway); + + if (!valid || realSize > MAX_FILE_SIZE) { - throw new BadRequestException("Cipher belongs to an organization already."); + // File reported differs in size from that promised. Must be a rogue client. Delete Send + await DeleteAttachmentAsync(cipher, attachmentData); + return false; } - - var org = await _organizationRepository.GetByIdAsync(organizationId); - if (org == null || !org.MaxStorageGb.HasValue) + // Update Send data if necessary + if (realSize != attachmentData.Size) { - throw new BadRequestException("This organization cannot use attachments."); + attachmentData.Size = realSize.Value; } + attachmentData.Validated = true; - var storageBytesRemaining = org.StorageBytesRemaining(); - if (storageBytesRemaining < requestLength) - { - throw new BadRequestException("Not enough storage available for this organization."); - } - - var attachments = cipher.GetAttachments(); - if (!attachments.ContainsKey(attachmentId)) - { - throw new BadRequestException($"Cipher does not own specified attachment"); - } - - await _attachmentStorageService.UploadShareAttachmentAsync(stream, cipher.Id, organizationId, - attachments[attachmentId]); - - // Previous call may alter metadata var updatedAttachment = new CipherAttachment { Id = cipher.Id, UserId = cipher.UserId, OrganizationId = cipher.OrganizationId, - AttachmentId = attachmentId, - AttachmentData = JsonSerializer.Serialize(attachments[attachmentId]) + AttachmentId = attachmentData.AttachmentId, + AttachmentData = JsonSerializer.Serialize(attachmentData) }; + await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); - } - catch - { - await _attachmentStorageService.CleanupAsync(cipher.Id); - throw; - } - } - public async Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - var (valid, realSize) = await _attachmentStorageService.ValidateFileAsync(cipher, attachmentData, _fileSizeLeeway); - - if (!valid || realSize > MAX_FILE_SIZE) - { - // File reported differs in size from that promised. Must be a rogue client. Delete Send - await DeleteAttachmentAsync(cipher, attachmentData); - return false; - } - // Update Send data if necessary - if (realSize != attachmentData.Size) - { - attachmentData.Size = realSize.Value; - } - attachmentData.Validated = true; - - var updatedAttachment = new CipherAttachment - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - AttachmentId = attachmentData.AttachmentId, - AttachmentData = JsonSerializer.Serialize(attachmentData) - }; - - - await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); - - return valid; - } - - public async Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId) - { - var attachments = cipher?.GetAttachments() ?? new Dictionary(); - - if (!attachments.ContainsKey(attachmentId)) - { - throw new NotFoundException(); + return valid; } - var data = attachments[attachmentId]; - var response = new AttachmentResponseData + public async Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId) { - Cipher = cipher, - Data = data, - Id = attachmentId, - Url = await _attachmentStorageService.GetAttachmentDownloadUrlAsync(cipher, data), - }; + var attachments = cipher?.GetAttachments() ?? new Dictionary(); - return response; - } - - public async Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) - { - throw new BadRequestException("You do not have permissions to delete this."); - } - - await _cipherRepository.DeleteAsync(cipher); - await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipher.Id); - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Deleted); - - // push - await _pushService.PushSyncCipherDeleteAsync(cipher); - } - - public async Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false) - { - var cipherIdsSet = new HashSet(cipherIds); - var deletingCiphers = new List(); - - if (orgAdmin && organizationId.HasValue) - { - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); - await _cipherRepository.DeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); - } - else - { - var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); - await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); - } - - var events = deletingCiphers.Select(c => - new Tuple(c, EventType.Cipher_Deleted, null)); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(deletingUserId); - } - - public async Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, - bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) - { - throw new BadRequestException("You do not have permissions to delete this."); - } - - if (!cipher.ContainsAttachment(attachmentId)) - { - throw new NotFoundException(); - } - - await DeleteAttachmentAsync(cipher, cipher.GetAttachments()[attachmentId]); - } - - public async Task PurgeAsync(Guid organizationId) - { - var org = await _organizationRepository.GetByIdAsync(organizationId); - if (org == null) - { - throw new NotFoundException(); - } - await _cipherRepository.DeleteByOrganizationIdAsync(organizationId); - await _eventService.LogOrganizationEventAsync(org, Enums.EventType.Organization_PurgedVault); - } - - public async Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId) - { - if (destinationFolderId.HasValue) - { - var folder = await _folderRepository.GetByIdAsync(destinationFolderId.Value); - if (folder == null || folder.UserId != movingUserId) + if (!attachments.ContainsKey(attachmentId)) { - throw new BadRequestException("Invalid folder."); - } - } - - await _cipherRepository.MoveAsync(cipherIds, destinationFolderId, movingUserId); - // push - await _pushService.PushSyncCiphersAsync(movingUserId); - } - - public async Task SaveFolderAsync(Folder folder) - { - if (folder.Id == default(Guid)) - { - await _folderRepository.CreateAsync(folder); - - // push - await _pushService.PushSyncFolderCreateAsync(folder); - } - else - { - folder.RevisionDate = DateTime.UtcNow; - await _folderRepository.UpsertAsync(folder); - - // push - await _pushService.PushSyncFolderUpdateAsync(folder); - } - } - - public async Task DeleteFolderAsync(Folder folder) - { - await _folderRepository.DeleteAsync(folder); - - // push - await _pushService.PushSyncFolderDeleteAsync(folder); - } - - public async Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, - IEnumerable collectionIds, Guid sharingUserId, DateTime? lastKnownRevisionDate) - { - var attachments = cipher.GetAttachments(); - var hasOldAttachments = attachments?.Any(a => a.Key == null) ?? false; - var updatedCipher = false; - var migratedAttachments = false; - var originalAttachments = CoreHelpers.CloneObject(attachments); - - try - { - await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); - - // Sproc will not save this UserId on the cipher. It is used limit scope of the collectionIds. - cipher.UserId = sharingUserId; - cipher.OrganizationId = organizationId; - cipher.RevisionDate = DateTime.UtcNow; - if (!await _cipherRepository.ReplaceAsync(cipher, collectionIds)) - { - throw new BadRequestException("Unable to save."); + throw new NotFoundException(); } - updatedCipher = true; - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Shared); - - if (hasOldAttachments) + var data = attachments[attachmentId]; + var response = new AttachmentResponseData { - // migrate old attachments - foreach (var attachment in attachments.Where(a => a.Key == null)) + Cipher = cipher, + Data = data, + Id = attachmentId, + Url = await _attachmentStorageService.GetAttachmentDownloadUrlAsync(cipher, data), + }; + + return response; + } + + public async Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + { + throw new BadRequestException("You do not have permissions to delete this."); + } + + await _cipherRepository.DeleteAsync(cipher); + await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipher.Id); + await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Deleted); + + // push + await _pushService.PushSyncCipherDeleteAsync(cipher); + } + + public async Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false) + { + var cipherIdsSet = new HashSet(cipherIds); + var deletingCiphers = new List(); + + if (orgAdmin && organizationId.HasValue) + { + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); + await _cipherRepository.DeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); + } + else + { + var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); + await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); + } + + var events = deletingCiphers.Select(c => + new Tuple(c, EventType.Cipher_Deleted, null)); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(deletingUserId); + } + + public async Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, + bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + { + throw new BadRequestException("You do not have permissions to delete this."); + } + + if (!cipher.ContainsAttachment(attachmentId)) + { + throw new NotFoundException(); + } + + await DeleteAttachmentAsync(cipher, cipher.GetAttachments()[attachmentId]); + } + + public async Task PurgeAsync(Guid organizationId) + { + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null) + { + throw new NotFoundException(); + } + await _cipherRepository.DeleteByOrganizationIdAsync(organizationId); + await _eventService.LogOrganizationEventAsync(org, Enums.EventType.Organization_PurgedVault); + } + + public async Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId) + { + if (destinationFolderId.HasValue) + { + var folder = await _folderRepository.GetByIdAsync(destinationFolderId.Value); + if (folder == null || folder.UserId != movingUserId) { - await _attachmentStorageService.StartShareAttachmentAsync(cipher.Id, organizationId, - attachment.Value); - migratedAttachments = true; + throw new BadRequestException("Invalid folder."); + } + } + + await _cipherRepository.MoveAsync(cipherIds, destinationFolderId, movingUserId); + // push + await _pushService.PushSyncCiphersAsync(movingUserId); + } + + public async Task SaveFolderAsync(Folder folder) + { + if (folder.Id == default(Guid)) + { + await _folderRepository.CreateAsync(folder); + + // push + await _pushService.PushSyncFolderCreateAsync(folder); + } + else + { + folder.RevisionDate = DateTime.UtcNow; + await _folderRepository.UpsertAsync(folder); + + // push + await _pushService.PushSyncFolderUpdateAsync(folder); + } + } + + public async Task DeleteFolderAsync(Folder folder) + { + await _folderRepository.DeleteAsync(folder); + + // push + await _pushService.PushSyncFolderDeleteAsync(folder); + } + + public async Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, + IEnumerable collectionIds, Guid sharingUserId, DateTime? lastKnownRevisionDate) + { + var attachments = cipher.GetAttachments(); + var hasOldAttachments = attachments?.Any(a => a.Key == null) ?? false; + var updatedCipher = false; + var migratedAttachments = false; + var originalAttachments = CoreHelpers.CloneObject(attachments); + + try + { + await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); + + // Sproc will not save this UserId on the cipher. It is used limit scope of the collectionIds. + cipher.UserId = sharingUserId; + cipher.OrganizationId = organizationId; + cipher.RevisionDate = DateTime.UtcNow; + if (!await _cipherRepository.ReplaceAsync(cipher, collectionIds)) + { + throw new BadRequestException("Unable to save."); } - // commit attachment migration - await _attachmentStorageService.CleanupAsync(cipher.Id); + updatedCipher = true; + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Shared); + + if (hasOldAttachments) + { + // migrate old attachments + foreach (var attachment in attachments.Where(a => a.Key == null)) + { + await _attachmentStorageService.StartShareAttachmentAsync(cipher.Id, organizationId, + attachment.Value); + migratedAttachments = true; + } + + // commit attachment migration + await _attachmentStorageService.CleanupAsync(cipher.Id); + } + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); } + catch + { + // roll everything back + if (updatedCipher) + { + await _cipherRepository.ReplaceAsync(originalCipher); + } + + if (!hasOldAttachments || !migratedAttachments) + { + throw; + } + + if (updatedCipher) + { + await _userRepository.UpdateStorageAsync(sharingUserId); + await _organizationRepository.UpdateStorageAsync(organizationId); + } + + foreach (var attachment in attachments.Where(a => a.Key == null)) + { + await _attachmentStorageService.RollbackShareAttachmentAsync(cipher.Id, organizationId, + attachment.Value, originalAttachments[attachment.Key].ContainerName); + } + + await _attachmentStorageService.CleanupAsync(cipher.Id); + throw; + } + } + + public async Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> cipherInfos, + Guid organizationId, IEnumerable collectionIds, Guid sharingUserId) + { + var cipherIds = new List(); + foreach (var (cipher, lastKnownRevisionDate) in cipherInfos) + { + await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); + + cipher.UserId = null; + cipher.OrganizationId = organizationId; + cipher.RevisionDate = DateTime.UtcNow; + cipherIds.Add(cipher.Id); + } + + await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); + await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, + organizationId, collectionIds); + + var events = cipherInfos.Select(c => + new Tuple(c.cipher, EventType.Cipher_Shared, null)); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(sharingUserId); + } + + public async Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, + bool orgAdmin) + { + if (cipher.Id == default(Guid)) + { + throw new BadRequestException(nameof(cipher.Id)); + } + + if (!cipher.OrganizationId.HasValue) + { + throw new BadRequestException("Cipher must belong to an organization."); + } + + cipher.RevisionDate = DateTime.UtcNow; + + // The sprocs will validate that all collections belong to this org/user and that they have + // proper write permissions. + if (orgAdmin) + { + await _collectionCipherRepository.UpdateCollectionsForAdminAsync(cipher.Id, + cipher.OrganizationId.Value, collectionIds); + } + else + { + if (!(await UserCanEditAsync(cipher, savingUserId))) + { + throw new BadRequestException("You do not have permissions to edit this."); + } + await _collectionCipherRepository.UpdateCollectionsAsync(cipher.Id, savingUserId, collectionIds); + } + + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_UpdatedCollections); // push await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); } - catch + + public async Task ImportCiphersAsync( + List folders, + List ciphers, + IEnumerable> folderRelationships) { - // roll everything back - if (updatedCipher) + var userId = folders.FirstOrDefault()?.UserId ?? ciphers.FirstOrDefault()?.UserId; + + // Make sure the user can save new ciphers to their personal vault + var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, + PolicyType.PersonalOwnership); + if (personalOwnershipPolicyCount > 0) { - await _cipherRepository.ReplaceAsync(originalCipher); + throw new BadRequestException("You cannot import items into your personal vault because you are " + + "a member of an organization which forbids it."); } - if (!hasOldAttachments || !migratedAttachments) + foreach (var cipher in ciphers) { - throw; + cipher.SetNewId(); + + if (cipher.UserId.HasValue && cipher.Favorite) + { + cipher.Favorites = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":\"true\"}}"; + } } - if (updatedCipher) + // Init. ids for folders + foreach (var folder in folders) { - await _userRepository.UpdateStorageAsync(sharingUserId); - await _organizationRepository.UpdateStorageAsync(organizationId); + folder.SetNewId(); } - foreach (var attachment in attachments.Where(a => a.Key == null)) + // Create the folder associations based on the newly created folder ids + foreach (var relationship in folderRelationships) { - await _attachmentStorageService.RollbackShareAttachmentAsync(cipher.Id, organizationId, - attachment.Value, originalAttachments[attachment.Key].ContainerName); + var cipher = ciphers.ElementAtOrDefault(relationship.Key); + var folder = folders.ElementAtOrDefault(relationship.Value); + + if (cipher == null || folder == null) + { + continue; + } + + cipher.Folders = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":" + + $"\"{folder.Id.ToString().ToUpperInvariant()}\"}}"; } - await _attachmentStorageService.CleanupAsync(cipher.Id); - throw; - } - } + // Create it all + await _cipherRepository.CreateAsync(ciphers, folders); - public async Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> cipherInfos, - Guid organizationId, IEnumerable collectionIds, Guid sharingUserId) - { - var cipherIds = new List(); - foreach (var (cipher, lastKnownRevisionDate) in cipherInfos) - { - await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); - - cipher.UserId = null; - cipher.OrganizationId = organizationId; - cipher.RevisionDate = DateTime.UtcNow; - cipherIds.Add(cipher.Id); - } - - await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); - await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, - organizationId, collectionIds); - - var events = cipherInfos.Select(c => - new Tuple(c.cipher, EventType.Cipher_Shared, null)); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(sharingUserId); - } - - public async Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, - bool orgAdmin) - { - if (cipher.Id == default(Guid)) - { - throw new BadRequestException(nameof(cipher.Id)); - } - - if (!cipher.OrganizationId.HasValue) - { - throw new BadRequestException("Cipher must belong to an organization."); - } - - cipher.RevisionDate = DateTime.UtcNow; - - // The sprocs will validate that all collections belong to this org/user and that they have - // proper write permissions. - if (orgAdmin) - { - await _collectionCipherRepository.UpdateCollectionsForAdminAsync(cipher.Id, - cipher.OrganizationId.Value, collectionIds); - } - else - { - if (!(await UserCanEditAsync(cipher, savingUserId))) + // push + if (userId.HasValue) { - throw new BadRequestException("You do not have permissions to edit this."); - } - await _collectionCipherRepository.UpdateCollectionsAsync(cipher.Id, savingUserId, collectionIds); - } - - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_UpdatedCollections); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); - } - - public async Task ImportCiphersAsync( - List folders, - List ciphers, - IEnumerable> folderRelationships) - { - var userId = folders.FirstOrDefault()?.UserId ?? ciphers.FirstOrDefault()?.UserId; - - // Make sure the user can save new ciphers to their personal vault - var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, - PolicyType.PersonalOwnership); - if (personalOwnershipPolicyCount > 0) - { - throw new BadRequestException("You cannot import items into your personal vault because you are " + - "a member of an organization which forbids it."); - } - - foreach (var cipher in ciphers) - { - cipher.SetNewId(); - - if (cipher.UserId.HasValue && cipher.Favorite) - { - cipher.Favorites = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":\"true\"}}"; + await _pushService.PushSyncVaultAsync(userId.Value); } } - // Init. ids for folders - foreach (var folder in folders) + public async Task ImportCiphersAsync( + List collections, + List ciphers, + IEnumerable> collectionRelationships, + Guid importingUserId) { - folder.SetNewId(); - } + var org = collections.Count > 0 ? + await _organizationRepository.GetByIdAsync(collections[0].OrganizationId) : + await _organizationRepository.GetByIdAsync(ciphers.FirstOrDefault(c => c.OrganizationId.HasValue).OrganizationId.Value); - // Create the folder associations based on the newly created folder ids - foreach (var relationship in folderRelationships) - { - var cipher = ciphers.ElementAtOrDefault(relationship.Key); - var folder = folders.ElementAtOrDefault(relationship.Value); - - if (cipher == null || folder == null) + if (collections.Count > 0 && org != null && org.MaxCollections.HasValue) { - continue; + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); + if (org.MaxCollections.Value < (collectionCount + collections.Count)) + { + throw new BadRequestException("This organization can only have a maximum of " + + $"{org.MaxCollections.Value} collections."); + } } - cipher.Folders = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":" + - $"\"{folder.Id.ToString().ToUpperInvariant()}\"}}"; - } - - // Create it all - await _cipherRepository.CreateAsync(ciphers, folders); - - // push - if (userId.HasValue) - { - await _pushService.PushSyncVaultAsync(userId.Value); - } - } - - public async Task ImportCiphersAsync( - List collections, - List ciphers, - IEnumerable> collectionRelationships, - Guid importingUserId) - { - var org = collections.Count > 0 ? - await _organizationRepository.GetByIdAsync(collections[0].OrganizationId) : - await _organizationRepository.GetByIdAsync(ciphers.FirstOrDefault(c => c.OrganizationId.HasValue).OrganizationId.Value); - - if (collections.Count > 0 && org != null && org.MaxCollections.HasValue) - { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); - if (org.MaxCollections.Value < (collectionCount + collections.Count)) + // Init. ids for ciphers + foreach (var cipher in ciphers) { - throw new BadRequestException("This organization can only have a maximum of " + - $"{org.MaxCollections.Value} collections."); + cipher.SetNewId(); + } + + // Init. ids for collections + foreach (var collection in collections) + { + collection.SetNewId(); + } + + // Create associations based on the newly assigned ids + var collectionCiphers = new List(); + foreach (var relationship in collectionRelationships) + { + var cipher = ciphers.ElementAtOrDefault(relationship.Key); + var collection = collections.ElementAtOrDefault(relationship.Value); + + if (cipher == null || collection == null) + { + continue; + } + + collectionCiphers.Add(new CollectionCipher + { + CipherId = cipher.Id, + CollectionId = collection.Id + }); + } + + // Create it all + await _cipherRepository.CreateAsync(ciphers, collections, collectionCiphers); + + // push + await _pushService.PushSyncVaultAsync(importingUserId); + + + if (org != null) + { + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.VaultImported, org)); } } - // Init. ids for ciphers - foreach (var cipher in ciphers) + public async Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) { - cipher.SetNewId(); - } - - // Init. ids for collections - foreach (var collection in collections) - { - collection.SetNewId(); - } - - // Create associations based on the newly assigned ids - var collectionCiphers = new List(); - foreach (var relationship in collectionRelationships) - { - var cipher = ciphers.ElementAtOrDefault(relationship.Key); - var collection = collections.ElementAtOrDefault(relationship.Value); - - if (cipher == null || collection == null) + if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) { - continue; + throw new BadRequestException("You do not have permissions to soft delete this."); } - collectionCiphers.Add(new CollectionCipher + if (cipher.DeletedDate.HasValue) { - CipherId = cipher.Id, - CollectionId = collection.Id - }); - } - - // Create it all - await _cipherRepository.CreateAsync(ciphers, collections, collectionCiphers); - - // push - await _pushService.PushSyncVaultAsync(importingUserId); - - - if (org != null) - { - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.VaultImported, org)); - } - } - - public async Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) - { - throw new BadRequestException("You do not have permissions to soft delete this."); - } - - if (cipher.DeletedDate.HasValue) - { - // Already soft-deleted, we can safely ignore this - return; - } - - cipher.DeletedDate = cipher.RevisionDate = DateTime.UtcNow; - - if (cipher is CipherDetails details) - { - await _cipherRepository.UpsertAsync(details); - } - else - { - await _cipherRepository.UpsertAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - public async Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin) - { - var cipherIdsSet = new HashSet(cipherIds); - var deletingCiphers = new List(); - - if (orgAdmin && organizationId.HasValue) - { - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); - await _cipherRepository.SoftDeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); - } - else - { - var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); - await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); - } - - var events = deletingCiphers.Select(c => - new Tuple(c, EventType.Cipher_SoftDeleted, null)); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(deletingUserId); - } - - public async Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, restoringUserId))) - { - throw new BadRequestException("You do not have permissions to delete this."); - } - - if (!cipher.DeletedDate.HasValue) - { - // Already restored, we can safely ignore this - return; - } - - cipher.DeletedDate = null; - cipher.RevisionDate = DateTime.UtcNow; - - if (cipher is CipherDetails details) - { - await _cipherRepository.UpsertAsync(details); - } - else - { - await _cipherRepository.UpsertAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Restored); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - public async Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId) - { - var revisionDate = await _cipherRepository.RestoreAsync(ciphers.Select(c => c.Id), restoringUserId); - - var events = ciphers.Select(c => - { - c.RevisionDate = revisionDate; - c.DeletedDate = null; - return new Tuple(c, EventType.Cipher_Restored, null); - }); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(restoringUserId); - } - - public async Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId) - { - if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.AccessReports(organizationId)) - { - throw new NotFoundException(); - } - - IEnumerable orgCiphers; - if (await _currentContext.OrganizationAdmin(organizationId)) - { - // Admins, Owners and Providers can access all items even if not assigned to them - orgCiphers = await _cipherRepository.GetManyOrganizationDetailsByOrganizationIdAsync(organizationId); - } - else - { - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true); - orgCiphers = ciphers.Where(c => c.OrganizationId == organizationId); - } - - var orgCipherIds = orgCiphers.Select(c => c.Id); - - var collectionCiphers = await _collectionCipherRepository.GetManyByOrganizationIdAsync(organizationId); - var collectionCiphersGroupDict = collectionCiphers - .Where(c => orgCipherIds.Contains(c.CipherId)) - .GroupBy(c => c.CipherId).ToDictionary(s => s.Key); - - var providerId = await _currentContext.ProviderIdForOrg(organizationId); - if (providerId.HasValue) - { - await _providerService.LogProviderAccessToOrganizationAsync(organizationId); - } - - return (orgCiphers, collectionCiphersGroupDict); - } - - private async Task UserCanEditAsync(Cipher cipher, Guid userId) - { - if (!cipher.OrganizationId.HasValue && cipher.UserId.HasValue && cipher.UserId.Value == userId) - { - return true; - } - - return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id); - } - - private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) - { - if (cipher.Id == default || !lastKnownRevisionDate.HasValue) - { - return; - } - - if ((cipher.RevisionDate - lastKnownRevisionDate.Value).Duration() > TimeSpan.FromSeconds(1)) - { - throw new BadRequestException( - "The cipher you are updating is out of date. Please save your work, sync your vault, and try again." - ); - } - } - - private async Task DeleteAttachmentAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - if (attachmentData == null || string.IsNullOrWhiteSpace(attachmentData.AttachmentId)) - { - return; - } - - await _cipherRepository.DeleteAttachmentAsync(cipher.Id, attachmentData.AttachmentId); - cipher.DeleteAttachment(attachmentData.AttachmentId); - await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, attachmentData); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentDeleted); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - private async Task ValidateCipherEditForAttachmentAsync(Cipher cipher, Guid savingUserId, bool orgAdmin, - long requestLength) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, savingUserId))) - { - throw new BadRequestException("You do not have permissions to edit this."); - } - - if (requestLength < 1) - { - throw new BadRequestException("No data to attach."); - } - - var storageBytesRemaining = await StorageBytesRemainingForCipherAsync(cipher); - - if (storageBytesRemaining < requestLength) - { - throw new BadRequestException("Not enough storage available."); - } - } - - private async Task StorageBytesRemainingForCipherAsync(Cipher cipher) - { - var storageBytesRemaining = 0L; - if (cipher.UserId.HasValue) - { - var user = await _userRepository.GetByIdAsync(cipher.UserId.Value); - if (!(await _userService.CanAccessPremium(user))) - { - throw new BadRequestException("You must have premium status to use attachments."); + // Already soft-deleted, we can safely ignore this + return; } - if (user.Premium) + cipher.DeletedDate = cipher.RevisionDate = DateTime.UtcNow; + + if (cipher is CipherDetails details) { - storageBytesRemaining = user.StorageBytesRemaining(); + await _cipherRepository.UpsertAsync(details); } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? (short)10240 : (short)1); + await _cipherRepository.UpsertAsync(cipher); + } + await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + public async Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin) + { + var cipherIdsSet = new HashSet(cipherIds); + var deletingCiphers = new List(); + + if (orgAdmin && organizationId.HasValue) + { + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); + await _cipherRepository.SoftDeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); + } + else + { + var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); + await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); + } + + var events = deletingCiphers.Select(c => + new Tuple(c, EventType.Cipher_SoftDeleted, null)); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(deletingUserId); + } + + public async Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, restoringUserId))) + { + throw new BadRequestException("You do not have permissions to delete this."); + } + + if (!cipher.DeletedDate.HasValue) + { + // Already restored, we can safely ignore this + return; + } + + cipher.DeletedDate = null; + cipher.RevisionDate = DateTime.UtcNow; + + if (cipher is CipherDetails details) + { + await _cipherRepository.UpsertAsync(details); + } + else + { + await _cipherRepository.UpsertAsync(cipher); + } + await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Restored); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + public async Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId) + { + var revisionDate = await _cipherRepository.RestoreAsync(ciphers.Select(c => c.Id), restoringUserId); + + var events = ciphers.Select(c => + { + c.RevisionDate = revisionDate; + c.DeletedDate = null; + return new Tuple(c, EventType.Cipher_Restored, null); + }); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(restoringUserId); + } + + public async Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId) + { + if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.AccessReports(organizationId)) + { + throw new NotFoundException(); + } + + IEnumerable orgCiphers; + if (await _currentContext.OrganizationAdmin(organizationId)) + { + // Admins, Owners and Providers can access all items even if not assigned to them + orgCiphers = await _cipherRepository.GetManyOrganizationDetailsByOrganizationIdAsync(organizationId); + } + else + { + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true); + orgCiphers = ciphers.Where(c => c.OrganizationId == organizationId); + } + + var orgCipherIds = orgCiphers.Select(c => c.Id); + + var collectionCiphers = await _collectionCipherRepository.GetManyByOrganizationIdAsync(organizationId); + var collectionCiphersGroupDict = collectionCiphers + .Where(c => orgCipherIds.Contains(c.CipherId)) + .GroupBy(c => c.CipherId).ToDictionary(s => s.Key); + + var providerId = await _currentContext.ProviderIdForOrg(organizationId); + if (providerId.HasValue) + { + await _providerService.LogProviderAccessToOrganizationAsync(organizationId); + } + + return (orgCiphers, collectionCiphersGroupDict); + } + + private async Task UserCanEditAsync(Cipher cipher, Guid userId) + { + if (!cipher.OrganizationId.HasValue && cipher.UserId.HasValue && cipher.UserId.Value == userId) + { + return true; + } + + return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id); + } + + private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) + { + if (cipher.Id == default || !lastKnownRevisionDate.HasValue) + { + return; + } + + if ((cipher.RevisionDate - lastKnownRevisionDate.Value).Duration() > TimeSpan.FromSeconds(1)) + { + throw new BadRequestException( + "The cipher you are updating is out of date. Please save your work, sync your vault, and try again." + ); } } - else if (cipher.OrganizationId.HasValue) + + private async Task DeleteAttachmentAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) { - var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); - if (!org.MaxStorageGb.HasValue) + if (attachmentData == null || string.IsNullOrWhiteSpace(attachmentData.AttachmentId)) + { + return; + } + + await _cipherRepository.DeleteAttachmentAsync(cipher.Id, attachmentData.AttachmentId); + cipher.DeleteAttachment(attachmentData.AttachmentId); + await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, attachmentData); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentDeleted); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + private async Task ValidateCipherEditForAttachmentAsync(Cipher cipher, Guid savingUserId, bool orgAdmin, + long requestLength) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, savingUserId))) + { + throw new BadRequestException("You do not have permissions to edit this."); + } + + if (requestLength < 1) + { + throw new BadRequestException("No data to attach."); + } + + var storageBytesRemaining = await StorageBytesRemainingForCipherAsync(cipher); + + if (storageBytesRemaining < requestLength) + { + throw new BadRequestException("Not enough storage available."); + } + } + + private async Task StorageBytesRemainingForCipherAsync(Cipher cipher) + { + var storageBytesRemaining = 0L; + if (cipher.UserId.HasValue) + { + var user = await _userRepository.GetByIdAsync(cipher.UserId.Value); + if (!(await _userService.CanAccessPremium(user))) + { + throw new BadRequestException("You must have premium status to use attachments."); + } + + if (user.Premium) + { + storageBytesRemaining = user.StorageBytesRemaining(); + } + else + { + // Users that get access to file storage/premium from their organization get the default + // 1 GB max storage. + storageBytesRemaining = user.StorageBytesRemaining( + _globalSettings.SelfHosted ? (short)10240 : (short)1); + } + } + else if (cipher.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); + if (!org.MaxStorageGb.HasValue) + { + throw new BadRequestException("This organization cannot use attachments."); + } + + storageBytesRemaining = org.StorageBytesRemaining(); + } + + return storageBytesRemaining; + } + + private async Task ValidateCipherCanBeShared( + Cipher cipher, + Guid sharingUserId, + Guid organizationId, + DateTime? lastKnownRevisionDate) + { + if (cipher.Id == default(Guid)) + { + throw new BadRequestException("Cipher must already exist."); + } + + if (cipher.OrganizationId.HasValue) + { + throw new BadRequestException("One or more ciphers already belong to an organization."); + } + + if (!cipher.UserId.HasValue || cipher.UserId.Value != sharingUserId) + { + throw new BadRequestException("One or more ciphers do not belong to you."); + } + + var attachments = cipher.GetAttachments(); + var hasAttachments = attachments?.Any() ?? false; + var org = await _organizationRepository.GetByIdAsync(organizationId); + + if (org == null) + { + throw new BadRequestException("Could not find organization."); + } + + if (hasAttachments && !org.MaxStorageGb.HasValue) { throw new BadRequestException("This organization cannot use attachments."); } - storageBytesRemaining = org.StorageBytesRemaining(); + var storageAdjustment = attachments?.Sum(a => a.Value.Size) ?? 0; + if (org.StorageBytesRemaining() < storageAdjustment) + { + throw new BadRequestException("Not enough storage available for this organization."); + } + + ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); } - - return storageBytesRemaining; - } - - private async Task ValidateCipherCanBeShared( - Cipher cipher, - Guid sharingUserId, - Guid organizationId, - DateTime? lastKnownRevisionDate) - { - if (cipher.Id == default(Guid)) - { - throw new BadRequestException("Cipher must already exist."); - } - - if (cipher.OrganizationId.HasValue) - { - throw new BadRequestException("One or more ciphers already belong to an organization."); - } - - if (!cipher.UserId.HasValue || cipher.UserId.Value != sharingUserId) - { - throw new BadRequestException("One or more ciphers do not belong to you."); - } - - var attachments = cipher.GetAttachments(); - var hasAttachments = attachments?.Any() ?? false; - var org = await _organizationRepository.GetByIdAsync(organizationId); - - if (org == null) - { - throw new BadRequestException("Could not find organization."); - } - - if (hasAttachments && !org.MaxStorageGb.HasValue) - { - throw new BadRequestException("This organization cannot use attachments."); - } - - var storageAdjustment = attachments?.Sum(a => a.Value.Size) ?? 0; - if (org.StorageBytesRemaining() < storageAdjustment) - { - throw new BadRequestException("Not enough storage available for this organization."); - } - - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); } } diff --git a/src/Core/Services/Implementations/CollectionService.cs b/src/Core/Services/Implementations/CollectionService.cs index 699f38925d..e41532c1ec 100644 --- a/src/Core/Services/Implementations/CollectionService.cs +++ b/src/Core/Services/Implementations/CollectionService.cs @@ -6,135 +6,136 @@ using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class CollectionService : ICollectionService +namespace Bit.Core.Services { - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IReferenceEventService _referenceEventService; - private readonly ICurrentContext _currentContext; - - public CollectionService( - IEventService eventService, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - ICollectionRepository collectionRepository, - IUserRepository userRepository, - IMailService mailService, - IReferenceEventService referenceEventService, - ICurrentContext currentContext) + public class CollectionService : ICollectionService { - _eventService = eventService; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _collectionRepository = collectionRepository; - _userRepository = userRepository; - _mailService = mailService; - _referenceEventService = referenceEventService; - _currentContext = currentContext; - } + private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IReferenceEventService _referenceEventService; + private readonly ICurrentContext _currentContext; - public async Task SaveAsync(Collection collection, IEnumerable groups = null, - Guid? assignUserId = null) - { - var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId); - if (org == null) + public CollectionService( + IEventService eventService, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository, + IUserRepository userRepository, + IMailService mailService, + IReferenceEventService referenceEventService, + ICurrentContext currentContext) { - throw new BadRequestException("Organization not found"); + _eventService = eventService; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _collectionRepository = collectionRepository; + _userRepository = userRepository; + _mailService = mailService; + _referenceEventService = referenceEventService; + _currentContext = currentContext; } - if (collection.Id == default(Guid)) + public async Task SaveAsync(Collection collection, IEnumerable groups = null, + Guid? assignUserId = null) { - if (org.MaxCollections.HasValue) + var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId); + if (org == null) { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); - if (org.MaxCollections.Value <= collectionCount) - { - throw new BadRequestException("You have reached the maximum number of collections " + - $"({org.MaxCollections.Value}) for this organization."); - } + throw new BadRequestException("Organization not found"); } - if (groups == null || !org.UseGroups) + if (collection.Id == default(Guid)) { - await _collectionRepository.CreateAsync(collection); + if (org.MaxCollections.HasValue) + { + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); + if (org.MaxCollections.Value <= collectionCount) + { + throw new BadRequestException("You have reached the maximum number of collections " + + $"({org.MaxCollections.Value}) for this organization."); + } + } + + if (groups == null || !org.UseGroups) + { + await _collectionRepository.CreateAsync(collection); + } + else + { + await _collectionRepository.CreateAsync(collection, groups); + } + + // Assign a user to the newly created collection. + if (assignUserId.HasValue) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value); + if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed) + { + await _collectionRepository.UpdateUsersAsync(collection.Id, + new List { + new SelectionReadOnly { Id = orgUser.Id, ReadOnly = false } }); + } + } + + await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org)); } else { - await _collectionRepository.CreateAsync(collection, groups); - } - - // Assign a user to the newly created collection. - if (assignUserId.HasValue) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value); - if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed) + if (!org.UseGroups) { - await _collectionRepository.UpdateUsersAsync(collection.Id, - new List { - new SelectionReadOnly { Id = orgUser.Id, ReadOnly = false } }); + await _collectionRepository.ReplaceAsync(collection); } + else + { + await _collectionRepository.ReplaceAsync(collection, groups ?? new List()); + } + + await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated); + } + } + + public async Task DeleteAsync(Collection collection) + { + await _collectionRepository.DeleteAsync(collection); + await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Deleted); + } + + public async Task DeleteUserAsync(Collection collection, Guid organizationUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != collection.OrganizationId) + { + throw new NotFoundException(); + } + await _collectionRepository.DeleteUserAsync(collection.Id, organizationUserId); + await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_Updated); + } + + public async Task> GetOrganizationCollections(Guid organizationId) + { + if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.ManageUsers(organizationId)) + { + throw new NotFoundException(); } - await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org)); - } - else - { - if (!org.UseGroups) + IEnumerable orgCollections; + if (await _currentContext.OrganizationAdmin(organizationId) || await _currentContext.ViewAllCollections(organizationId)) { - await _collectionRepository.ReplaceAsync(collection); + // Admins, Owners, Providers and Custom (with collection management permissions) can access all items even if not assigned to them + orgCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId); } else { - await _collectionRepository.ReplaceAsync(collection, groups ?? new List()); + var collections = await _collectionRepository.GetManyByUserIdAsync(_currentContext.UserId.Value); + orgCollections = collections.Where(c => c.OrganizationId == organizationId); } - await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated); + return orgCollections; } } - - public async Task DeleteAsync(Collection collection) - { - await _collectionRepository.DeleteAsync(collection); - await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Deleted); - } - - public async Task DeleteUserAsync(Collection collection, Guid organizationUserId) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != collection.OrganizationId) - { - throw new NotFoundException(); - } - await _collectionRepository.DeleteUserAsync(collection.Id, organizationUserId); - await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_Updated); - } - - public async Task> GetOrganizationCollections(Guid organizationId) - { - if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.ManageUsers(organizationId)) - { - throw new NotFoundException(); - } - - IEnumerable orgCollections; - if (await _currentContext.OrganizationAdmin(organizationId) || await _currentContext.ViewAllCollections(organizationId)) - { - // Admins, Owners, Providers and Custom (with collection management permissions) can access all items even if not assigned to them - orgCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId); - } - else - { - var collections = await _collectionRepository.GetManyByUserIdAsync(_currentContext.UserId.Value); - orgCollections = collections.Where(c => c.OrganizationId == organizationId); - } - - return orgCollections; - } } diff --git a/src/Core/Services/Implementations/DeviceService.cs b/src/Core/Services/Implementations/DeviceService.cs index 99f4648a3e..a65a49bdde 100644 --- a/src/Core/Services/Implementations/DeviceService.cs +++ b/src/Core/Services/Implementations/DeviceService.cs @@ -1,46 +1,47 @@ using Bit.Core.Entities; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class DeviceService : IDeviceService +namespace Bit.Core.Services { - private readonly IDeviceRepository _deviceRepository; - private readonly IPushRegistrationService _pushRegistrationService; - - public DeviceService( - IDeviceRepository deviceRepository, - IPushRegistrationService pushRegistrationService) + public class DeviceService : IDeviceService { - _deviceRepository = deviceRepository; - _pushRegistrationService = pushRegistrationService; - } + private readonly IDeviceRepository _deviceRepository; + private readonly IPushRegistrationService _pushRegistrationService; - public async Task SaveAsync(Device device) - { - if (device.Id == default(Guid)) + public DeviceService( + IDeviceRepository deviceRepository, + IPushRegistrationService pushRegistrationService) { - await _deviceRepository.CreateAsync(device); - } - else - { - device.RevisionDate = DateTime.UtcNow; - await _deviceRepository.ReplaceAsync(device); + _deviceRepository = deviceRepository; + _pushRegistrationService = pushRegistrationService; } - await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(), - device.UserId.ToString(), device.Identifier, device.Type); - } + public async Task SaveAsync(Device device) + { + if (device.Id == default(Guid)) + { + await _deviceRepository.CreateAsync(device); + } + else + { + device.RevisionDate = DateTime.UtcNow; + await _deviceRepository.ReplaceAsync(device); + } - public async Task ClearTokenAsync(Device device) - { - await _deviceRepository.ClearPushTokenAsync(device.Id); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); - } + await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(), + device.UserId.ToString(), device.Identifier, device.Type); + } - public async Task DeleteAsync(Device device) - { - await _deviceRepository.DeleteAsync(device); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); + public async Task ClearTokenAsync(Device device) + { + await _deviceRepository.ClearPushTokenAsync(device.Id); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); + } + + public async Task DeleteAsync(Device device) + { + await _deviceRepository.DeleteAsync(device); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); + } } } diff --git a/src/Core/Services/Implementations/EmergencyAccessService.cs b/src/Core/Services/Implementations/EmergencyAccessService.cs index e48000b525..06a5e5a85e 100644 --- a/src/Core/Services/Implementations/EmergencyAccessService.cs +++ b/src/Core/Services/Implementations/EmergencyAccessService.cs @@ -9,415 +9,416 @@ using Bit.Core.Settings; using Bit.Core.Tokens; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Services; - -public class EmergencyAccessService : IEmergencyAccessService +namespace Bit.Core.Services { - private readonly IEmergencyAccessRepository _emergencyAccessRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IUserRepository _userRepository; - private readonly ICipherRepository _cipherRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ICipherService _cipherService; - private readonly IMailService _mailService; - private readonly IUserService _userService; - private readonly GlobalSettings _globalSettings; - private readonly IPasswordHasher _passwordHasher; - private readonly IOrganizationService _organizationService; - private readonly IDataProtectorTokenFactory _dataProtectorTokenizer; - - public EmergencyAccessService( - IEmergencyAccessRepository emergencyAccessRepository, - IOrganizationUserRepository organizationUserRepository, - IUserRepository userRepository, - ICipherRepository cipherRepository, - IPolicyRepository policyRepository, - ICipherService cipherService, - IMailService mailService, - IUserService userService, - IPasswordHasher passwordHasher, - GlobalSettings globalSettings, - IOrganizationService organizationService, - IDataProtectorTokenFactory dataProtectorTokenizer) + public class EmergencyAccessService : IEmergencyAccessService { - _emergencyAccessRepository = emergencyAccessRepository; - _organizationUserRepository = organizationUserRepository; - _userRepository = userRepository; - _cipherRepository = cipherRepository; - _policyRepository = policyRepository; - _cipherService = cipherService; - _mailService = mailService; - _userService = userService; - _passwordHasher = passwordHasher; - _globalSettings = globalSettings; - _organizationService = organizationService; - _dataProtectorTokenizer = dataProtectorTokenizer; - } + private readonly IEmergencyAccessRepository _emergencyAccessRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IUserRepository _userRepository; + private readonly ICipherRepository _cipherRepository; + private readonly IPolicyRepository _policyRepository; + private readonly ICipherService _cipherService; + private readonly IMailService _mailService; + private readonly IUserService _userService; + private readonly GlobalSettings _globalSettings; + private readonly IPasswordHasher _passwordHasher; + private readonly IOrganizationService _organizationService; + private readonly IDataProtectorTokenFactory _dataProtectorTokenizer; - public async Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime) - { - if (!await _userService.CanAccessPremium(invitingUser)) + public EmergencyAccessService( + IEmergencyAccessRepository emergencyAccessRepository, + IOrganizationUserRepository organizationUserRepository, + IUserRepository userRepository, + ICipherRepository cipherRepository, + IPolicyRepository policyRepository, + ICipherService cipherService, + IMailService mailService, + IUserService userService, + IPasswordHasher passwordHasher, + GlobalSettings globalSettings, + IOrganizationService organizationService, + IDataProtectorTokenFactory dataProtectorTokenizer) { - throw new BadRequestException("Not a premium user."); + _emergencyAccessRepository = emergencyAccessRepository; + _organizationUserRepository = organizationUserRepository; + _userRepository = userRepository; + _cipherRepository = cipherRepository; + _policyRepository = policyRepository; + _cipherService = cipherService; + _mailService = mailService; + _userService = userService; + _passwordHasher = passwordHasher; + _globalSettings = globalSettings; + _organizationService = organizationService; + _dataProtectorTokenizer = dataProtectorTokenizer; } - if (type == EmergencyAccessType.Takeover && invitingUser.UsesKeyConnector) + public async Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime) { - throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); - } + if (!await _userService.CanAccessPremium(invitingUser)) + { + throw new BadRequestException("Not a premium user."); + } - var emergencyAccess = new EmergencyAccess - { - GrantorId = invitingUser.Id, - Email = email.ToLowerInvariant(), - Status = EmergencyAccessStatusType.Invited, - Type = type, - WaitTimeDays = waitTime, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; - - await _emergencyAccessRepository.CreateAsync(emergencyAccess); - await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); - - return emergencyAccess; - } - - public async Task GetAsync(Guid emergencyAccessId, Guid userId) - { - var emergencyAccess = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, userId); - if (emergencyAccess == null) - { - throw new BadRequestException("Emergency Access not valid."); - } - - return emergencyAccess; - } - - public async Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); - if (emergencyAccess == null || emergencyAccess.GrantorId != invitingUser.Id || - emergencyAccess.Status != EmergencyAccessStatusType.Invited) - { - throw new BadRequestException("Emergency Access not valid."); - } - - await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); - } - - public async Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); - if (emergencyAccess == null) - { - throw new BadRequestException("Emergency Access not valid."); - } - - if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) - { - throw new BadRequestException("Invalid token."); - } - - if (emergencyAccess.Status == EmergencyAccessStatusType.Accepted) - { - throw new BadRequestException("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact."); - } - else if (emergencyAccess.Status != EmergencyAccessStatusType.Invited) - { - throw new BadRequestException("Invitation already accepted."); - } - - if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || - !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("User email does not match invite."); - } - - var granteeEmail = emergencyAccess.Email; - - emergencyAccess.Status = EmergencyAccessStatusType.Accepted; - emergencyAccess.GranteeId = user.Id; - emergencyAccess.Email = null; - - var grantor = await userService.GetUserByIdAsync(emergencyAccess.GrantorId); - - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - await _mailService.SendEmergencyAccessAcceptedEmailAsync(granteeEmail, grantor.Email); - - return emergencyAccess; - } - - public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); - if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - await _emergencyAccessRepository.DeleteAsync(emergencyAccess); - } - - public async Task ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); - if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || - emergencyAccess.GrantorId != confirmingUserId) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(confirmingUserId); - if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); - } - - var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); - - emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; - emergencyAccess.KeyEncrypted = key; - emergencyAccess.Email = null; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - await _mailService.SendEmergencyAccessConfirmedEmailAsync(NameOrEmail(grantor), grantee.Email); - - return emergencyAccess; - } - - public async Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser) - { - if (!await _userService.CanAccessPremium(savingUser)) - { - throw new BadRequestException("Not a premium user."); - } - - if (emergencyAccess.GrantorId != savingUser.Id) - { - throw new BadRequestException("Emergency Access not valid."); - } - - if (emergencyAccess.Type == EmergencyAccessType.Takeover) - { - var grantor = await _userService.GetUserByIdAsync(emergencyAccess.GrantorId); - if (grantor.UsesKeyConnector) + if (type == EmergencyAccessType.Takeover && invitingUser.UsesKeyConnector) { throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); } - } - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - } - - public async Task InitiateAsync(Guid id, User initiatingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || - emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot takeover an account that is using Key Connector."); - } - - var now = DateTime.UtcNow; - emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; - emergencyAccess.RevisionDate = now; - emergencyAccess.RecoveryInitiatedDate = now; - emergencyAccess.LastNotificationDate = now; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - - await _mailService.SendEmergencyAccessRecoveryInitiated(emergencyAccess, NameOrEmail(initiatingUser), grantor.Email); - } - - public async Task ApproveAsync(Guid id, User approvingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (emergencyAccess == null || emergencyAccess.GrantorId != approvingUser.Id || - emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated) - { - throw new BadRequestException("Emergency Access not valid."); - } - - emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - - var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); - await _mailService.SendEmergencyAccessRecoveryApproved(emergencyAccess, NameOrEmail(approvingUser), grantee.Email); - } - - public async Task RejectAsync(Guid id, User rejectingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (emergencyAccess == null || emergencyAccess.GrantorId != rejectingUser.Id || - (emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated && - emergencyAccess.Status != EmergencyAccessStatusType.RecoveryApproved)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - - var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); - await _mailService.SendEmergencyAccessRecoveryRejected(emergencyAccess, NameOrEmail(rejectingUser), grantee.Email); - } - - public async Task> GetPoliciesAsync(Guid id, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); - var isOrganizationOwner = grantorOrganizations.Any(organization => organization.Type == OrganizationUserType.Owner); - var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; - - return policies; - } - - public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot takeover an account that is using Key Connector."); - } - - return (emergencyAccess, grantor); - } - - public async Task PasswordAsync(Guid id, User requestingUser, string newMasterPasswordHash, string key) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - grantor.MasterPassword = _passwordHasher.HashPassword(grantor, newMasterPasswordHash); - grantor.Key = key; - // Disable TwoFactor providers since they will otherwise block logins - grantor.SetTwoFactorProviders(new Dictionary()); - grantor.UnknownDeviceVerificationEnabled = false; - await _userRepository.ReplaceAsync(grantor); - - // Remove grantor from all organizations unless Owner - var orgUser = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); - foreach (var o in orgUser) - { - if (o.Type != OrganizationUserType.Owner) + var emergencyAccess = new EmergencyAccess { - await _organizationService.DeleteUserAsync(o.OrganizationId, grantor.Id); + GrantorId = invitingUser.Id, + Email = email.ToLowerInvariant(), + Status = EmergencyAccessStatusType.Invited, + Type = type, + WaitTimeDays = waitTime, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + await _emergencyAccessRepository.CreateAsync(emergencyAccess); + await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); + + return emergencyAccess; + } + + public async Task GetAsync(Guid emergencyAccessId, Guid userId) + { + var emergencyAccess = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, userId); + if (emergencyAccess == null) + { + throw new BadRequestException("Emergency Access not valid."); + } + + return emergencyAccess; + } + + public async Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + if (emergencyAccess == null || emergencyAccess.GrantorId != invitingUser.Id || + emergencyAccess.Status != EmergencyAccessStatusType.Invited) + { + throw new BadRequestException("Emergency Access not valid."); + } + + await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); + } + + public async Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + if (emergencyAccess == null) + { + throw new BadRequestException("Emergency Access not valid."); + } + + if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) + { + throw new BadRequestException("Invalid token."); + } + + if (emergencyAccess.Status == EmergencyAccessStatusType.Accepted) + { + throw new BadRequestException("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact."); + } + else if (emergencyAccess.Status != EmergencyAccessStatusType.Invited) + { + throw new BadRequestException("Invitation already accepted."); + } + + if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || + !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + throw new BadRequestException("User email does not match invite."); + } + + var granteeEmail = emergencyAccess.Email; + + emergencyAccess.Status = EmergencyAccessStatusType.Accepted; + emergencyAccess.GranteeId = user.Id; + emergencyAccess.Email = null; + + var grantor = await userService.GetUserByIdAsync(emergencyAccess.GrantorId); + + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + await _mailService.SendEmergencyAccessAcceptedEmailAsync(granteeEmail, grantor.Email); + + return emergencyAccess; + } + + public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + await _emergencyAccessRepository.DeleteAsync(emergencyAccess); + } + + public async Task ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); + if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || + emergencyAccess.GrantorId != confirmingUserId) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var grantor = await _userRepository.GetByIdAsync(confirmingUserId); + if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) + { + throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); + } + + var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); + + emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; + emergencyAccess.KeyEncrypted = key; + emergencyAccess.Email = null; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + await _mailService.SendEmergencyAccessConfirmedEmailAsync(NameOrEmail(grantor), grantee.Email); + + return emergencyAccess; + } + + public async Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser) + { + if (!await _userService.CanAccessPremium(savingUser)) + { + throw new BadRequestException("Not a premium user."); + } + + if (emergencyAccess.GrantorId != savingUser.Id) + { + throw new BadRequestException("Emergency Access not valid."); + } + + if (emergencyAccess.Type == EmergencyAccessType.Takeover) + { + var grantor = await _userService.GetUserByIdAsync(emergencyAccess.GrantorId); + if (grantor.UsesKeyConnector) + { + throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); + } + } + + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + } + + public async Task InitiateAsync(Guid id, User initiatingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || + emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); + + if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) + { + throw new BadRequestException("You cannot takeover an account that is using Key Connector."); + } + + var now = DateTime.UtcNow; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + emergencyAccess.RevisionDate = now; + emergencyAccess.RecoveryInitiatedDate = now; + emergencyAccess.LastNotificationDate = now; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + + await _mailService.SendEmergencyAccessRecoveryInitiated(emergencyAccess, NameOrEmail(initiatingUser), grantor.Email); + } + + public async Task ApproveAsync(Guid id, User approvingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (emergencyAccess == null || emergencyAccess.GrantorId != approvingUser.Id || + emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated) + { + throw new BadRequestException("Emergency Access not valid."); + } + + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + + var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); + await _mailService.SendEmergencyAccessRecoveryApproved(emergencyAccess, NameOrEmail(approvingUser), grantee.Email); + } + + public async Task RejectAsync(Guid id, User rejectingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (emergencyAccess == null || emergencyAccess.GrantorId != rejectingUser.Id || + (emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated && + emergencyAccess.Status != EmergencyAccessStatusType.RecoveryApproved)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + + var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); + await _mailService.SendEmergencyAccessRecoveryRejected(emergencyAccess, NameOrEmail(rejectingUser), grantee.Email); + } + + public async Task> GetPoliciesAsync(Guid id, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); + + var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); + var isOrganizationOwner = grantorOrganizations.Any(organization => organization.Type == OrganizationUserType.Owner); + var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; + + return policies; + } + + public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); + + if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) + { + throw new BadRequestException("You cannot takeover an account that is using Key Connector."); + } + + return (emergencyAccess, grantor); + } + + public async Task PasswordAsync(Guid id, User requestingUser, string newMasterPasswordHash, string key) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); + + grantor.MasterPassword = _passwordHasher.HashPassword(grantor, newMasterPasswordHash); + grantor.Key = key; + // Disable TwoFactor providers since they will otherwise block logins + grantor.SetTwoFactorProviders(new Dictionary()); + grantor.UnknownDeviceVerificationEnabled = false; + await _userRepository.ReplaceAsync(grantor); + + // Remove grantor from all organizations unless Owner + var orgUser = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); + foreach (var o in orgUser) + { + if (o.Type != OrganizationUserType.Owner) + { + await _organizationService.DeleteUserAsync(o.OrganizationId, grantor.Id); + } } } - } - public async Task SendNotificationsAsync() - { - var toNotify = await _emergencyAccessRepository.GetManyToNotifyAsync(); - - foreach (var notify in toNotify) + public async Task SendNotificationsAsync() { - var ea = notify.ToEmergencyAccess(); - ea.LastNotificationDate = DateTime.UtcNow; - await _emergencyAccessRepository.ReplaceAsync(ea); + var toNotify = await _emergencyAccessRepository.GetManyToNotifyAsync(); - var granteeNameOrEmail = string.IsNullOrWhiteSpace(notify.GranteeName) ? notify.GranteeEmail : notify.GranteeName; + foreach (var notify in toNotify) + { + var ea = notify.ToEmergencyAccess(); + ea.LastNotificationDate = DateTime.UtcNow; + await _emergencyAccessRepository.ReplaceAsync(ea); - await _mailService.SendEmergencyAccessRecoveryReminder(ea, granteeNameOrEmail, notify.GrantorEmail); - } - } + var granteeNameOrEmail = string.IsNullOrWhiteSpace(notify.GranteeName) ? notify.GranteeEmail : notify.GranteeName; - public async Task HandleTimedOutRequestsAsync() - { - var expired = await _emergencyAccessRepository.GetExpiredRecoveriesAsync(); - - foreach (var details in expired) - { - var ea = details.ToEmergencyAccess(); - ea.Status = EmergencyAccessStatusType.RecoveryApproved; - await _emergencyAccessRepository.ReplaceAsync(ea); - - var grantorNameOrEmail = string.IsNullOrWhiteSpace(details.GrantorName) ? details.GrantorEmail : details.GrantorName; - var granteeNameOrEmail = string.IsNullOrWhiteSpace(details.GranteeName) ? details.GranteeEmail : details.GranteeName; - - await _mailService.SendEmergencyAccessRecoveryApproved(ea, grantorNameOrEmail, details.GranteeEmail); - await _mailService.SendEmergencyAccessRecoveryTimedOut(ea, granteeNameOrEmail, details.GrantorEmail); - } - } - - public async Task ViewAsync(Guid id, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) - { - throw new BadRequestException("Emergency Access not valid."); + await _mailService.SendEmergencyAccessRecoveryReminder(ea, granteeNameOrEmail, notify.GrantorEmail); + } } - var ciphers = await _cipherRepository.GetManyByUserIdAsync(emergencyAccess.GrantorId, false); - - return new EmergencyAccessViewData + public async Task HandleTimedOutRequestsAsync() { - EmergencyAccess = emergencyAccess, - Ciphers = ciphers, - }; - } + var expired = await _emergencyAccessRepository.GetExpiredRecoveriesAsync(); - public async Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + foreach (var details in expired) + { + var ea = details.ToEmergencyAccess(); + ea.Status = EmergencyAccessStatusType.RecoveryApproved; + await _emergencyAccessRepository.ReplaceAsync(ea); - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) - { - throw new BadRequestException("Emergency Access not valid."); + var grantorNameOrEmail = string.IsNullOrWhiteSpace(details.GrantorName) ? details.GrantorEmail : details.GrantorName; + var granteeNameOrEmail = string.IsNullOrWhiteSpace(details.GranteeName) ? details.GranteeEmail : details.GranteeName; + + await _mailService.SendEmergencyAccessRecoveryApproved(ea, grantorNameOrEmail, details.GranteeEmail); + await _mailService.SendEmergencyAccessRecoveryTimedOut(ea, granteeNameOrEmail, details.GrantorEmail); + } } - var cipher = await _cipherRepository.GetByIdAsync(cipherId, emergencyAccess.GrantorId); - return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); - } + public async Task ViewAsync(Guid id, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName) - { - var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours)); - await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); - } + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) + { + throw new BadRequestException("Emergency Access not valid."); + } - private string NameOrEmail(User user) - { - return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; - } + var ciphers = await _cipherRepository.GetManyByUserIdAsync(emergencyAccess.GrantorId, false); - private bool IsValidRequest(EmergencyAccess availibleAccess, User requestingUser, EmergencyAccessType requestedAccessType) - { - return availibleAccess != null && - availibleAccess.GranteeId == requestingUser.Id && - availibleAccess.Status == EmergencyAccessStatusType.RecoveryApproved && - availibleAccess.Type == requestedAccessType; + return new EmergencyAccessViewData + { + EmergencyAccess = emergencyAccess, + Ciphers = ciphers, + }; + } + + public async Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var cipher = await _cipherRepository.GetByIdAsync(cipherId, emergencyAccess.GrantorId); + return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); + } + + private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName) + { + var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours)); + await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); + } + + private string NameOrEmail(User user) + { + return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; + } + + private bool IsValidRequest(EmergencyAccess availibleAccess, User requestingUser, EmergencyAccessType requestedAccessType) + { + return availibleAccess != null && + availibleAccess.GranteeId == requestingUser.Id && + availibleAccess.Status == EmergencyAccessStatusType.RecoveryApproved && + availibleAccess.Type == requestedAccessType; + } } } diff --git a/src/Core/Services/Implementations/EventService.cs b/src/Core/Services/Implementations/EventService.cs index 18a4b19cf1..3a35555d1b 100644 --- a/src/Core/Services/Implementations/EventService.cs +++ b/src/Core/Services/Implementations/EventService.cs @@ -7,321 +7,322 @@ using Bit.Core.Models.Data.Organizations; using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Core.Services; - -public class EventService : IEventService +namespace Bit.Core.Services { - private readonly IEventWriteService _eventWriteService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IApplicationCacheService _applicationCacheService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - - public EventService( - IEventWriteService eventWriteService, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IApplicationCacheService applicationCacheService, - ICurrentContext currentContext, - GlobalSettings globalSettings) + public class EventService : IEventService { - _eventWriteService = eventWriteService; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _applicationCacheService = applicationCacheService; - _currentContext = currentContext; - _globalSettings = globalSettings; - } + private readonly IEventWriteService _eventWriteService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IApplicationCacheService _applicationCacheService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; - public async Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) - { - var events = new List + public EventService( + IEventWriteService eventWriteService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IApplicationCacheService applicationCacheService, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - new EventMessage(_currentContext) + _eventWriteService = eventWriteService; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _applicationCacheService = applicationCacheService; + _currentContext = currentContext; + _globalSettings = globalSettings; + } + + public async Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) + { + var events = new List { - UserId = userId, - ActingUserId = userId, - Type = type, - Date = date.GetValueOrDefault(DateTime.UtcNow) + new EventMessage(_currentContext) + { + UserId = userId, + ActingUserId = userId, + Type = type, + Date = date.GetValueOrDefault(DateTime.UtcNow) + } + }; + + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, userId); + var orgEvents = orgs.Where(o => CanUseEvents(orgAbilities, o.Id)) + .Select(o => new EventMessage(_currentContext) + { + OrganizationId = o.Id, + UserId = userId, + ActingUserId = userId, + Type = type, + Date = DateTime.UtcNow + }); + + var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); + var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, userId); + var providerEvents = providers.Where(o => CanUseProviderEvents(providerAbilities, o.Id)) + .Select(p => new EventMessage(_currentContext) + { + ProviderId = p.Id, + UserId = userId, + ActingUserId = userId, + Type = type, + Date = DateTime.UtcNow + }); + + if (orgEvents.Any() || providerEvents.Any()) + { + events.AddRange(orgEvents); + events.AddRange(providerEvents); + await _eventWriteService.CreateManyAsync(events); } - }; - - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, userId); - var orgEvents = orgs.Where(o => CanUseEvents(orgAbilities, o.Id)) - .Select(o => new EventMessage(_currentContext) + else { - OrganizationId = o.Id, - UserId = userId, - ActingUserId = userId, - Type = type, - Date = DateTime.UtcNow - }); - - var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); - var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, userId); - var providerEvents = providers.Where(o => CanUseProviderEvents(providerAbilities, o.Id)) - .Select(p => new EventMessage(_currentContext) - { - ProviderId = p.Id, - UserId = userId, - ActingUserId = userId, - Type = type, - Date = DateTime.UtcNow - }); - - if (orgEvents.Any() || providerEvents.Any()) - { - events.AddRange(orgEvents); - events.AddRange(providerEvents); - await _eventWriteService.CreateManyAsync(events); + await _eventWriteService.CreateAsync(events.First()); + } } - else - { - await _eventWriteService.CreateAsync(events.First()); - } - } - public async Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) - { - var e = await BuildCipherEventMessageAsync(cipher, type, date); - if (e != null) + public async Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) { - await _eventWriteService.CreateAsync(e); - } - } - - public async Task LogCipherEventsAsync(IEnumerable> events) - { - var cipherEvents = new List(); - foreach (var ev in events) - { - var e = await BuildCipherEventMessageAsync(ev.Item1, ev.Item2, ev.Item3); + var e = await BuildCipherEventMessageAsync(cipher, type, date); if (e != null) { - cipherEvents.Add(e); + await _eventWriteService.CreateAsync(e); } } - await _eventWriteService.CreateManyAsync(cipherEvents); - } - private async Task BuildCipherEventMessageAsync(Cipher cipher, EventType type, DateTime? date = null) - { - // Only logging organization cipher events for now. - if (!cipher.OrganizationId.HasValue || (!_currentContext?.UserId.HasValue ?? true)) + public async Task LogCipherEventsAsync(IEnumerable> events) { - return null; + var cipherEvents = new List(); + foreach (var ev in events) + { + var e = await BuildCipherEventMessageAsync(ev.Item1, ev.Item2, ev.Item3); + if (e != null) + { + cipherEvents.Add(e); + } + } + await _eventWriteService.CreateManyAsync(cipherEvents); } - if (cipher.OrganizationId.HasValue) + private async Task BuildCipherEventMessageAsync(Cipher cipher, EventType type, DateTime? date = null) { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, cipher.OrganizationId.Value)) + // Only logging organization cipher events for now. + if (!cipher.OrganizationId.HasValue || (!_currentContext?.UserId.HasValue ?? true)) { return null; } - } - return new EventMessage(_currentContext) - { - OrganizationId = cipher.OrganizationId, - UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId, - CipherId = cipher.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(cipher.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - } - - public async Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, collection.OrganizationId)) - { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = collection.OrganizationId, - CollectionId = collection.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(collection.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } - - public async Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, group.OrganizationId)) - { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = group.OrganizationId, - GroupId = group.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(@group.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } - - public async Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, policy.OrganizationId)) - { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = policy.OrganizationId, - PolicyId = policy.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(policy.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } - - public async Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, - DateTime? date = null) => - await LogOrganizationUserEventsAsync(new[] { (organizationUser, type, date) }); - - public async Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var eventMessages = new List(); - foreach (var (organizationUser, type, date) in events) - { - if (!CanUseEvents(orgAbilities, organizationUser.OrganizationId)) + if (cipher.OrganizationId.HasValue) { - continue; + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, cipher.OrganizationId.Value)) + { + return null; + } } - eventMessages.Add(new EventMessage(_currentContext) + return new EventMessage(_currentContext) { - OrganizationId = organizationUser.OrganizationId, - UserId = organizationUser.UserId, - OrganizationUserId = organizationUser.Id, - ProviderId = await GetProviderIdAsync(organizationUser.OrganizationId), + OrganizationId = cipher.OrganizationId, + UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId, + CipherId = cipher.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(cipher.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + } + + public async Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, collection.OrganizationId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = collection.OrganizationId, + CollectionId = collection.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(collection.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, group.OrganizationId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = group.OrganizationId, + GroupId = group.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(@group.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, policy.OrganizationId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = policy.OrganizationId, + PolicyId = policy.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(policy.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, + DateTime? date = null) => + await LogOrganizationUserEventsAsync(new[] { (organizationUser, type, date) }); + + public async Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + foreach (var (organizationUser, type, date) in events) + { + if (!CanUseEvents(orgAbilities, organizationUser.OrganizationId)) + { + continue; + } + + eventMessages.Add(new EventMessage(_currentContext) + { + OrganizationId = organizationUser.OrganizationId, + UserId = organizationUser.UserId, + OrganizationUserId = organizationUser.Id, + ProviderId = await GetProviderIdAsync(organizationUser.OrganizationId), + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }); + } + + await _eventWriteService.CreateManyAsync(eventMessages); + } + + public async Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) + { + if (!organization.Enabled || !organization.UseEvents) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = organization.Id, + ProviderId = await GetProviderIdAsync(organization.Id), + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow), + InstallationId = GetInstallationId(), + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) + { + await LogProviderUsersEventAsync(new[] { (providerUser, type, date) }); + } + + public async Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) + { + var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); + var eventMessages = new List(); + foreach (var (providerUser, type, date) in events) + { + if (!CanUseProviderEvents(providerAbilities, providerUser.ProviderId)) + { + continue; + } + eventMessages.Add(new EventMessage(_currentContext) + { + ProviderId = providerUser.ProviderId, + UserId = providerUser.UserId, + ProviderUserId = providerUser.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }); + } + + await _eventWriteService.CreateManyAsync(eventMessages); + } + + public async Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, + DateTime? date = null) + { + var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); + if (!CanUseProviderEvents(providerAbilities, providerOrganization.ProviderId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + ProviderId = providerOrganization.ProviderId, + ProviderOrganizationId = providerOrganization.Id, Type = type, ActingUserId = _currentContext?.UserId, Date = date.GetValueOrDefault(DateTime.UtcNow) - }); + }; + await _eventWriteService.CreateAsync(e); } - await _eventWriteService.CreateManyAsync(eventMessages); - } - - public async Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) - { - if (!organization.Enabled || !organization.UseEvents) + private async Task GetProviderIdAsync(Guid? orgId) { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = organization.Id, - ProviderId = await GetProviderIdAsync(organization.Id), - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow), - InstallationId = GetInstallationId(), - }; - await _eventWriteService.CreateAsync(e); - } - - public async Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) - { - await LogProviderUsersEventAsync(new[] { (providerUser, type, date) }); - } - - public async Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) - { - var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); - var eventMessages = new List(); - foreach (var (providerUser, type, date) in events) - { - if (!CanUseProviderEvents(providerAbilities, providerUser.ProviderId)) + if (_currentContext == null || !orgId.HasValue) { - continue; + return null; } - eventMessages.Add(new EventMessage(_currentContext) + + return await _currentContext.ProviderIdForOrg(orgId.Value); + } + + private Guid? GetInstallationId() + { + if (_currentContext == null) { - ProviderId = providerUser.ProviderId, - UserId = providerUser.UserId, - ProviderUserId = providerUser.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow) - }); + return null; + } + + return _currentContext.InstallationId; } - await _eventWriteService.CreateManyAsync(eventMessages); - } - - public async Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, - DateTime? date = null) - { - var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); - if (!CanUseProviderEvents(providerAbilities, providerOrganization.ProviderId)) + private bool CanUseEvents(IDictionary orgAbilities, Guid orgId) { - return; + return orgAbilities != null && orgAbilities.ContainsKey(orgId) && + orgAbilities[orgId].Enabled && orgAbilities[orgId].UseEvents; } - var e = new EventMessage(_currentContext) + private bool CanUseProviderEvents(IDictionary providerAbilities, Guid providerId) { - ProviderId = providerOrganization.ProviderId, - ProviderOrganizationId = providerOrganization.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } - - private async Task GetProviderIdAsync(Guid? orgId) - { - if (_currentContext == null || !orgId.HasValue) - { - return null; + return providerAbilities != null && providerAbilities.ContainsKey(providerId) && + providerAbilities[providerId].Enabled && providerAbilities[providerId].UseEvents; } - - return await _currentContext.ProviderIdForOrg(orgId.Value); - } - - private Guid? GetInstallationId() - { - if (_currentContext == null) - { - return null; - } - - return _currentContext.InstallationId; - } - - private bool CanUseEvents(IDictionary orgAbilities, Guid orgId) - { - return orgAbilities != null && orgAbilities.ContainsKey(orgId) && - orgAbilities[orgId].Enabled && orgAbilities[orgId].UseEvents; - } - - private bool CanUseProviderEvents(IDictionary providerAbilities, Guid providerId) - { - return providerAbilities != null && providerAbilities.ContainsKey(providerId) && - providerAbilities[providerId].Enabled && providerAbilities[providerId].UseEvents; } } diff --git a/src/Core/Services/Implementations/GroupService.cs b/src/Core/Services/Implementations/GroupService.cs index c637fd0ce0..3d7872f707 100644 --- a/src/Core/Services/Implementations/GroupService.cs +++ b/src/Core/Services/Implementations/GroupService.cs @@ -5,81 +5,82 @@ using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class GroupService : IGroupService +namespace Bit.Core.Services { - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IGroupRepository _groupRepository; - private readonly IReferenceEventService _referenceEventService; - - public GroupService( - IEventService eventService, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IGroupRepository groupRepository, - IReferenceEventService referenceEventService) + public class GroupService : IGroupService { - _eventService = eventService; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _groupRepository = groupRepository; - _referenceEventService = referenceEventService; - } + private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IGroupRepository _groupRepository; + private readonly IReferenceEventService _referenceEventService; - public async Task SaveAsync(Group group, IEnumerable collections = null) - { - var org = await _organizationRepository.GetByIdAsync(group.OrganizationId); - if (org == null) + public GroupService( + IEventService eventService, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IGroupRepository groupRepository, + IReferenceEventService referenceEventService) { - throw new BadRequestException("Organization not found"); + _eventService = eventService; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _groupRepository = groupRepository; + _referenceEventService = referenceEventService; } - if (!org.UseGroups) + public async Task SaveAsync(Group group, IEnumerable collections = null) { - throw new BadRequestException("This organization cannot use groups."); - } - - if (group.Id == default(Guid)) - { - group.CreationDate = group.RevisionDate = DateTime.UtcNow; - - if (collections == null) + var org = await _organizationRepository.GetByIdAsync(group.OrganizationId); + if (org == null) { - await _groupRepository.CreateAsync(group); + throw new BadRequestException("Organization not found"); + } + + if (!org.UseGroups) + { + throw new BadRequestException("This organization cannot use groups."); + } + + if (group.Id == default(Guid)) + { + group.CreationDate = group.RevisionDate = DateTime.UtcNow; + + if (collections == null) + { + await _groupRepository.CreateAsync(group); + } + else + { + await _groupRepository.CreateAsync(group, collections); + } + + await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Created); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.GroupCreated, org)); } else { - await _groupRepository.CreateAsync(group, collections); + group.RevisionDate = DateTime.UtcNow; + await _groupRepository.ReplaceAsync(group, collections ?? new List()); + await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Updated); } - - await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Created); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.GroupCreated, org)); } - else + + public async Task DeleteAsync(Group group) { - group.RevisionDate = DateTime.UtcNow; - await _groupRepository.ReplaceAsync(group, collections ?? new List()); - await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Updated); + await _groupRepository.DeleteAsync(group); + await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Deleted); } - } - public async Task DeleteAsync(Group group) - { - await _groupRepository.DeleteAsync(group); - await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Deleted); - } - - public async Task DeleteUserAsync(Group group, Guid organizationUserId) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != group.OrganizationId) + public async Task DeleteUserAsync(Group group, Guid organizationUserId) { - throw new NotFoundException(); + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != group.OrganizationId) + { + throw new NotFoundException(); + } + await _groupRepository.DeleteUserAsync(group.Id, organizationUserId); + await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_UpdatedGroups); } - await _groupRepository.DeleteUserAsync(group.Id, organizationUserId); - await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_UpdatedGroups); } } diff --git a/src/Core/Services/Implementations/HCaptchaValidationService.cs b/src/Core/Services/Implementations/HCaptchaValidationService.cs index b8a63c642c..0b72d52865 100644 --- a/src/Core/Services/Implementations/HCaptchaValidationService.cs +++ b/src/Core/Services/Implementations/HCaptchaValidationService.cs @@ -8,124 +8,125 @@ using Bit.Core.Settings; using Bit.Core.Tokens; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class HCaptchaValidationService : ICaptchaValidationService +namespace Bit.Core.Services { - private readonly ILogger _logger; - private readonly IHttpClientFactory _httpClientFactory; - private readonly GlobalSettings _globalSettings; - private readonly IDataProtectorTokenFactory _tokenizer; - - public HCaptchaValidationService( - ILogger logger, - IHttpClientFactory httpClientFactory, - IDataProtectorTokenFactory tokenizer, - GlobalSettings globalSettings) + public class HCaptchaValidationService : ICaptchaValidationService { - _logger = logger; - _httpClientFactory = httpClientFactory; - _globalSettings = globalSettings; - _tokenizer = tokenizer; - } + private readonly ILogger _logger; + private readonly IHttpClientFactory _httpClientFactory; + private readonly GlobalSettings _globalSettings; + private readonly IDataProtectorTokenFactory _tokenizer; - public string SiteKeyResponseKeyName => "HCaptcha_SiteKey"; - public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey; - - public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user)); - - public async Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, - User user = null) - { - var response = new CaptchaResponse { Success = false }; - if (string.IsNullOrWhiteSpace(captchaResponse)) + public HCaptchaValidationService( + ILogger logger, + IHttpClientFactory httpClientFactory, + IDataProtectorTokenFactory tokenizer, + GlobalSettings globalSettings) { - return response; + _logger = logger; + _httpClientFactory = httpClientFactory; + _globalSettings = globalSettings; + _tokenizer = tokenizer; } - if (user != null && ValidateCaptchaBypassToken(captchaResponse, user)) - { - response.Success = true; - return response; - } + public string SiteKeyResponseKeyName => "HCaptcha_SiteKey"; + public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey; - var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService"); + public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user)); - var requestMessage = new HttpRequestMessage + public async Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, + User user = null) { - Method = HttpMethod.Post, - RequestUri = new Uri("https://hcaptcha.com/siteverify"), - Content = new FormUrlEncodedContent(new Dictionary + var response = new CaptchaResponse { Success = false }; + if (string.IsNullOrWhiteSpace(captchaResponse)) { - { "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) }, - { "secret", _globalSettings.Captcha.HCaptchaSecretKey }, - { "sitekey", SiteKey }, - { "remoteip", clientIpAddress } - }) - }; + return response; + } - HttpResponseMessage responseMessage; - try - { - responseMessage = await httpClient.SendAsync(requestMessage); - } - catch (Exception e) - { - _logger.LogError(11389, e, "Unable to verify with HCaptcha."); + if (user != null && ValidateCaptchaBypassToken(captchaResponse, user)) + { + response.Success = true; + return response; + } + + var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService"); + + var requestMessage = new HttpRequestMessage + { + Method = HttpMethod.Post, + RequestUri = new Uri("https://hcaptcha.com/siteverify"), + Content = new FormUrlEncodedContent(new Dictionary + { + { "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) }, + { "secret", _globalSettings.Captcha.HCaptchaSecretKey }, + { "sitekey", SiteKey }, + { "remoteip", clientIpAddress } + }) + }; + + HttpResponseMessage responseMessage; + try + { + responseMessage = await httpClient.SendAsync(requestMessage); + } + catch (Exception e) + { + _logger.LogError(11389, e, "Unable to verify with HCaptcha."); + return response; + } + + if (!responseMessage.IsSuccessStatusCode) + { + return response; + } + + using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync(); + response.Success = hcaptchaResponse.Success; + var score = hcaptchaResponse.Score.GetValueOrDefault(); + response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold; + response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold; + response.Score = score; return response; } - if (!responseMessage.IsSuccessStatusCode) + public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) { - return response; + if (user == null) + { + return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired; + } + + var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; + var failedLoginCount = user?.FailedLoginCount ?? 0; + var cloudEmailUnverified = !_globalSettings.SelfHosted && !user.EmailVerified; + return currentContext.IsBot || + _globalSettings.Captcha.ForceCaptchaRequired || + cloudEmailUnverified || + failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling; } - using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync(); - response.Success = hcaptchaResponse.Success; - var score = hcaptchaResponse.Score.GetValueOrDefault(); - response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold; - response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold; - response.Score = score; - return response; - } + private static bool TokenIsValidApiKey(string bypassToken, User user) => + !string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken; - public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) - { - if (user == null) + private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user) { - return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired; + return _tokenizer.TryUnprotect(encryptedToken, out var data) && + data.Valid && data.TokenIsValid(user); } - var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; - var failedLoginCount = user?.FailedLoginCount ?? 0; - var cloudEmailUnverified = !_globalSettings.SelfHosted && !user.EmailVerified; - return currentContext.IsBot || - _globalSettings.Captcha.ForceCaptchaRequired || - cloudEmailUnverified || - failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling; - } + private bool ValidateCaptchaBypassToken(string bypassToken, User user) => + TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user); - private static bool TokenIsValidApiKey(string bypassToken, User user) => - !string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken; + public class HCaptchaResponse : IDisposable + { + [JsonPropertyName("success")] + public bool Success { get; set; } + [JsonPropertyName("score")] + public double? Score { get; set; } + [JsonPropertyName("score_reason")] + public List ScoreReason { get; set; } - private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user) - { - return _tokenizer.TryUnprotect(encryptedToken, out var data) && - data.Valid && data.TokenIsValid(user); - } - - private bool ValidateCaptchaBypassToken(string bypassToken, User user) => - TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user); - - public class HCaptchaResponse : IDisposable - { - [JsonPropertyName("success")] - public bool Success { get; set; } - [JsonPropertyName("score")] - public double? Score { get; set; } - [JsonPropertyName("score_reason")] - public List ScoreReason { get; set; } - - public void Dispose() { } + public void Dispose() { } + } } } diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 3688c8f153..a1cfb61cec 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -10,879 +10,880 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using HandlebarsDotNet; -namespace Bit.Core.Services; - -public class HandlebarsMailService : IMailService +namespace Bit.Core.Services { - private const string Namespace = "Bit.Core.MailTemplates.Handlebars"; - - private readonly GlobalSettings _globalSettings; - private readonly IMailDeliveryService _mailDeliveryService; - private readonly IMailEnqueuingService _mailEnqueuingService; - private readonly Dictionary> _templateCache = - new Dictionary>(); - - private bool _registeredHelpersAndPartials = false; - - public HandlebarsMailService( - GlobalSettings globalSettings, - IMailDeliveryService mailDeliveryService, - IMailEnqueuingService mailEnqueuingService) + public class HandlebarsMailService : IMailService { - _globalSettings = globalSettings; - _mailDeliveryService = mailDeliveryService; - _mailEnqueuingService = mailEnqueuingService; - } + private const string Namespace = "Bit.Core.MailTemplates.Handlebars"; - public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) - { - var message = CreateDefaultMessage("Verify Your Email", email); - var model = new VerifyEmailModel + private readonly GlobalSettings _globalSettings; + private readonly IMailDeliveryService _mailDeliveryService; + private readonly IMailEnqueuingService _mailEnqueuingService; + private readonly Dictionary> _templateCache = + new Dictionary>(); + + private bool _registeredHelpersAndPartials = false; + + public HandlebarsMailService( + GlobalSettings globalSettings, + IMailDeliveryService mailDeliveryService, + IMailEnqueuingService mailEnqueuingService) { - Token = WebUtility.UrlEncode(token), - UserId = userId, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "VerifyEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "VerifyEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) - { - var message = CreateDefaultMessage("Delete Your Account", email); - var model = new VerifyDeleteModel - { - Token = WebUtility.UrlEncode(token), - UserId = userId, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - Email = email, - EmailEncoded = WebUtility.UrlEncode(email) - }; - await AddMessageContentAsync(message, "VerifyDelete", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "VerifyDelete"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) - { - var message = CreateDefaultMessage("Your Email Change", toEmail); - var model = new ChangeEmailExistsViewModel - { - FromEmail = fromEmail, - ToEmail = toEmail, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "ChangeEmailAlreadyExists", model); - message.Category = "ChangeEmailAlreadyExists"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendChangeEmailEmailAsync(string newEmailAddress, string token) - { - var message = CreateDefaultMessage("Your Email Change", newEmailAddress); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "ChangeEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "ChangeEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendTwoFactorEmailAsync(string email, string token) - { - var message = CreateDefaultMessage("Your Two-step Login Verification Code", email); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "TwoFactorEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "TwoFactorEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) - { - var message = CreateDefaultMessage("New Device Login Verification Code", email); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "NewDeviceLoginTwoFactorEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "TwoFactorEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendMasterPasswordHintEmailAsync(string email, string hint) - { - var message = CreateDefaultMessage("Your Master Password Hint", email); - var model = new MasterPasswordHintViewModel - { - Hint = CoreHelpers.SanitizeForEmail(hint, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "MasterPasswordHint", model); - message.Category = "MasterPasswordHint"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendNoMasterPasswordHintEmailAsync(string email) - { - var message = CreateDefaultMessage("Your Master Password Hint", email); - var model = new BaseMailModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "NoMasterPasswordHint", model); - message.Category = "NoMasterPasswordHint"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) - { - var message = CreateDefaultMessage($"{organization.Name} Seat Count Has Increased", ownerEmails); - var model = new OrganizationSeatsAutoscaledViewModel - { - OrganizationId = organization.Id, - InitialSeatCount = initialSeatCount, - CurrentSeatCount = organization.Seats.Value, - }; - - await AddMessageContentAsync(message, "OrganizationSeatsAutoscaled", model); - message.Category = "OrganizationSeatsAutoscaled"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) - { - var message = CreateDefaultMessage($"{organization.Name} Seat Limit Reached", ownerEmails); - var model = new OrganizationSeatsMaxReachedViewModel - { - OrganizationId = organization.Id, - MaxSeatCount = maxSeatCount, - }; - - await AddMessageContentAsync(message, "OrganizationSeatsMaxReached", model); - message.Category = "OrganizationSeatsMaxReached"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, - IEnumerable adminEmails) - { - var message = CreateDefaultMessage($"Action Required: {userIdentifier} Needs to Be Confirmed", adminEmails); - var model = new OrganizationUserAcceptedViewModel - { - OrganizationId = organization.Id, - OrganizationName = CoreHelpers.SanitizeForEmail(organization.Name, false), - UserIdentifier = userIdentifier, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserAccepted", model); - message.Category = "OrganizationUserAccepted"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) - { - var message = CreateDefaultMessage($"You Have Been Confirmed To {organizationName}", email); - var model = new OrganizationUserConfirmedViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserConfirmed", model); - message.Category = "OrganizationUserConfirmed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) => - BulkSendOrganizationInviteEmailAsync(organizationName, new[] { (orgUser, token) }); - - public async Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) - { - MailQueueMessage CreateMessage(string email, object model) - { - var message = CreateDefaultMessage($"Join {organizationName}", email); - return new MailQueueMessage(message, "OrganizationUserInvited", model); + _globalSettings = globalSettings; + _mailDeliveryService = mailDeliveryService; + _mailEnqueuingService = mailEnqueuingService; } - var messageModels = invites.Select(invite => CreateMessage(invite.orgUser.Email, - new OrganizationUserInvitedViewModel + public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) + { + var message = CreateDefaultMessage("Verify Your Email", email); + var model = new VerifyEmailModel + { + Token = WebUtility.UrlEncode(token), + UserId = userId, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "VerifyEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "VerifyEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) + { + var message = CreateDefaultMessage("Delete Your Account", email); + var model = new VerifyDeleteModel + { + Token = WebUtility.UrlEncode(token), + UserId = userId, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + Email = email, + EmailEncoded = WebUtility.UrlEncode(email) + }; + await AddMessageContentAsync(message, "VerifyDelete", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "VerifyDelete"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) + { + var message = CreateDefaultMessage("Your Email Change", toEmail); + var model = new ChangeEmailExistsViewModel + { + FromEmail = fromEmail, + ToEmail = toEmail, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "ChangeEmailAlreadyExists", model); + message.Category = "ChangeEmailAlreadyExists"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendChangeEmailEmailAsync(string newEmailAddress, string token) + { + var message = CreateDefaultMessage("Your Email Change", newEmailAddress); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "ChangeEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "ChangeEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendTwoFactorEmailAsync(string email, string token) + { + var message = CreateDefaultMessage("Your Two-step Login Verification Code", email); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "TwoFactorEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "TwoFactorEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) + { + var message = CreateDefaultMessage("New Device Login Verification Code", email); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "NewDeviceLoginTwoFactorEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "TwoFactorEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendMasterPasswordHintEmailAsync(string email, string hint) + { + var message = CreateDefaultMessage("Your Master Password Hint", email); + var model = new MasterPasswordHintViewModel + { + Hint = CoreHelpers.SanitizeForEmail(hint, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MasterPasswordHint", model); + message.Category = "MasterPasswordHint"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendNoMasterPasswordHintEmailAsync(string email) + { + var message = CreateDefaultMessage("Your Master Password Hint", email); + var model = new BaseMailModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "NoMasterPasswordHint", model); + message.Category = "NoMasterPasswordHint"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) + { + var message = CreateDefaultMessage($"{organization.Name} Seat Count Has Increased", ownerEmails); + var model = new OrganizationSeatsAutoscaledViewModel + { + OrganizationId = organization.Id, + InitialSeatCount = initialSeatCount, + CurrentSeatCount = organization.Seats.Value, + }; + + await AddMessageContentAsync(message, "OrganizationSeatsAutoscaled", model); + message.Category = "OrganizationSeatsAutoscaled"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) + { + var message = CreateDefaultMessage($"{organization.Name} Seat Limit Reached", ownerEmails); + var model = new OrganizationSeatsMaxReachedViewModel + { + OrganizationId = organization.Id, + MaxSeatCount = maxSeatCount, + }; + + await AddMessageContentAsync(message, "OrganizationSeatsMaxReached", model); + message.Category = "OrganizationSeatsMaxReached"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, + IEnumerable adminEmails) + { + var message = CreateDefaultMessage($"Action Required: {userIdentifier} Needs to Be Confirmed", adminEmails); + var model = new OrganizationUserAcceptedViewModel + { + OrganizationId = organization.Id, + OrganizationName = CoreHelpers.SanitizeForEmail(organization.Name, false), + UserIdentifier = userIdentifier, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserAccepted", model); + message.Category = "OrganizationUserAccepted"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You Have Been Confirmed To {organizationName}", email); + var model = new OrganizationUserConfirmedViewModel { OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - Email = WebUtility.UrlEncode(invite.orgUser.Email), - OrganizationId = invite.orgUser.OrganizationId.ToString(), - OrganizationUserId = invite.orgUser.Id.ToString(), - Token = WebUtility.UrlEncode(invite.token.Token), - ExpirationDate = $"{invite.token.ExpirationDate.ToLongDateString()} {invite.token.ExpirationDate.ToShortTimeString()} UTC", - OrganizationNameUrlEncoded = WebUtility.UrlEncode(organizationName), WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - } - )); - - await EnqueueMailAsync(messageModels); - } - - public async Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) - { - var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); - var model = new OrganizationUserRemovedForPolicyTwoStepViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicyTwoStep", model); - message.Category = "OrganizationUserRemovedForPolicyTwoStep"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendWelcomeEmailAsync(User user) - { - var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); - var model = new BaseMailModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "Welcome", model); - message.Category = "Welcome"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) - { - var message = CreateDefaultMessage("[Admin] Continue Logging In", email); - var url = CoreHelpers.ExtendQuery(new Uri($"{_globalSettings.BaseServiceUri.Admin}/login/confirm"), - new Dictionary - { - ["returnUrl"] = returnUrl, - ["email"] = email, - ["token"] = token, - }); - var model = new PasswordlessSignInModel - { - Url = url.ToString() - }; - await AddMessageContentAsync(message, "PasswordlessSignIn", model); - message.Category = "PasswordlessSignIn"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, - List items, bool mentionInvoices) - { - var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email); - var model = new InvoiceUpcomingViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - AmountDue = amount, - DueDate = dueDate, - Items = items, - MentionInvoices = mentionInvoices - }; - await AddMessageContentAsync(message, "InvoiceUpcoming", model); - message.Category = "InvoiceUpcoming"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) - { - var message = CreateDefaultMessage("Payment Failed", email); - var model = new PaymentFailedViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - Amount = amount, - MentionInvoices = mentionInvoices - }; - await AddMessageContentAsync(message, "PaymentFailed", model); - message.Category = "PaymentFailed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendAddedCreditAsync(string email, decimal amount) - { - var message = CreateDefaultMessage("Account Credit Payment Processed", email); - var model = new AddedCreditViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - Amount = amount - }; - await AddMessageContentAsync(message, "AddedCredit", model); - message.Category = "AddedCredit"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) - { - var message = CreateDefaultMessage("License Expired", emails); - var model = new LicenseExpiredViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - }; - await AddMessageContentAsync(message, "LicenseExpired", model); - message.Category = "LicenseExpired"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) - { - var message = CreateDefaultMessage($"New Device Logged In From {deviceType}", email); - var model = new NewDeviceLoggedInModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - DeviceType = deviceType, - TheDate = timestamp.ToLongDateString(), - TheTime = timestamp.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip - }; - await AddMessageContentAsync(message, "NewDeviceLoggedIn", model); - message.Category = "NewDeviceLoggedIn"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) - { - var message = CreateDefaultMessage($"Recover 2FA From {ip}", email); - var model = new RecoverTwoFactorModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - TheDate = timestamp.ToLongDateString(), - TheTime = timestamp.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip - }; - await AddMessageContentAsync(message, "RecoverTwoFactor", model); - message.Category = "RecoverTwoFactor"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) - { - var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); - var model = new OrganizationUserRemovedForPolicySingleOrgViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicySingleOrg", model); - message.Category = "OrganizationUserRemovedForPolicySingleOrg"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) - { - var message = CreateDefaultMessage(queueMessage.Subject, queueMessage.ToEmails); - message.BccEmails = queueMessage.BccEmails; - message.Category = queueMessage.Category; - await AddMessageContentAsync(message, queueMessage.TemplateName, queueMessage.Model); - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) - { - var message = CreateDefaultMessage("Master Password Has Been Changed", email); - var model = new AdminResetPasswordViewModel() - { - UserName = GetUserIdentifier(email, userName), - OrgName = CoreHelpers.SanitizeForEmail(orgName, false), - }; - await AddMessageContentAsync(message, "AdminResetPassword", model); - message.Category = "AdminResetPassword"; - await _mailDeliveryService.SendEmailAsync(message); - } - - private Task EnqueueMailAsync(IMailQueueMessage queueMessage) => - _mailEnqueuingService.EnqueueAsync(queueMessage, SendEnqueuedMailMessageAsync); - - private Task EnqueueMailAsync(IEnumerable queueMessages) => - _mailEnqueuingService.EnqueueManyAsync(queueMessages, SendEnqueuedMailMessageAsync); - - private MailMessage CreateDefaultMessage(string subject, string toEmail) - { - return CreateDefaultMessage(subject, new List { toEmail }); - } - - private MailMessage CreateDefaultMessage(string subject, IEnumerable toEmails) - { - return new MailMessage - { - ToEmails = toEmails, - Subject = subject, - MetaData = new Dictionary() - }; - } - - private async Task AddMessageContentAsync(MailMessage message, string templateName, T model) - { - message.HtmlContent = await RenderAsync($"{templateName}.html", model); - message.TextContent = await RenderAsync($"{templateName}.text", model); - } - - private async Task RenderAsync(string templateName, T model) - { - await RegisterHelpersAndPartialsAsync(); - if (!_templateCache.TryGetValue(templateName, out var template)) - { - var source = await ReadSourceAsync(templateName); - if (source != null) - { - template = Handlebars.Compile(source); - _templateCache.Add(templateName, template); - } - } - return template != null ? template(model) : null; - } - - private async Task ReadSourceAsync(string templateName) - { - var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; - var fullTemplateName = $"{Namespace}.{templateName}.hbs"; - if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) - { - return null; - } - using (var s = assembly.GetManifestResourceStream(fullTemplateName)) - using (var sr = new StreamReader(s)) - { - return await sr.ReadToEndAsync(); - } - } - - private async Task RegisterHelpersAndPartialsAsync() - { - if (_registeredHelpersAndPartials) - { - return; - } - _registeredHelpersAndPartials = true; - - var basicHtmlLayoutSource = await ReadSourceAsync("Layouts.Basic.html"); - Handlebars.RegisterTemplate("BasicHtmlLayout", basicHtmlLayoutSource); - var basicTextLayoutSource = await ReadSourceAsync("Layouts.Basic.text"); - Handlebars.RegisterTemplate("BasicTextLayout", basicTextLayoutSource); - var fullHtmlLayoutSource = await ReadSourceAsync("Layouts.Full.html"); - Handlebars.RegisterTemplate("FullHtmlLayout", fullHtmlLayoutSource); - var fullTextLayoutSource = await ReadSourceAsync("Layouts.Full.text"); - Handlebars.RegisterTemplate("FullTextLayout", fullTextLayoutSource); - - Handlebars.RegisterHelper("date", (writer, context, parameters) => - { - if (parameters.Length == 0 || !(parameters[0] is DateTime)) - { - writer.WriteSafeString(string.Empty); - return; - } - if (parameters.Length > 0 && parameters[1] is string) - { - writer.WriteSafeString(((DateTime)parameters[0]).ToString(parameters[1].ToString())); - } - else - { - writer.WriteSafeString(((DateTime)parameters[0]).ToString()); - } - }); - - Handlebars.RegisterHelper("usd", (writer, context, parameters) => - { - if (parameters.Length == 0 || !(parameters[0] is decimal)) - { - writer.WriteSafeString(string.Empty); - return; - } - writer.WriteSafeString(((decimal)parameters[0]).ToString("C")); - }); - - Handlebars.RegisterHelper("link", (writer, context, parameters) => - { - if (parameters.Length == 0) - { - writer.WriteSafeString(string.Empty); - return; - } - - var text = parameters[0].ToString(); - var href = text; - var clickTrackingOff = false; - if (parameters.Length == 2) - { - if (parameters[1] is string) - { - var p1 = parameters[1].ToString(); - if (p1 == "true" || p1 == "false") - { - clickTrackingOff = p1 == "true"; - } - else - { - href = p1; - } - } - else if (parameters[1] is bool) - { - clickTrackingOff = (bool)parameters[1]; - } - } - else if (parameters.Length > 2) - { - if (parameters[1] is string) - { - href = parameters[1].ToString(); - } - if (parameters[2] is string) - { - var p2 = parameters[2].ToString(); - if (p2 == "true" || p2 == "false") - { - clickTrackingOff = p2 == "true"; - } - } - else if (parameters[2] is bool) - { - clickTrackingOff = (bool)parameters[2]; - } - } - - var clickTrackingText = (clickTrackingOff ? "clicktracking=off" : string.Empty); - writer.WriteSafeString($"{text}"); - }); - } - - public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) - { - var message = CreateDefaultMessage($"Emergency Access Contact Invite", emergencyAccess.Email); - var model = new EmergencyAccessInvitedViewModel - { - Name = CoreHelpers.SanitizeForEmail(name), - Email = WebUtility.UrlEncode(emergencyAccess.Email), - Id = emergencyAccess.Id.ToString(), - Token = WebUtility.UrlEncode(token), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "EmergencyAccessInvited", model); - message.Category = "EmergencyAccessInvited"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) - { - var message = CreateDefaultMessage($"Accepted Emergency Access", email); - var model = new EmergencyAccessAcceptedViewModel - { - GranteeEmail = granteeEmail, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "EmergencyAccessAccepted", model); - message.Category = "EmergencyAccessAccepted"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) - { - var message = CreateDefaultMessage($"You Have Been Confirmed as Emergency Access Contact", email); - var model = new EmergencyAccessConfirmedViewModel - { - Name = CoreHelpers.SanitizeForEmail(grantorName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "EmergencyAccessConfirmed", model); - message.Category = "EmergencyAccessConfirmed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - var message = CreateDefaultMessage("Emergency Access Initiated", email); - - var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); - - var model = new EmergencyAccessRecoveryViewModel - { - Name = CoreHelpers.SanitizeForEmail(initiatingName), - Action = emergencyAccess.Type.ToString(), - DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), - }; - await AddMessageContentAsync(message, "EmergencyAccessRecovery", model); - message.Category = "EmergencyAccessRecovery"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) - { - var message = CreateDefaultMessage("Emergency Access Approved", email); - var model = new EmergencyAccessApprovedViewModel - { - Name = CoreHelpers.SanitizeForEmail(approvingName), - }; - await AddMessageContentAsync(message, "EmergencyAccessApproved", model); - message.Category = "EmergencyAccessApproved"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) - { - var message = CreateDefaultMessage("Emergency Access Rejected", email); - var model = new EmergencyAccessRejectedViewModel - { - Name = CoreHelpers.SanitizeForEmail(rejectingName), - }; - await AddMessageContentAsync(message, "EmergencyAccessRejected", model); - message.Category = "EmergencyAccessRejected"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - var message = CreateDefaultMessage("Pending Emergency Access Request", email); - - var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); - - var model = new EmergencyAccessRecoveryViewModel - { - Name = CoreHelpers.SanitizeForEmail(initiatingName), - Action = emergencyAccess.Type.ToString(), - DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), - }; - await AddMessageContentAsync(message, "EmergencyAccessRecoveryReminder", model); - message.Category = "EmergencyAccessRecoveryReminder"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - var message = CreateDefaultMessage("Emergency Access Granted", email); - var model = new EmergencyAccessRecoveryTimedOutViewModel - { - Name = CoreHelpers.SanitizeForEmail(initiatingName), - Action = emergencyAccess.Type.ToString(), - }; - await AddMessageContentAsync(message, "EmergencyAccessRecoveryTimedOut", model); - message.Category = "EmergencyAccessRecoveryTimedOut"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) - { - var message = CreateDefaultMessage($"Create a Provider", email); - var model = new ProviderSetupInviteViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - ProviderId = provider.Id.ToString(), - Email = WebUtility.UrlEncode(email), - Token = WebUtility.UrlEncode(token), - }; - await AddMessageContentAsync(message, "Provider.ProviderSetupInvite", model); - message.Category = "ProviderSetupInvite"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) - { - var message = CreateDefaultMessage($"Join {providerName}", email); - var model = new ProviderUserInvitedViewModel - { - ProviderName = CoreHelpers.SanitizeForEmail(providerName), - Email = WebUtility.UrlEncode(providerUser.Email), - ProviderId = providerUser.ProviderId.ToString(), - ProviderUserId = providerUser.Id.ToString(), - ProviderNameUrlEncoded = WebUtility.UrlEncode(providerName), - Token = WebUtility.UrlEncode(token), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - }; - await AddMessageContentAsync(message, "Provider.ProviderUserInvited", model); - message.Category = "ProviderSetupInvite"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendProviderConfirmedEmailAsync(string providerName, string email) - { - var message = CreateDefaultMessage($"You Have Been Confirmed To {providerName}", email); - var model = new ProviderUserConfirmedViewModel - { - ProviderName = CoreHelpers.SanitizeForEmail(providerName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "Provider.ProviderUserConfirmed", model); - message.Category = "ProviderUserConfirmed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendProviderUserRemoved(string providerName, string email) - { - var message = CreateDefaultMessage($"You Have Been Removed from {providerName}", email); - var model = new ProviderUserRemovedViewModel - { - ProviderName = CoreHelpers.SanitizeForEmail(providerName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "Provider.ProviderUserRemoved", model); - message.Category = "ProviderUserRemoved"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendUpdatedTempPasswordEmailAsync(string email, string userName) - { - var message = CreateDefaultMessage("Master Password Has Been Changed", email); - var model = new UpdateTempPasswordViewModel() - { - UserName = GetUserIdentifier(email, userName) - }; - await AddMessageContentAsync(message, "UpdatedTempPassword", model); - message.Category = "UpdatedTempPassword"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token) => - await BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsorOrgName, new[] { (email, existingAccount, token) }); - - public async Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) - { - MailQueueMessage CreateMessage((string Email, bool ExistingAccount, string Token) invite) - { - var message = CreateDefaultMessage("Accept Your Free Families Subscription", invite.Email); - message.Category = "FamiliesForEnterpriseOffer"; - var model = new FamiliesForEnterpriseOfferViewModel - { - SponsorOrgName = sponsorOrgName, - SponsoredEmail = WebUtility.UrlEncode(invite.Email), - ExistingAccount = invite.ExistingAccount, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - SponsorshipToken = invite.Token, + SiteName = _globalSettings.SiteName }; - var templateName = invite.ExistingAccount ? - "FamiliesForEnterprise.FamiliesForEnterpriseOfferExistingAccount" : - "FamiliesForEnterprise.FamiliesForEnterpriseOfferNewAccount"; - - return new MailQueueMessage(message, templateName, model); + await AddMessageContentAsync(message, "OrganizationUserConfirmed", model); + message.Category = "OrganizationUserConfirmed"; + await _mailDeliveryService.SendEmailAsync(message); } - var messageModels = invites.Select(invite => CreateMessage(invite)); - await EnqueueMailAsync(messageModels); - } - public async Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) - { - // Email family user - await SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(familyUserEmail); + public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) => + BulkSendOrganizationInviteEmailAsync(organizationName, new[] { (orgUser, token) }); - // Email enterprise org user - await SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(sponsorEmail); - } - - private async Task SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(string email) - { - var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); - await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToFamilyUser", new BaseMailModel()); - message.Category = "FamilyForEnterpriseRedeemedToFamilyUser"; - await _mailDeliveryService.SendEmailAsync(message); - } - - private async Task SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(string email) - { - var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); - await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToEnterpriseUser", new BaseMailModel()); - message.Category = "FamilyForEnterpriseRedeemedToEnterpriseUser"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) - { - var message = CreateDefaultMessage("Your Families Sponsorship was Removed", email); - var model = new FamiliesForEnterpriseSponsorshipRevertingViewModel + public async Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) { - ExpirationDate = expirationDate, - }; - await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseSponsorshipReverting", model); - message.Category = "FamiliesForEnterpriseSponsorshipReverting"; - await _mailDeliveryService.SendEmailAsync(message); - } + MailQueueMessage CreateMessage(string email, object model) + { + var message = CreateDefaultMessage($"Join {organizationName}", email); + return new MailQueueMessage(message, "OrganizationUserInvited", model); + } - public async Task SendOTPEmailAsync(string email, string token) - { - var message = CreateDefaultMessage("Your Bitwarden Verification Code", email); - var model = new EmailTokenViewModel + var messageModels = invites.Select(invite => CreateMessage(invite.orgUser.Email, + new OrganizationUserInvitedViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + Email = WebUtility.UrlEncode(invite.orgUser.Email), + OrganizationId = invite.orgUser.OrganizationId.ToString(), + OrganizationUserId = invite.orgUser.Id.ToString(), + Token = WebUtility.UrlEncode(invite.token.Token), + ExpirationDate = $"{invite.token.ExpirationDate.ToLongDateString()} {invite.token.ExpirationDate.ToShortTimeString()} UTC", + OrganizationNameUrlEncoded = WebUtility.UrlEncode(organizationName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + } + )); + + await EnqueueMailAsync(messageModels); + } + + public async Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - }; - await AddMessageContentAsync(message, "OTPEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "OTP"; - await _mailDeliveryService.SendEmailAsync(message); - } + var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); + var model = new OrganizationUserRemovedForPolicyTwoStepViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicyTwoStep", model); + message.Category = "OrganizationUserRemovedForPolicyTwoStep"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - var message = CreateDefaultMessage("Failed login attempts detected", email); - var model = new FailedAuthAttemptsModel() + public async Task SendWelcomeEmailAsync(User user) { - TheDate = utcNow.ToLongDateString(), - TheTime = utcNow.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip, - AffectedEmail = email + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new BaseMailModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "Welcome", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } - }; - await AddMessageContentAsync(message, "FailedLoginAttempts", model); - message.Category = "FailedLoginAttempts"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - var message = CreateDefaultMessage("Failed login attempts detected", email); - var model = new FailedAuthAttemptsModel() + public async Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) { - TheDate = utcNow.ToLongDateString(), - TheTime = utcNow.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip, - AffectedEmail = email + var message = CreateDefaultMessage("[Admin] Continue Logging In", email); + var url = CoreHelpers.ExtendQuery(new Uri($"{_globalSettings.BaseServiceUri.Admin}/login/confirm"), + new Dictionary + { + ["returnUrl"] = returnUrl, + ["email"] = email, + ["token"] = token, + }); + var model = new PasswordlessSignInModel + { + Url = url.ToString() + }; + await AddMessageContentAsync(message, "PasswordlessSignIn", model); + message.Category = "PasswordlessSignIn"; + await _mailDeliveryService.SendEmailAsync(message); + } - }; - await AddMessageContentAsync(message, "FailedTwoFactorAttempts", model); - message.Category = "FailedTwoFactorAttempts"; - await _mailDeliveryService.SendEmailAsync(message); - } + public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, + List items, bool mentionInvoices) + { + var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email); + var model = new InvoiceUpcomingViewModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + AmountDue = amount, + DueDate = dueDate, + Items = items, + MentionInvoices = mentionInvoices + }; + await AddMessageContentAsync(message, "InvoiceUpcoming", model); + message.Category = "InvoiceUpcoming"; + await _mailDeliveryService.SendEmailAsync(message); + } - private static string GetUserIdentifier(string email, string userName) - { - return string.IsNullOrEmpty(userName) ? email : CoreHelpers.SanitizeForEmail(userName, false); + public async Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) + { + var message = CreateDefaultMessage("Payment Failed", email); + var model = new PaymentFailedViewModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + Amount = amount, + MentionInvoices = mentionInvoices + }; + await AddMessageContentAsync(message, "PaymentFailed", model); + message.Category = "PaymentFailed"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendAddedCreditAsync(string email, decimal amount) + { + var message = CreateDefaultMessage("Account Credit Payment Processed", email); + var model = new AddedCreditViewModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + Amount = amount + }; + await AddMessageContentAsync(message, "AddedCredit", model); + message.Category = "AddedCredit"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) + { + var message = CreateDefaultMessage("License Expired", emails); + var model = new LicenseExpiredViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + }; + await AddMessageContentAsync(message, "LicenseExpired", model); + message.Category = "LicenseExpired"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) + { + var message = CreateDefaultMessage($"New Device Logged In From {deviceType}", email); + var model = new NewDeviceLoggedInModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + DeviceType = deviceType, + TheDate = timestamp.ToLongDateString(), + TheTime = timestamp.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip + }; + await AddMessageContentAsync(message, "NewDeviceLoggedIn", model); + message.Category = "NewDeviceLoggedIn"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) + { + var message = CreateDefaultMessage($"Recover 2FA From {ip}", email); + var model = new RecoverTwoFactorModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + TheDate = timestamp.ToLongDateString(), + TheTime = timestamp.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip + }; + await AddMessageContentAsync(message, "RecoverTwoFactor", model); + message.Category = "RecoverTwoFactor"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); + var model = new OrganizationUserRemovedForPolicySingleOrgViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicySingleOrg", model); + message.Category = "OrganizationUserRemovedForPolicySingleOrg"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) + { + var message = CreateDefaultMessage(queueMessage.Subject, queueMessage.ToEmails); + message.BccEmails = queueMessage.BccEmails; + message.Category = queueMessage.Category; + await AddMessageContentAsync(message, queueMessage.TemplateName, queueMessage.Model); + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + { + var message = CreateDefaultMessage("Master Password Has Been Changed", email); + var model = new AdminResetPasswordViewModel() + { + UserName = GetUserIdentifier(email, userName), + OrgName = CoreHelpers.SanitizeForEmail(orgName, false), + }; + await AddMessageContentAsync(message, "AdminResetPassword", model); + message.Category = "AdminResetPassword"; + await _mailDeliveryService.SendEmailAsync(message); + } + + private Task EnqueueMailAsync(IMailQueueMessage queueMessage) => + _mailEnqueuingService.EnqueueAsync(queueMessage, SendEnqueuedMailMessageAsync); + + private Task EnqueueMailAsync(IEnumerable queueMessages) => + _mailEnqueuingService.EnqueueManyAsync(queueMessages, SendEnqueuedMailMessageAsync); + + private MailMessage CreateDefaultMessage(string subject, string toEmail) + { + return CreateDefaultMessage(subject, new List { toEmail }); + } + + private MailMessage CreateDefaultMessage(string subject, IEnumerable toEmails) + { + return new MailMessage + { + ToEmails = toEmails, + Subject = subject, + MetaData = new Dictionary() + }; + } + + private async Task AddMessageContentAsync(MailMessage message, string templateName, T model) + { + message.HtmlContent = await RenderAsync($"{templateName}.html", model); + message.TextContent = await RenderAsync($"{templateName}.text", model); + } + + private async Task RenderAsync(string templateName, T model) + { + await RegisterHelpersAndPartialsAsync(); + if (!_templateCache.TryGetValue(templateName, out var template)) + { + var source = await ReadSourceAsync(templateName); + if (source != null) + { + template = Handlebars.Compile(source); + _templateCache.Add(templateName, template); + } + } + return template != null ? template(model) : null; + } + + private async Task ReadSourceAsync(string templateName) + { + var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; + var fullTemplateName = $"{Namespace}.{templateName}.hbs"; + if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) + { + return null; + } + using (var s = assembly.GetManifestResourceStream(fullTemplateName)) + using (var sr = new StreamReader(s)) + { + return await sr.ReadToEndAsync(); + } + } + + private async Task RegisterHelpersAndPartialsAsync() + { + if (_registeredHelpersAndPartials) + { + return; + } + _registeredHelpersAndPartials = true; + + var basicHtmlLayoutSource = await ReadSourceAsync("Layouts.Basic.html"); + Handlebars.RegisterTemplate("BasicHtmlLayout", basicHtmlLayoutSource); + var basicTextLayoutSource = await ReadSourceAsync("Layouts.Basic.text"); + Handlebars.RegisterTemplate("BasicTextLayout", basicTextLayoutSource); + var fullHtmlLayoutSource = await ReadSourceAsync("Layouts.Full.html"); + Handlebars.RegisterTemplate("FullHtmlLayout", fullHtmlLayoutSource); + var fullTextLayoutSource = await ReadSourceAsync("Layouts.Full.text"); + Handlebars.RegisterTemplate("FullTextLayout", fullTextLayoutSource); + + Handlebars.RegisterHelper("date", (writer, context, parameters) => + { + if (parameters.Length == 0 || !(parameters[0] is DateTime)) + { + writer.WriteSafeString(string.Empty); + return; + } + if (parameters.Length > 0 && parameters[1] is string) + { + writer.WriteSafeString(((DateTime)parameters[0]).ToString(parameters[1].ToString())); + } + else + { + writer.WriteSafeString(((DateTime)parameters[0]).ToString()); + } + }); + + Handlebars.RegisterHelper("usd", (writer, context, parameters) => + { + if (parameters.Length == 0 || !(parameters[0] is decimal)) + { + writer.WriteSafeString(string.Empty); + return; + } + writer.WriteSafeString(((decimal)parameters[0]).ToString("C")); + }); + + Handlebars.RegisterHelper("link", (writer, context, parameters) => + { + if (parameters.Length == 0) + { + writer.WriteSafeString(string.Empty); + return; + } + + var text = parameters[0].ToString(); + var href = text; + var clickTrackingOff = false; + if (parameters.Length == 2) + { + if (parameters[1] is string) + { + var p1 = parameters[1].ToString(); + if (p1 == "true" || p1 == "false") + { + clickTrackingOff = p1 == "true"; + } + else + { + href = p1; + } + } + else if (parameters[1] is bool) + { + clickTrackingOff = (bool)parameters[1]; + } + } + else if (parameters.Length > 2) + { + if (parameters[1] is string) + { + href = parameters[1].ToString(); + } + if (parameters[2] is string) + { + var p2 = parameters[2].ToString(); + if (p2 == "true" || p2 == "false") + { + clickTrackingOff = p2 == "true"; + } + } + else if (parameters[2] is bool) + { + clickTrackingOff = (bool)parameters[2]; + } + } + + var clickTrackingText = (clickTrackingOff ? "clicktracking=off" : string.Empty); + writer.WriteSafeString($"{text}"); + }); + } + + public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) + { + var message = CreateDefaultMessage($"Emergency Access Contact Invite", emergencyAccess.Email); + var model = new EmergencyAccessInvitedViewModel + { + Name = CoreHelpers.SanitizeForEmail(name), + Email = WebUtility.UrlEncode(emergencyAccess.Email), + Id = emergencyAccess.Id.ToString(), + Token = WebUtility.UrlEncode(token), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "EmergencyAccessInvited", model); + message.Category = "EmergencyAccessInvited"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) + { + var message = CreateDefaultMessage($"Accepted Emergency Access", email); + var model = new EmergencyAccessAcceptedViewModel + { + GranteeEmail = granteeEmail, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "EmergencyAccessAccepted", model); + message.Category = "EmergencyAccessAccepted"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) + { + var message = CreateDefaultMessage($"You Have Been Confirmed as Emergency Access Contact", email); + var model = new EmergencyAccessConfirmedViewModel + { + Name = CoreHelpers.SanitizeForEmail(grantorName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "EmergencyAccessConfirmed", model); + message.Category = "EmergencyAccessConfirmed"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Initiated", email); + + var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); + + var model = new EmergencyAccessRecoveryViewModel + { + Name = CoreHelpers.SanitizeForEmail(initiatingName), + Action = emergencyAccess.Type.ToString(), + DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), + }; + await AddMessageContentAsync(message, "EmergencyAccessRecovery", model); + message.Category = "EmergencyAccessRecovery"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Approved", email); + var model = new EmergencyAccessApprovedViewModel + { + Name = CoreHelpers.SanitizeForEmail(approvingName), + }; + await AddMessageContentAsync(message, "EmergencyAccessApproved", model); + message.Category = "EmergencyAccessApproved"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Rejected", email); + var model = new EmergencyAccessRejectedViewModel + { + Name = CoreHelpers.SanitizeForEmail(rejectingName), + }; + await AddMessageContentAsync(message, "EmergencyAccessRejected", model); + message.Category = "EmergencyAccessRejected"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + var message = CreateDefaultMessage("Pending Emergency Access Request", email); + + var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); + + var model = new EmergencyAccessRecoveryViewModel + { + Name = CoreHelpers.SanitizeForEmail(initiatingName), + Action = emergencyAccess.Type.ToString(), + DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), + }; + await AddMessageContentAsync(message, "EmergencyAccessRecoveryReminder", model); + message.Category = "EmergencyAccessRecoveryReminder"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Granted", email); + var model = new EmergencyAccessRecoveryTimedOutViewModel + { + Name = CoreHelpers.SanitizeForEmail(initiatingName), + Action = emergencyAccess.Type.ToString(), + }; + await AddMessageContentAsync(message, "EmergencyAccessRecoveryTimedOut", model); + message.Category = "EmergencyAccessRecoveryTimedOut"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) + { + var message = CreateDefaultMessage($"Create a Provider", email); + var model = new ProviderSetupInviteViewModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + ProviderId = provider.Id.ToString(), + Email = WebUtility.UrlEncode(email), + Token = WebUtility.UrlEncode(token), + }; + await AddMessageContentAsync(message, "Provider.ProviderSetupInvite", model); + message.Category = "ProviderSetupInvite"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) + { + var message = CreateDefaultMessage($"Join {providerName}", email); + var model = new ProviderUserInvitedViewModel + { + ProviderName = CoreHelpers.SanitizeForEmail(providerName), + Email = WebUtility.UrlEncode(providerUser.Email), + ProviderId = providerUser.ProviderId.ToString(), + ProviderUserId = providerUser.Id.ToString(), + ProviderNameUrlEncoded = WebUtility.UrlEncode(providerName), + Token = WebUtility.UrlEncode(token), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "Provider.ProviderUserInvited", model); + message.Category = "ProviderSetupInvite"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderConfirmedEmailAsync(string providerName, string email) + { + var message = CreateDefaultMessage($"You Have Been Confirmed To {providerName}", email); + var model = new ProviderUserConfirmedViewModel + { + ProviderName = CoreHelpers.SanitizeForEmail(providerName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "Provider.ProviderUserConfirmed", model); + message.Category = "ProviderUserConfirmed"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderUserRemoved(string providerName, string email) + { + var message = CreateDefaultMessage($"You Have Been Removed from {providerName}", email); + var model = new ProviderUserRemovedViewModel + { + ProviderName = CoreHelpers.SanitizeForEmail(providerName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "Provider.ProviderUserRemoved", model); + message.Category = "ProviderUserRemoved"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendUpdatedTempPasswordEmailAsync(string email, string userName) + { + var message = CreateDefaultMessage("Master Password Has Been Changed", email); + var model = new UpdateTempPasswordViewModel() + { + UserName = GetUserIdentifier(email, userName) + }; + await AddMessageContentAsync(message, "UpdatedTempPassword", model); + message.Category = "UpdatedTempPassword"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token) => + await BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsorOrgName, new[] { (email, existingAccount, token) }); + + public async Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) + { + MailQueueMessage CreateMessage((string Email, bool ExistingAccount, string Token) invite) + { + var message = CreateDefaultMessage("Accept Your Free Families Subscription", invite.Email); + message.Category = "FamiliesForEnterpriseOffer"; + var model = new FamiliesForEnterpriseOfferViewModel + { + SponsorOrgName = sponsorOrgName, + SponsoredEmail = WebUtility.UrlEncode(invite.Email), + ExistingAccount = invite.ExistingAccount, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + SponsorshipToken = invite.Token, + }; + var templateName = invite.ExistingAccount ? + "FamiliesForEnterprise.FamiliesForEnterpriseOfferExistingAccount" : + "FamiliesForEnterprise.FamiliesForEnterpriseOfferNewAccount"; + + return new MailQueueMessage(message, templateName, model); + } + var messageModels = invites.Select(invite => CreateMessage(invite)); + await EnqueueMailAsync(messageModels); + } + + public async Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) + { + // Email family user + await SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(familyUserEmail); + + // Email enterprise org user + await SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(sponsorEmail); + } + + private async Task SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(string email) + { + var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); + await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToFamilyUser", new BaseMailModel()); + message.Category = "FamilyForEnterpriseRedeemedToFamilyUser"; + await _mailDeliveryService.SendEmailAsync(message); + } + + private async Task SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(string email) + { + var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); + await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToEnterpriseUser", new BaseMailModel()); + message.Category = "FamilyForEnterpriseRedeemedToEnterpriseUser"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) + { + var message = CreateDefaultMessage("Your Families Sponsorship was Removed", email); + var model = new FamiliesForEnterpriseSponsorshipRevertingViewModel + { + ExpirationDate = expirationDate, + }; + await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseSponsorshipReverting", model); + message.Category = "FamiliesForEnterpriseSponsorshipReverting"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOTPEmailAsync(string email, string token) + { + var message = CreateDefaultMessage("Your Bitwarden Verification Code", email); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "OTPEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "OTP"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + var message = CreateDefaultMessage("Failed login attempts detected", email); + var model = new FailedAuthAttemptsModel() + { + TheDate = utcNow.ToLongDateString(), + TheTime = utcNow.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip, + AffectedEmail = email + + }; + await AddMessageContentAsync(message, "FailedLoginAttempts", model); + message.Category = "FailedLoginAttempts"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + var message = CreateDefaultMessage("Failed login attempts detected", email); + var model = new FailedAuthAttemptsModel() + { + TheDate = utcNow.ToLongDateString(), + TheTime = utcNow.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip, + AffectedEmail = email + + }; + await AddMessageContentAsync(message, "FailedTwoFactorAttempts", model); + message.Category = "FailedTwoFactorAttempts"; + await _mailDeliveryService.SendEmailAsync(message); + } + + private static string GetUserIdentifier(string email, string userName) + { + return string.IsNullOrEmpty(userName) ? email : CoreHelpers.SanitizeForEmail(userName, false); + } } } diff --git a/src/Core/Services/Implementations/I18nService.cs b/src/Core/Services/Implementations/I18nService.cs index 7d99dacbac..e9675ca58d 100644 --- a/src/Core/Services/Implementations/I18nService.cs +++ b/src/Core/Services/Implementations/I18nService.cs @@ -2,35 +2,36 @@ using Bit.Core.Resources; using Microsoft.Extensions.Localization; -namespace Bit.Core.Services; - -public class I18nService : II18nService +namespace Bit.Core.Services { - private readonly IStringLocalizer _localizer; - - public I18nService(IStringLocalizerFactory factory) + public class I18nService : II18nService { - var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); - _localizer = factory.Create("SharedResources", assemblyName.Name); - } + private readonly IStringLocalizer _localizer; - public LocalizedString GetLocalizedHtmlString(string key) - { - return _localizer[key]; - } + public I18nService(IStringLocalizerFactory factory) + { + var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); + _localizer = factory.Create("SharedResources", assemblyName.Name); + } - public LocalizedString GetLocalizedHtmlString(string key, params object[] args) - { - return _localizer[key, args]; - } + public LocalizedString GetLocalizedHtmlString(string key) + { + return _localizer[key]; + } - public string Translate(string key, params object[] args) - { - return string.Format(GetLocalizedHtmlString(key).ToString(), args); - } + public LocalizedString GetLocalizedHtmlString(string key, params object[] args) + { + return _localizer[key, args]; + } - public string T(string key, params object[] args) - { - return Translate(key, args); + public string Translate(string key, params object[] args) + { + return string.Format(GetLocalizedHtmlString(key).ToString(), args); + } + + public string T(string key, params object[] args) + { + return Translate(key, args); + } } } diff --git a/src/Core/Services/Implementations/I18nViewLocalizer.cs b/src/Core/Services/Implementations/I18nViewLocalizer.cs index 69699d9c4b..4a8d866786 100644 --- a/src/Core/Services/Implementations/I18nViewLocalizer.cs +++ b/src/Core/Services/Implementations/I18nViewLocalizer.cs @@ -3,28 +3,29 @@ using Bit.Core.Resources; using Microsoft.AspNetCore.Mvc.Localization; using Microsoft.Extensions.Localization; -namespace Bit.Core.Services; - -public class I18nViewLocalizer : IViewLocalizer +namespace Bit.Core.Services { - private readonly IStringLocalizer _stringLocalizer; - private readonly IHtmlLocalizer _htmlLocalizer; - - public I18nViewLocalizer(IStringLocalizerFactory stringFactory, - IHtmlLocalizerFactory htmlFactory) + public class I18nViewLocalizer : IViewLocalizer { - var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); - _stringLocalizer = stringFactory.Create("SharedResources", assemblyName.Name); - _htmlLocalizer = htmlFactory.Create("SharedResources", assemblyName.Name); + private readonly IStringLocalizer _stringLocalizer; + private readonly IHtmlLocalizer _htmlLocalizer; + + public I18nViewLocalizer(IStringLocalizerFactory stringFactory, + IHtmlLocalizerFactory htmlFactory) + { + var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); + _stringLocalizer = stringFactory.Create("SharedResources", assemblyName.Name); + _htmlLocalizer = htmlFactory.Create("SharedResources", assemblyName.Name); + } + + public LocalizedHtmlString this[string name] => _htmlLocalizer[name]; + public LocalizedHtmlString this[string name, params object[] args] => _htmlLocalizer[name, args]; + + public IEnumerable GetAllStrings(bool includeParentCultures) => + _stringLocalizer.GetAllStrings(includeParentCultures); + + public LocalizedString GetString(string name) => _stringLocalizer[name]; + public LocalizedString GetString(string name, params object[] arguments) => + _stringLocalizer[name, arguments]; } - - public LocalizedHtmlString this[string name] => _htmlLocalizer[name]; - public LocalizedHtmlString this[string name, params object[] args] => _htmlLocalizer[name, args]; - - public IEnumerable GetAllStrings(bool includeParentCultures) => - _stringLocalizer.GetAllStrings(includeParentCultures); - - public LocalizedString GetString(string name) => _stringLocalizer[name]; - public LocalizedString GetString(string name, params object[] arguments) => - _stringLocalizer[name, arguments]; } diff --git a/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs b/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs index dc23fcdb82..98333ff55c 100644 --- a/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs +++ b/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs @@ -4,96 +4,97 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class InMemoryApplicationCacheService : IApplicationCacheService +namespace Bit.Core.Services { - private readonly IOrganizationRepository _organizationRepository; - private readonly IProviderRepository _providerRepository; - private DateTime _lastOrgAbilityRefresh = DateTime.MinValue; - private IDictionary _orgAbilities; - private TimeSpan _orgAbilitiesRefreshInterval = TimeSpan.FromMinutes(10); - - private IDictionary _providerAbilities; - - public InMemoryApplicationCacheService( - IOrganizationRepository organizationRepository, IProviderRepository providerRepository) + public class InMemoryApplicationCacheService : IApplicationCacheService { - _organizationRepository = organizationRepository; - _providerRepository = providerRepository; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IProviderRepository _providerRepository; + private DateTime _lastOrgAbilityRefresh = DateTime.MinValue; + private IDictionary _orgAbilities; + private TimeSpan _orgAbilitiesRefreshInterval = TimeSpan.FromMinutes(10); - public virtual async Task> GetOrganizationAbilitiesAsync() - { - await InitOrganizationAbilitiesAsync(); - return _orgAbilities; - } + private IDictionary _providerAbilities; - public virtual async Task> GetProviderAbilitiesAsync() - { - await InitProviderAbilitiesAsync(); - return _providerAbilities; - } - - public virtual async Task UpsertProviderAbilityAsync(Provider provider) - { - await InitProviderAbilitiesAsync(); - var newAbility = new ProviderAbility(provider); - - if (_providerAbilities.ContainsKey(provider.Id)) + public InMemoryApplicationCacheService( + IOrganizationRepository organizationRepository, IProviderRepository providerRepository) { - _providerAbilities[provider.Id] = newAbility; - } - else - { - _providerAbilities.Add(provider.Id, newAbility); - } - } - - public virtual async Task UpsertOrganizationAbilityAsync(Organization organization) - { - await InitOrganizationAbilitiesAsync(); - var newAbility = new OrganizationAbility(organization); - - if (_orgAbilities.ContainsKey(organization.Id)) - { - _orgAbilities[organization.Id] = newAbility; - } - else - { - _orgAbilities.Add(organization.Id, newAbility); - } - } - - public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId) - { - if (_orgAbilities != null && _orgAbilities.ContainsKey(organizationId)) - { - _orgAbilities.Remove(organizationId); + _organizationRepository = organizationRepository; + _providerRepository = providerRepository; } - return Task.FromResult(0); - } - - private async Task InitOrganizationAbilitiesAsync() - { - var now = DateTime.UtcNow; - if (_orgAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) + public virtual async Task> GetOrganizationAbilitiesAsync() { - var abilities = await _organizationRepository.GetManyAbilitiesAsync(); - _orgAbilities = abilities.ToDictionary(a => a.Id); - _lastOrgAbilityRefresh = now; + await InitOrganizationAbilitiesAsync(); + return _orgAbilities; } - } - private async Task InitProviderAbilitiesAsync() - { - var now = DateTime.UtcNow; - if (_providerAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) + public virtual async Task> GetProviderAbilitiesAsync() { - var abilities = await _providerRepository.GetManyAbilitiesAsync(); - _providerAbilities = abilities.ToDictionary(a => a.Id); - _lastOrgAbilityRefresh = now; + await InitProviderAbilitiesAsync(); + return _providerAbilities; + } + + public virtual async Task UpsertProviderAbilityAsync(Provider provider) + { + await InitProviderAbilitiesAsync(); + var newAbility = new ProviderAbility(provider); + + if (_providerAbilities.ContainsKey(provider.Id)) + { + _providerAbilities[provider.Id] = newAbility; + } + else + { + _providerAbilities.Add(provider.Id, newAbility); + } + } + + public virtual async Task UpsertOrganizationAbilityAsync(Organization organization) + { + await InitOrganizationAbilitiesAsync(); + var newAbility = new OrganizationAbility(organization); + + if (_orgAbilities.ContainsKey(organization.Id)) + { + _orgAbilities[organization.Id] = newAbility; + } + else + { + _orgAbilities.Add(organization.Id, newAbility); + } + } + + public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId) + { + if (_orgAbilities != null && _orgAbilities.ContainsKey(organizationId)) + { + _orgAbilities.Remove(organizationId); + } + + return Task.FromResult(0); + } + + private async Task InitOrganizationAbilitiesAsync() + { + var now = DateTime.UtcNow; + if (_orgAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) + { + var abilities = await _organizationRepository.GetManyAbilitiesAsync(); + _orgAbilities = abilities.ToDictionary(a => a.Id); + _lastOrgAbilityRefresh = now; + } + } + + private async Task InitProviderAbilitiesAsync() + { + var now = DateTime.UtcNow; + if (_providerAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) + { + var abilities = await _providerRepository.GetManyAbilitiesAsync(); + _providerAbilities = abilities.ToDictionary(a => a.Id); + _lastOrgAbilityRefresh = now; + } } } } diff --git a/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs b/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs index 1c059e4ca6..c12efb4091 100644 --- a/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs +++ b/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs @@ -5,61 +5,62 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Azure.ServiceBus; -namespace Bit.Core.Services; - -public class InMemoryServiceBusApplicationCacheService : InMemoryApplicationCacheService, IApplicationCacheService +namespace Bit.Core.Services { - private readonly TopicClient _topicClient; - private readonly string _subName; - - public InMemoryServiceBusApplicationCacheService( - IOrganizationRepository organizationRepository, - IProviderRepository providerRepository, - GlobalSettings globalSettings) - : base(organizationRepository, providerRepository) + public class InMemoryServiceBusApplicationCacheService : InMemoryApplicationCacheService, IApplicationCacheService { - _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); - _topicClient = new TopicClient(globalSettings.ServiceBus.ConnectionString, - globalSettings.ServiceBus.ApplicationCacheTopicName); - } + private readonly TopicClient _topicClient; + private readonly string _subName; - public override async Task UpsertOrganizationAbilityAsync(Organization organization) - { - await base.UpsertOrganizationAbilityAsync(organization); - var message = new Message + public InMemoryServiceBusApplicationCacheService( + IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, + GlobalSettings globalSettings) + : base(organizationRepository, providerRepository) { - Label = _subName, - UserProperties = - { - { "type", (byte)ApplicationCacheMessageType.UpsertOrganizationAbility }, - { "id", organization.Id }, - } - }; - var task = _topicClient.SendAsync(message); - } + _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); + _topicClient = new TopicClient(globalSettings.ServiceBus.ConnectionString, + globalSettings.ServiceBus.ApplicationCacheTopicName); + } - public override async Task DeleteOrganizationAbilityAsync(Guid organizationId) - { - await base.DeleteOrganizationAbilityAsync(organizationId); - var message = new Message + public override async Task UpsertOrganizationAbilityAsync(Organization organization) { - Label = _subName, - UserProperties = + await base.UpsertOrganizationAbilityAsync(organization); + var message = new Message { - { "type", (byte)ApplicationCacheMessageType.DeleteOrganizationAbility }, - { "id", organizationId }, - } - }; - var task = _topicClient.SendAsync(message); - } + Label = _subName, + UserProperties = + { + { "type", (byte)ApplicationCacheMessageType.UpsertOrganizationAbility }, + { "id", organization.Id }, + } + }; + var task = _topicClient.SendAsync(message); + } - public async Task BaseUpsertOrganizationAbilityAsync(Organization organization) - { - await base.UpsertOrganizationAbilityAsync(organization); - } + public override async Task DeleteOrganizationAbilityAsync(Guid organizationId) + { + await base.DeleteOrganizationAbilityAsync(organizationId); + var message = new Message + { + Label = _subName, + UserProperties = + { + { "type", (byte)ApplicationCacheMessageType.DeleteOrganizationAbility }, + { "id", organizationId }, + } + }; + var task = _topicClient.SendAsync(message); + } - public async Task BaseDeleteOrganizationAbilityAsync(Guid organizationId) - { - await base.DeleteOrganizationAbilityAsync(organizationId); + public async Task BaseUpsertOrganizationAbilityAsync(Organization organization) + { + await base.UpsertOrganizationAbilityAsync(organization); + } + + public async Task BaseDeleteOrganizationAbilityAsync(Guid organizationId) + { + await base.DeleteOrganizationAbilityAsync(organizationId); + } } } diff --git a/src/Core/Services/Implementations/LicensingService.cs b/src/Core/Services/Implementations/LicensingService.cs index 893ea22689..70625e7584 100644 --- a/src/Core/Services/Implementations/LicensingService.cs +++ b/src/Core/Services/Implementations/LicensingService.cs @@ -10,251 +10,252 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class LicensingService : ILicensingService +namespace Bit.Core.Services { - private readonly X509Certificate2 _certificate; - private readonly IGlobalSettings _globalSettings; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IMailService _mailService; - private readonly ILogger _logger; - - private IDictionary _userCheckCache = new Dictionary(); - - public LicensingService( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IMailService mailService, - IWebHostEnvironment environment, - ILogger logger, - IGlobalSettings globalSettings) + public class LicensingService : ILicensingService { - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _mailService = mailService; - _logger = logger; - _globalSettings = globalSettings; + private readonly X509Certificate2 _certificate; + private readonly IGlobalSettings _globalSettings; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IMailService _mailService; + private readonly ILogger _logger; - var certThumbprint = environment.IsDevelopment() ? - "207E64A231E8AA32AAF68A61037C075EBEBD553F" : - "‎B34876439FCDA2846505B2EFBBA6C4A951313EBE"; - if (_globalSettings.SelfHosted) + private IDictionary _userCheckCache = new Dictionary(); + + public LicensingService( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IMailService mailService, + IWebHostEnvironment environment, + ILogger logger, + IGlobalSettings globalSettings) { - _certificate = CoreHelpers.GetEmbeddedCertificateAsync(environment.IsDevelopment() ? "licensing_dev.cer" : "licensing.cer", null) - .GetAwaiter().GetResult(); - } - else if (CoreHelpers.SettingHasValue(_globalSettings.Storage?.ConnectionString) && - CoreHelpers.SettingHasValue(_globalSettings.LicenseCertificatePassword)) - { - _certificate = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", - "licensing.pfx", _globalSettings.LicenseCertificatePassword) - .GetAwaiter().GetResult(); - } - else - { - _certificate = CoreHelpers.GetCertificate(certThumbprint); + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _mailService = mailService; + _logger = logger; + _globalSettings = globalSettings; + + var certThumbprint = environment.IsDevelopment() ? + "207E64A231E8AA32AAF68A61037C075EBEBD553F" : + "‎B34876439FCDA2846505B2EFBBA6C4A951313EBE"; + if (_globalSettings.SelfHosted) + { + _certificate = CoreHelpers.GetEmbeddedCertificateAsync(environment.IsDevelopment() ? "licensing_dev.cer" : "licensing.cer", null) + .GetAwaiter().GetResult(); + } + else if (CoreHelpers.SettingHasValue(_globalSettings.Storage?.ConnectionString) && + CoreHelpers.SettingHasValue(_globalSettings.LicenseCertificatePassword)) + { + _certificate = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", + "licensing.pfx", _globalSettings.LicenseCertificatePassword) + .GetAwaiter().GetResult(); + } + else + { + _certificate = CoreHelpers.GetCertificate(certThumbprint); + } + + if (_certificate == null || !_certificate.Thumbprint.Equals(CoreHelpers.CleanCertificateThumbprint(certThumbprint), + StringComparison.InvariantCultureIgnoreCase)) + { + throw new Exception("Invalid licensing certificate."); + } + + if (_globalSettings.SelfHosted && !CoreHelpers.SettingHasValue(_globalSettings.LicenseDirectory)) + { + throw new InvalidOperationException("No license directory."); + } } - if (_certificate == null || !_certificate.Thumbprint.Equals(CoreHelpers.CleanCertificateThumbprint(certThumbprint), - StringComparison.InvariantCultureIgnoreCase)) + public async Task ValidateOrganizationsAsync() { - throw new Exception("Invalid licensing certificate."); + if (!_globalSettings.SelfHosted) + { + return; + } + + var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Validating licenses for {0} organizations.", enabledOrgs.Count); + + foreach (var org in enabledOrgs) + { + var license = await ReadOrganizationLicenseAsync(org); + if (license == null) + { + await DisableOrganizationAsync(org, null, "No license file."); + continue; + } + + var totalLicensedOrgs = enabledOrgs.Count(o => o.LicenseKey.Equals(license.LicenseKey)); + if (totalLicensedOrgs > 1) + { + await DisableOrganizationAsync(org, license, "Multiple organizations."); + continue; + } + + if (!license.VerifyData(org, _globalSettings)) + { + await DisableOrganizationAsync(org, license, "Invalid data."); + continue; + } + + if (!license.VerifySignature(_certificate)) + { + await DisableOrganizationAsync(org, license, "Invalid signature."); + continue; + } + } } - if (_globalSettings.SelfHosted && !CoreHelpers.SettingHasValue(_globalSettings.LicenseDirectory)) + private async Task DisableOrganizationAsync(Organization org, ILicense license, string reason) { - throw new InvalidOperationException("No license directory."); - } - } + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Organization {0} ({1}) has an invalid license and is being disabled. Reason: {2}", + org.Id, org.Name, reason); + org.Enabled = false; + org.ExpirationDate = license?.Expires ?? DateTime.UtcNow; + org.RevisionDate = DateTime.UtcNow; + await _organizationRepository.ReplaceAsync(org); - public async Task ValidateOrganizationsAsync() - { - if (!_globalSettings.SelfHosted) - { - return; + await _mailService.SendLicenseExpiredAsync(new List { org.BillingEmail }, org.Name); } - var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Validating licenses for {0} organizations.", enabledOrgs.Count); - - foreach (var org in enabledOrgs) + public async Task ValidateUsersAsync() { - var license = await ReadOrganizationLicenseAsync(org); + if (!_globalSettings.SelfHosted) + { + return; + } + + var premiumUsers = await _userRepository.GetManyByPremiumAsync(true); + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Validating premium for {0} users.", premiumUsers.Count); + + foreach (var user in premiumUsers) + { + await ProcessUserValidationAsync(user); + } + } + + public async Task ValidateUserPremiumAsync(User user) + { + if (!_globalSettings.SelfHosted) + { + return user.Premium; + } + + if (!user.Premium) + { + return false; + } + + // Only check once per day + var now = DateTime.UtcNow; + if (_userCheckCache.ContainsKey(user.Id)) + { + var lastCheck = _userCheckCache[user.Id]; + if (lastCheck < now && now - lastCheck < TimeSpan.FromDays(1)) + { + return user.Premium; + } + else + { + _userCheckCache[user.Id] = now; + } + } + else + { + _userCheckCache.Add(user.Id, now); + } + + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Validating premium license for user {0}({1}).", user.Id, user.Email); + return await ProcessUserValidationAsync(user); + } + + private async Task ProcessUserValidationAsync(User user) + { + var license = ReadUserLicense(user); if (license == null) { - await DisableOrganizationAsync(org, null, "No license file."); - continue; + await DisablePremiumAsync(user, null, "No license file."); + return false; } - var totalLicensedOrgs = enabledOrgs.Count(o => o.LicenseKey.Equals(license.LicenseKey)); - if (totalLicensedOrgs > 1) + if (!license.VerifyData(user)) { - await DisableOrganizationAsync(org, license, "Multiple organizations."); - continue; - } - - if (!license.VerifyData(org, _globalSettings)) - { - await DisableOrganizationAsync(org, license, "Invalid data."); - continue; + await DisablePremiumAsync(user, license, "Invalid data."); + return false; } if (!license.VerifySignature(_certificate)) { - await DisableOrganizationAsync(org, license, "Invalid signature."); - continue; + await DisablePremiumAsync(user, license, "Invalid signature."); + return false; } - } - } - private async Task DisableOrganizationAsync(Organization org, ILicense license, string reason) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Organization {0} ({1}) has an invalid license and is being disabled. Reason: {2}", - org.Id, org.Name, reason); - org.Enabled = false; - org.ExpirationDate = license?.Expires ?? DateTime.UtcNow; - org.RevisionDate = DateTime.UtcNow; - await _organizationRepository.ReplaceAsync(org); - - await _mailService.SendLicenseExpiredAsync(new List { org.BillingEmail }, org.Name); - } - - public async Task ValidateUsersAsync() - { - if (!_globalSettings.SelfHosted) - { - return; + return true; } - var premiumUsers = await _userRepository.GetManyByPremiumAsync(true); - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Validating premium for {0} users.", premiumUsers.Count); - - foreach (var user in premiumUsers) + private async Task DisablePremiumAsync(User user, ILicense license, string reason) { - await ProcessUserValidationAsync(user); - } - } + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "User {0}({1}) has an invalid license and premium is being disabled. Reason: {2}", + user.Id, user.Email, reason); - public async Task ValidateUserPremiumAsync(User user) - { - if (!_globalSettings.SelfHosted) - { - return user.Premium; + user.Premium = false; + user.PremiumExpirationDate = license?.Expires ?? DateTime.UtcNow; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + + await _mailService.SendLicenseExpiredAsync(new List { user.Email }); } - if (!user.Premium) + public bool VerifyLicense(ILicense license) { - return false; + return license.VerifySignature(_certificate); } - // Only check once per day - var now = DateTime.UtcNow; - if (_userCheckCache.ContainsKey(user.Id)) + public byte[] SignLicense(ILicense license) { - var lastCheck = _userCheckCache[user.Id]; - if (lastCheck < now && now - lastCheck < TimeSpan.FromDays(1)) + if (_globalSettings.SelfHosted || !_certificate.HasPrivateKey) { - return user.Premium; + throw new InvalidOperationException("Cannot sign licenses."); } - else + + return license.Sign(_certificate); + } + + private UserLicense ReadUserLicense(User user) + { + var filePath = $"{_globalSettings.LicenseDirectory}/user/{user.Id}.json"; + if (!File.Exists(filePath)) { - _userCheckCache[user.Id] = now; + return null; } + + var data = File.ReadAllText(filePath, Encoding.UTF8); + return JsonSerializer.Deserialize(data); } - else + + public Task ReadOrganizationLicenseAsync(Organization organization) => + ReadOrganizationLicenseAsync(organization.Id); + public async Task ReadOrganizationLicenseAsync(Guid organizationId) { - _userCheckCache.Add(user.Id, now); + var filePath = Path.Combine(_globalSettings.LicenseDirectory, "organization", $"{organizationId}.json"); + if (!File.Exists(filePath)) + { + return null; + } + + using var fs = File.OpenRead(filePath); + return await JsonSerializer.DeserializeAsync(fs); } - - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Validating premium license for user {0}({1}).", user.Id, user.Email); - return await ProcessUserValidationAsync(user); - } - - private async Task ProcessUserValidationAsync(User user) - { - var license = ReadUserLicense(user); - if (license == null) - { - await DisablePremiumAsync(user, null, "No license file."); - return false; - } - - if (!license.VerifyData(user)) - { - await DisablePremiumAsync(user, license, "Invalid data."); - return false; - } - - if (!license.VerifySignature(_certificate)) - { - await DisablePremiumAsync(user, license, "Invalid signature."); - return false; - } - - return true; - } - - private async Task DisablePremiumAsync(User user, ILicense license, string reason) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "User {0}({1}) has an invalid license and premium is being disabled. Reason: {2}", - user.Id, user.Email, reason); - - user.Premium = false; - user.PremiumExpirationDate = license?.Expires ?? DateTime.UtcNow; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - - await _mailService.SendLicenseExpiredAsync(new List { user.Email }); - } - - public bool VerifyLicense(ILicense license) - { - return license.VerifySignature(_certificate); - } - - public byte[] SignLicense(ILicense license) - { - if (_globalSettings.SelfHosted || !_certificate.HasPrivateKey) - { - throw new InvalidOperationException("Cannot sign licenses."); - } - - return license.Sign(_certificate); - } - - private UserLicense ReadUserLicense(User user) - { - var filePath = $"{_globalSettings.LicenseDirectory}/user/{user.Id}.json"; - if (!File.Exists(filePath)) - { - return null; - } - - var data = File.ReadAllText(filePath, Encoding.UTF8); - return JsonSerializer.Deserialize(data); - } - - public Task ReadOrganizationLicenseAsync(Organization organization) => - ReadOrganizationLicenseAsync(organization.Id); - public async Task ReadOrganizationLicenseAsync(Guid organizationId) - { - var filePath = Path.Combine(_globalSettings.LicenseDirectory, "organization", $"{organizationId}.json"); - if (!File.Exists(filePath)) - { - return null; - } - - using var fs = File.OpenRead(filePath); - return await JsonSerializer.DeserializeAsync(fs); } } diff --git a/src/Core/Services/Implementations/LocalAttachmentStorageService.cs b/src/Core/Services/Implementations/LocalAttachmentStorageService.cs index 4949ff3128..d24a561e35 100644 --- a/src/Core/Services/Implementations/LocalAttachmentStorageService.cs +++ b/src/Core/Services/Implementations/LocalAttachmentStorageService.cs @@ -3,194 +3,195 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Settings; -namespace Bit.Core.Services; - -public class LocalAttachmentStorageService : IAttachmentStorageService +namespace Bit.Core.Services { - private readonly string _baseAttachmentUrl; - private readonly string _baseDirPath; - private readonly string _baseTempDirPath; - - public FileUploadType FileUploadType => FileUploadType.Direct; - - public LocalAttachmentStorageService( - IGlobalSettings globalSettings) + public class LocalAttachmentStorageService : IAttachmentStorageService { - _baseDirPath = globalSettings.Attachment.BaseDirectory; - _baseTempDirPath = $"{_baseDirPath}/temp"; - _baseAttachmentUrl = globalSettings.Attachment.BaseUrl; - } + private readonly string _baseAttachmentUrl; + private readonly string _baseDirPath; + private readonly string _baseTempDirPath; - public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - return $"{_baseAttachmentUrl}/{cipher.Id}/{attachmentData.AttachmentId}"; - } + public FileUploadType FileUploadType => FileUploadType.Direct; - public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - var cipherDirPath = CipherDirectoryPath(cipher.Id, temp: false); - CreateDirectoryIfNotExists(cipherDirPath); - - using (var fs = File.Create(AttachmentFilePath(cipherDirPath, attachmentData.AttachmentId))) + public LocalAttachmentStorageService( + IGlobalSettings globalSettings) { - stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(fs); - } - } - - public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - var tempCipherOrgDirPath = OrganizationDirectoryPath(cipherId, organizationId, temp: true); - CreateDirectoryIfNotExists(tempCipherOrgDirPath); - - using (var fs = File.Create(AttachmentFilePath(tempCipherOrgDirPath, attachmentData.AttachmentId))) - { - stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(fs); - } - } - - public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - var sourceFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true); - if (!File.Exists(sourceFilePath)) - { - return; + _baseDirPath = globalSettings.Attachment.BaseDirectory; + _baseTempDirPath = $"{_baseDirPath}/temp"; + _baseAttachmentUrl = globalSettings.Attachment.BaseUrl; } - var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); - if (!File.Exists(destFilePath)) + public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) { - return; + await InitAsync(); + return $"{_baseAttachmentUrl}/{cipher.Id}/{attachmentData.AttachmentId}"; } - var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); - DeleteFileIfExists(originalFilePath); - - File.Move(destFilePath, originalFilePath); - DeleteFileIfExists(destFilePath); - - File.Move(sourceFilePath, destFilePath); - } - - public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) - { - await InitAsync(); - DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true)); - - var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); - if (!File.Exists(originalFilePath)) + public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) { - return; + await InitAsync(); + var cipherDirPath = CipherDirectoryPath(cipher.Id, temp: false); + CreateDirectoryIfNotExists(cipherDirPath); + + using (var fs = File.Create(AttachmentFilePath(cipherDirPath, attachmentData.AttachmentId))) + { + stream.Seek(0, SeekOrigin.Begin); + await stream.CopyToAsync(fs); + } } - var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); - DeleteFileIfExists(destFilePath); - - File.Move(originalFilePath, destFilePath); - DeleteFileIfExists(originalFilePath); - } - - public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false)); - } - - public async Task CleanupAsync(Guid cipherId) - { - await InitAsync(); - DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: true)); - } - - public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) - { - await InitAsync(); - DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: false)); - } - - public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) - { - await InitAsync(); - } - - public async Task DeleteAttachmentsForUserAsync(Guid userId) - { - await InitAsync(); - } - - private void DeleteFileIfExists(string path) - { - if (File.Exists(path)) + public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) { - File.Delete(path); - } - } + await InitAsync(); + var tempCipherOrgDirPath = OrganizationDirectoryPath(cipherId, organizationId, temp: true); + CreateDirectoryIfNotExists(tempCipherOrgDirPath); - private void DeleteDirectoryIfExists(string path) - { - if (Directory.Exists(path)) - { - Directory.Delete(path, true); - } - } - - private void CreateDirectoryIfNotExists(string path) - { - if (!Directory.Exists(path)) - { - Directory.CreateDirectory(path); - } - } - - private Task InitAsync() - { - if (!Directory.Exists(_baseDirPath)) - { - Directory.CreateDirectory(_baseDirPath); + using (var fs = File.Create(AttachmentFilePath(tempCipherOrgDirPath, attachmentData.AttachmentId))) + { + stream.Seek(0, SeekOrigin.Begin); + await stream.CopyToAsync(fs); + } } - if (!Directory.Exists(_baseTempDirPath)) + public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) { - Directory.CreateDirectory(_baseTempDirPath); + await InitAsync(); + var sourceFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true); + if (!File.Exists(sourceFilePath)) + { + return; + } + + var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); + if (!File.Exists(destFilePath)) + { + return; + } + + var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); + DeleteFileIfExists(originalFilePath); + + File.Move(destFilePath, originalFilePath); + DeleteFileIfExists(destFilePath); + + File.Move(sourceFilePath, destFilePath); } - return Task.FromResult(0); - } - - private string CipherDirectoryPath(Guid cipherId, bool temp = false) => - Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString()); - private string OrganizationDirectoryPath(Guid cipherId, Guid organizationId, bool temp = false) => - Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString(), organizationId.ToString()); - - private string AttachmentFilePath(string dir, string attachmentId) => Path.Combine(dir, attachmentId); - private string AttachmentFilePath(string attachmentId, Guid cipherId, Guid? organizationId = null, - bool temp = false) => - organizationId.HasValue ? - AttachmentFilePath(OrganizationDirectoryPath(cipherId, organizationId.Value, temp), attachmentId) : - AttachmentFilePath(CipherDirectoryPath(cipherId, temp), attachmentId); - public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - => Task.FromResult($"{cipher.Id}/attachment/{attachmentData.AttachmentId}"); - - public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) - { - long? length = null; - var path = AttachmentFilePath(attachmentData.AttachmentId, cipher.Id, temp: false); - if (!File.Exists(path)) + public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) { - return Task.FromResult((false, length)); + await InitAsync(); + DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true)); + + var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); + if (!File.Exists(originalFilePath)) + { + return; + } + + var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); + DeleteFileIfExists(destFilePath); + + File.Move(originalFilePath, destFilePath); + DeleteFileIfExists(originalFilePath); } - length = new FileInfo(path).Length; - if (attachmentData.Size < length - leeway || attachmentData.Size > length + leeway) + public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) { - return Task.FromResult((false, length)); + await InitAsync(); + DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false)); } - return Task.FromResult((true, length)); + public async Task CleanupAsync(Guid cipherId) + { + await InitAsync(); + DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: true)); + } + + public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) + { + await InitAsync(); + DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: false)); + } + + public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) + { + await InitAsync(); + } + + public async Task DeleteAttachmentsForUserAsync(Guid userId) + { + await InitAsync(); + } + + private void DeleteFileIfExists(string path) + { + if (File.Exists(path)) + { + File.Delete(path); + } + } + + private void DeleteDirectoryIfExists(string path) + { + if (Directory.Exists(path)) + { + Directory.Delete(path, true); + } + } + + private void CreateDirectoryIfNotExists(string path) + { + if (!Directory.Exists(path)) + { + Directory.CreateDirectory(path); + } + } + + private Task InitAsync() + { + if (!Directory.Exists(_baseDirPath)) + { + Directory.CreateDirectory(_baseDirPath); + } + + if (!Directory.Exists(_baseTempDirPath)) + { + Directory.CreateDirectory(_baseTempDirPath); + } + + return Task.FromResult(0); + } + + private string CipherDirectoryPath(Guid cipherId, bool temp = false) => + Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString()); + private string OrganizationDirectoryPath(Guid cipherId, Guid organizationId, bool temp = false) => + Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString(), organizationId.ToString()); + + private string AttachmentFilePath(string dir, string attachmentId) => Path.Combine(dir, attachmentId); + private string AttachmentFilePath(string attachmentId, Guid cipherId, Guid? organizationId = null, + bool temp = false) => + organizationId.HasValue ? + AttachmentFilePath(OrganizationDirectoryPath(cipherId, organizationId.Value, temp), attachmentId) : + AttachmentFilePath(CipherDirectoryPath(cipherId, temp), attachmentId); + public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + => Task.FromResult($"{cipher.Id}/attachment/{attachmentData.AttachmentId}"); + + public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) + { + long? length = null; + var path = AttachmentFilePath(attachmentData.AttachmentId, cipher.Id, temp: false); + if (!File.Exists(path)) + { + return Task.FromResult((false, length)); + } + + length = new FileInfo(path).Length; + if (attachmentData.Size < length - leeway || attachmentData.Size > length + leeway) + { + return Task.FromResult((false, length)); + } + + return Task.FromResult((true, length)); + } } } diff --git a/src/Core/Services/Implementations/LocalSendStorageService.cs b/src/Core/Services/Implementations/LocalSendStorageService.cs index 30872cbcca..200309f5bd 100644 --- a/src/Core/Services/Implementations/LocalSendStorageService.cs +++ b/src/Core/Services/Implementations/LocalSendStorageService.cs @@ -2,104 +2,105 @@ using Bit.Core.Enums; using Bit.Core.Settings; -namespace Bit.Core.Services; - -public class LocalSendStorageService : ISendFileStorageService +namespace Bit.Core.Services { - private readonly string _baseDirPath; - private readonly string _baseSendUrl; - - private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}"; - private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}"; - public FileUploadType FileUploadType => FileUploadType.Direct; - - public LocalSendStorageService( - GlobalSettings globalSettings) + public class LocalSendStorageService : ISendFileStorageService { - _baseDirPath = globalSettings.Send.BaseDirectory; - _baseSendUrl = globalSettings.Send.BaseUrl; - } + private readonly string _baseDirPath; + private readonly string _baseSendUrl; - public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) - { - await InitAsync(); - var path = FilePath(send, fileId); - Directory.CreateDirectory(Path.GetDirectoryName(path)); - using (var fs = File.Create(path)) + private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}"; + private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}"; + public FileUploadType FileUploadType => FileUploadType.Direct; + + public LocalSendStorageService( + GlobalSettings globalSettings) { - stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(fs); - } - } - - public async Task DeleteFileAsync(Send send, string fileId) - { - await InitAsync(); - var path = FilePath(send, fileId); - DeleteFileIfExists(path); - DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path)); - } - - public async Task DeleteFilesForOrganizationAsync(Guid organizationId) - { - await InitAsync(); - } - - public async Task DeleteFilesForUserAsync(Guid userId) - { - await InitAsync(); - } - - public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) - { - await InitAsync(); - return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}"; - } - - private void DeleteFileIfExists(string path) - { - if (File.Exists(path)) - { - File.Delete(path); - } - } - - private void DeleteDirectoryIfExistsAndEmpty(string path) - { - if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any()) - { - Directory.Delete(path); - } - } - - private Task InitAsync() - { - if (!Directory.Exists(_baseDirPath)) - { - Directory.CreateDirectory(_baseDirPath); + _baseDirPath = globalSettings.Send.BaseDirectory; + _baseSendUrl = globalSettings.Send.BaseUrl; } - return Task.FromResult(0); - } - - public Task GetSendFileUploadUrlAsync(Send send, string fileId) - => Task.FromResult($"/sends/{send.Id}/file/{fileId}"); - - public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) - { - long? length = null; - var path = FilePath(send, fileId); - if (!File.Exists(path)) + public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) { - return Task.FromResult((false, length)); + await InitAsync(); + var path = FilePath(send, fileId); + Directory.CreateDirectory(Path.GetDirectoryName(path)); + using (var fs = File.Create(path)) + { + stream.Seek(0, SeekOrigin.Begin); + await stream.CopyToAsync(fs); + } } - length = new FileInfo(path).Length; - if (expectedFileSize < length - leeway || expectedFileSize > length + leeway) + public async Task DeleteFileAsync(Send send, string fileId) { - return Task.FromResult((false, length)); + await InitAsync(); + var path = FilePath(send, fileId); + DeleteFileIfExists(path); + DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path)); } - return Task.FromResult((true, length)); + public async Task DeleteFilesForOrganizationAsync(Guid organizationId) + { + await InitAsync(); + } + + public async Task DeleteFilesForUserAsync(Guid userId) + { + await InitAsync(); + } + + public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) + { + await InitAsync(); + return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}"; + } + + private void DeleteFileIfExists(string path) + { + if (File.Exists(path)) + { + File.Delete(path); + } + } + + private void DeleteDirectoryIfExistsAndEmpty(string path) + { + if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any()) + { + Directory.Delete(path); + } + } + + private Task InitAsync() + { + if (!Directory.Exists(_baseDirPath)) + { + Directory.CreateDirectory(_baseDirPath); + } + + return Task.FromResult(0); + } + + public Task GetSendFileUploadUrlAsync(Send send, string fileId) + => Task.FromResult($"/sends/{send.Id}/file/{fileId}"); + + public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) + { + long? length = null; + var path = FilePath(send, fileId); + if (!File.Exists(path)) + { + return Task.FromResult((false, length)); + } + + length = new FileInfo(path).Length; + if (expectedFileSize < length - leeway || expectedFileSize > length + leeway) + { + return Task.FromResult((false, length)); + } + + return Task.FromResult((true, length)); + } } } diff --git a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs b/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs index 4e7b7ee105..b4b93278e5 100644 --- a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs +++ b/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs @@ -4,97 +4,98 @@ using MailKit.Net.Smtp; using Microsoft.Extensions.Logging; using MimeKit; -namespace Bit.Core.Services; - -public class MailKitSmtpMailDeliveryService : IMailDeliveryService +namespace Bit.Core.Services { - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - private readonly string _replyDomain; - private readonly string _replyEmail; - - public MailKitSmtpMailDeliveryService( - GlobalSettings globalSettings, - ILogger logger) + public class MailKitSmtpMailDeliveryService : IMailDeliveryService { - if (globalSettings.Mail?.Smtp?.Host == null) + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + private readonly string _replyDomain; + private readonly string _replyEmail; + + public MailKitSmtpMailDeliveryService( + GlobalSettings globalSettings, + ILogger logger) { - throw new ArgumentNullException(nameof(globalSettings.Mail.Smtp.Host)); + if (globalSettings.Mail?.Smtp?.Host == null) + { + throw new ArgumentNullException(nameof(globalSettings.Mail.Smtp.Host)); + } + + _replyEmail = CoreHelpers.PunyEncode(globalSettings.Mail?.ReplyToEmail); + + if (_replyEmail.Contains("@")) + { + _replyDomain = _replyEmail.Split('@')[1]; + } + + _globalSettings = globalSettings; + _logger = logger; } - _replyEmail = CoreHelpers.PunyEncode(globalSettings.Mail?.ReplyToEmail); - - if (_replyEmail.Contains("@")) + public async Task SendEmailAsync(Models.Mail.MailMessage message) { - _replyDomain = _replyEmail.Split('@')[1]; - } + var mimeMessage = new MimeMessage(); + mimeMessage.From.Add(new MailboxAddress(_globalSettings.SiteName, _replyEmail)); + mimeMessage.Subject = message.Subject; + if (!string.IsNullOrWhiteSpace(_replyDomain)) + { + mimeMessage.MessageId = $"<{Guid.NewGuid()}@{_replyDomain}>"; + } - _globalSettings = globalSettings; - _logger = logger; - } - - public async Task SendEmailAsync(Models.Mail.MailMessage message) - { - var mimeMessage = new MimeMessage(); - mimeMessage.From.Add(new MailboxAddress(_globalSettings.SiteName, _replyEmail)); - mimeMessage.Subject = message.Subject; - if (!string.IsNullOrWhiteSpace(_replyDomain)) - { - mimeMessage.MessageId = $"<{Guid.NewGuid()}@{_replyDomain}>"; - } - - foreach (var address in message.ToEmails) - { - var punyencoded = CoreHelpers.PunyEncode(address); - mimeMessage.To.Add(MailboxAddress.Parse(punyencoded)); - } - - if (message.BccEmails != null) - { - foreach (var address in message.BccEmails) + foreach (var address in message.ToEmails) { var punyencoded = CoreHelpers.PunyEncode(address); - mimeMessage.Bcc.Add(MailboxAddress.Parse(punyencoded)); + mimeMessage.To.Add(MailboxAddress.Parse(punyencoded)); } - } - var builder = new BodyBuilder(); - if (!string.IsNullOrWhiteSpace(message.TextContent)) - { - builder.TextBody = message.TextContent; - } - builder.HtmlBody = message.HtmlContent; - mimeMessage.Body = builder.ToMessageBody(); - - using (var client = new SmtpClient()) - { - if (_globalSettings.Mail.Smtp.TrustServer) + if (message.BccEmails != null) { - client.ServerCertificateValidationCallback = (s, c, h, e) => true; + foreach (var address in message.BccEmails) + { + var punyencoded = CoreHelpers.PunyEncode(address); + mimeMessage.Bcc.Add(MailboxAddress.Parse(punyencoded)); + } } - if (!_globalSettings.Mail.Smtp.StartTls && !_globalSettings.Mail.Smtp.Ssl && - _globalSettings.Mail.Smtp.Port == 25) + var builder = new BodyBuilder(); + if (!string.IsNullOrWhiteSpace(message.TextContent)) { - await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, - MailKit.Security.SecureSocketOptions.None); - } - else - { - var useSsl = _globalSettings.Mail.Smtp.Port == 587 && !_globalSettings.Mail.Smtp.SslOverride ? - false : _globalSettings.Mail.Smtp.Ssl; - await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, useSsl); + builder.TextBody = message.TextContent; } + builder.HtmlBody = message.HtmlContent; + mimeMessage.Body = builder.ToMessageBody(); - if (CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Username) && - CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Password)) + using (var client = new SmtpClient()) { - await client.AuthenticateAsync(_globalSettings.Mail.Smtp.Username, - _globalSettings.Mail.Smtp.Password); - } + if (_globalSettings.Mail.Smtp.TrustServer) + { + client.ServerCertificateValidationCallback = (s, c, h, e) => true; + } - await client.SendAsync(mimeMessage); - await client.DisconnectAsync(true); + if (!_globalSettings.Mail.Smtp.StartTls && !_globalSettings.Mail.Smtp.Ssl && + _globalSettings.Mail.Smtp.Port == 25) + { + await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, + MailKit.Security.SecureSocketOptions.None); + } + else + { + var useSsl = _globalSettings.Mail.Smtp.Port == 587 && !_globalSettings.Mail.Smtp.SslOverride ? + false : _globalSettings.Mail.Smtp.Ssl; + await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, useSsl); + } + + if (CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Username) && + CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Password)) + { + await client.AuthenticateAsync(_globalSettings.Mail.Smtp.Username, + _globalSettings.Mail.Smtp.Password); + } + + await client.SendAsync(mimeMessage); + await client.DisconnectAsync(true); + } } } } diff --git a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs b/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs index e088410967..286415fc27 100644 --- a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs +++ b/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs @@ -3,39 +3,40 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class MultiServiceMailDeliveryService : IMailDeliveryService +namespace Bit.Core.Services { - private readonly IMailDeliveryService _sesService; - private readonly IMailDeliveryService _sendGridService; - private readonly int _sendGridPercentage; - - private static Random _random = new Random(); - - public MultiServiceMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger sesLogger, - ILogger sendGridLogger) + public class MultiServiceMailDeliveryService : IMailDeliveryService { - _sesService = new AmazonSesMailDeliveryService(globalSettings, hostingEnvironment, sesLogger); - _sendGridService = new SendGridMailDeliveryService(globalSettings, hostingEnvironment, sendGridLogger); + private readonly IMailDeliveryService _sesService; + private readonly IMailDeliveryService _sendGridService; + private readonly int _sendGridPercentage; - // disabled by default (-1) - _sendGridPercentage = (globalSettings.Mail?.SendGridPercentage).GetValueOrDefault(-1); - } + private static Random _random = new Random(); - public async Task SendEmailAsync(MailMessage message) - { - var roll = _random.Next(0, 99); - if (roll < _sendGridPercentage) + public MultiServiceMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger sesLogger, + ILogger sendGridLogger) { - await _sendGridService.SendEmailAsync(message); + _sesService = new AmazonSesMailDeliveryService(globalSettings, hostingEnvironment, sesLogger); + _sendGridService = new SendGridMailDeliveryService(globalSettings, hostingEnvironment, sendGridLogger); + + // disabled by default (-1) + _sendGridPercentage = (globalSettings.Mail?.SendGridPercentage).GetValueOrDefault(-1); } - else + + public async Task SendEmailAsync(MailMessage message) { - await _sesService.SendEmailAsync(message); + var roll = _random.Next(0, 99); + if (roll < _sendGridPercentage) + { + await _sendGridService.SendEmailAsync(message); + } + else + { + await _sesService.SendEmailAsync(message); + } } } } diff --git a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs index 4e1678da6d..f940bad005 100644 --- a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs +++ b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs @@ -6,160 +6,161 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class MultiServicePushNotificationService : IPushNotificationService +namespace Bit.Core.Services { - private readonly List _services = new List(); - private readonly ILogger _logger; - - public MultiServicePushNotificationService( - IHttpClientFactory httpFactory, - IDeviceRepository deviceRepository, - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, - ILogger logger, - ILogger relayLogger, - ILogger hubLogger) + public class MultiServicePushNotificationService : IPushNotificationService { - if (globalSettings.SelfHosted) + private readonly List _services = new List(); + private readonly ILogger _logger; + + public MultiServicePushNotificationService( + IHttpClientFactory httpFactory, + IDeviceRepository deviceRepository, + IInstallationDeviceRepository installationDeviceRepository, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor, + ILogger logger, + ILogger relayLogger, + ILogger hubLogger) { - if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + if (globalSettings.SelfHosted) { - _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings, - httpContextAccessor, relayLogger)); + if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && + globalSettings.Installation?.Id != null && + CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + { + _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings, + httpContextAccessor, relayLogger)); + } + if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && + CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) + { + _services.Add(new NotificationsApiPushNotificationService( + httpFactory, globalSettings, httpContextAccessor, hubLogger)); + } } - if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && - CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) + else { - _services.Add(new NotificationsApiPushNotificationService( - httpFactory, globalSettings, httpContextAccessor, hubLogger)); - } - } - else - { - if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString)) - { - _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository, - globalSettings, httpContextAccessor)); - } - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) - { - _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor)); + if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString)) + { + _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository, + globalSettings, httpContextAccessor)); + } + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) + { + _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor)); + } } + + _logger = logger; } - _logger = logger; - } - - public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - PushToServices((s) => s.PushSyncCipherCreateAsync(cipher, collectionIds)); - return Task.FromResult(0); - } - - public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - PushToServices((s) => s.PushSyncCipherUpdateAsync(cipher, collectionIds)); - return Task.FromResult(0); - } - - public Task PushSyncCipherDeleteAsync(Cipher cipher) - { - PushToServices((s) => s.PushSyncCipherDeleteAsync(cipher)); - return Task.FromResult(0); - } - - public Task PushSyncFolderCreateAsync(Folder folder) - { - PushToServices((s) => s.PushSyncFolderCreateAsync(folder)); - return Task.FromResult(0); - } - - public Task PushSyncFolderUpdateAsync(Folder folder) - { - PushToServices((s) => s.PushSyncFolderUpdateAsync(folder)); - return Task.FromResult(0); - } - - public Task PushSyncFolderDeleteAsync(Folder folder) - { - PushToServices((s) => s.PushSyncFolderDeleteAsync(folder)); - return Task.FromResult(0); - } - - public Task PushSyncCiphersAsync(Guid userId) - { - PushToServices((s) => s.PushSyncCiphersAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncVaultAsync(Guid userId) - { - PushToServices((s) => s.PushSyncVaultAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncOrgKeysAsync(Guid userId) - { - PushToServices((s) => s.PushSyncOrgKeysAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncSettingsAsync(Guid userId) - { - PushToServices((s) => s.PushSyncSettingsAsync(userId)); - return Task.FromResult(0); - } - - public Task PushLogOutAsync(Guid userId) - { - PushToServices((s) => s.PushLogOutAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncSendCreateAsync(Send send) - { - PushToServices((s) => s.PushSyncSendCreateAsync(send)); - return Task.FromResult(0); - } - - public Task PushSyncSendUpdateAsync(Send send) - { - PushToServices((s) => s.PushSyncSendUpdateAsync(send)); - return Task.FromResult(0); - } - - public Task PushSyncSendDeleteAsync(Send send) - { - PushToServices((s) => s.PushSyncSendDeleteAsync(send)); - return Task.FromResult(0); - } - - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId)); - return Task.FromResult(0); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId)); - return Task.FromResult(0); - } - - private void PushToServices(Func pushFunc) - { - if (_services != null) + public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) { - foreach (var service in _services) + PushToServices((s) => s.PushSyncCipherCreateAsync(cipher, collectionIds)); + return Task.FromResult(0); + } + + public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + PushToServices((s) => s.PushSyncCipherUpdateAsync(cipher, collectionIds)); + return Task.FromResult(0); + } + + public Task PushSyncCipherDeleteAsync(Cipher cipher) + { + PushToServices((s) => s.PushSyncCipherDeleteAsync(cipher)); + return Task.FromResult(0); + } + + public Task PushSyncFolderCreateAsync(Folder folder) + { + PushToServices((s) => s.PushSyncFolderCreateAsync(folder)); + return Task.FromResult(0); + } + + public Task PushSyncFolderUpdateAsync(Folder folder) + { + PushToServices((s) => s.PushSyncFolderUpdateAsync(folder)); + return Task.FromResult(0); + } + + public Task PushSyncFolderDeleteAsync(Folder folder) + { + PushToServices((s) => s.PushSyncFolderDeleteAsync(folder)); + return Task.FromResult(0); + } + + public Task PushSyncCiphersAsync(Guid userId) + { + PushToServices((s) => s.PushSyncCiphersAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncVaultAsync(Guid userId) + { + PushToServices((s) => s.PushSyncVaultAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncOrgKeysAsync(Guid userId) + { + PushToServices((s) => s.PushSyncOrgKeysAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncSettingsAsync(Guid userId) + { + PushToServices((s) => s.PushSyncSettingsAsync(userId)); + return Task.FromResult(0); + } + + public Task PushLogOutAsync(Guid userId) + { + PushToServices((s) => s.PushLogOutAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncSendCreateAsync(Send send) + { + PushToServices((s) => s.PushSyncSendCreateAsync(send)); + return Task.FromResult(0); + } + + public Task PushSyncSendUpdateAsync(Send send) + { + PushToServices((s) => s.PushSyncSendUpdateAsync(send)); + return Task.FromResult(0); + } + + public Task PushSyncSendDeleteAsync(Send send) + { + PushToServices((s) => s.PushSyncSendDeleteAsync(send)); + return Task.FromResult(0); + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId)); + return Task.FromResult(0); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId)); + return Task.FromResult(0); + } + + private void PushToServices(Func pushFunc) + { + if (_services != null) { - pushFunc(service); + foreach (var service in _services) + { + pushFunc(service); + } } } } diff --git a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs b/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs index fbd7ab9ce1..dbf4e55aa9 100644 --- a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs +++ b/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs @@ -10,230 +10,231 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Azure.NotificationHubs; -namespace Bit.Core.Services; - -public class NotificationHubPushNotificationService : IPushNotificationService +namespace Bit.Core.Services { - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - - private NotificationHubClient _client = null; - - public NotificationHubPushNotificationService( - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor) + public class NotificationHubPushNotificationService : IPushNotificationService { - _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; - _httpContextAccessor = httpContextAccessor; - _client = NotificationHubClient.CreateClientFromConnectionString( - _globalSettings.NotificationHub.ConnectionString, - _globalSettings.NotificationHub.HubName); - } + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } + private NotificationHubClient _client = null; - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } - - public async Task PushSyncCipherDeleteAsync(Cipher cipher) - { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + public NotificationHubPushNotificationService( + IInstallationDeviceRepository installationDeviceRepository, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor) { - // We cannot send org pushes since access logic is much more complicated than just the fact that they belong - // to the organization. Potentially we could blindly send to just users that have the access all permission - // device registration needs to be more granular to handle that appropriately. A more brute force approach could - // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. - - // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); + _installationDeviceRepository = installationDeviceRepository; + _globalSettings = globalSettings; + _httpContextAccessor = httpContextAccessor; + _client = NotificationHubClient.CreateClientFromConnectionString( + _globalSettings.NotificationHub.ConnectionString, + _globalSettings.NotificationHub.HubName); } - else if (cipher.UserId.HasValue) + + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) { - var message = new SyncCipherPushNotification + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } + + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } + + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } + + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, + // We cannot send org pushes since access logic is much more complicated than just the fact that they belong + // to the organization. Potentially we could blindly send to just users that have the access all permission + // device registration needs to be more granular to handle that appropriately. A more brute force approach could + // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. + + // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); + } + else if (cipher.UserId.HasValue) + { + var message = new SyncCipherPushNotification + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, + }; + + await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); + } + } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate }; - await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); + await SendPayloadToUserAsync(folder.UserId, type, message, true); } - } - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification + public async Task PushSyncCiphersAsync(Guid userId) { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate - }; + await PushUserAsync(userId, PushType.SyncCiphers); + } - await SendPayloadToUserAsync(folder.UserId, type, message, true); - } - - public async Task PushSyncCiphersAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification + public async Task PushSyncVaultAsync(Guid userId) { - UserId = userId, - Date = DateTime.UtcNow - }; + await PushUserAsync(userId, PushType.SyncVault); + } - await SendPayloadToUserAsync(userId, type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) + public async Task PushSyncOrgKeysAsync(Guid userId) { - var message = new SyncSendPushNotification + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate + UserId = userId, + Date = DateTime.UtcNow }; - await SendPayloadToUserAsync(message.UserId, type, message, true); + await SendPayloadToUserAsync(userId, type, message, false); } - } - private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) - { - await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); - } - - private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) - { - await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); - } - - public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier); - await SendPayloadAsync(tag, type, payload); - if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + public async Task PushSyncSendCreateAsync(Send send) { - await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + await PushSendAsync(send, PushType.SyncSendCreate); } - } - public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier); - await SendPayloadAsync(tag, type, payload); - if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + public async Task PushSyncSendUpdateAsync(Send send) { - await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + await PushSendAsync(send, PushType.SyncSendUpdate); } - } - private string GetContextIdentifier(bool excludeCurrentContext) - { - if (!excludeCurrentContext) + public async Task PushSyncSendDeleteAsync(Send send) { - return null; + await PushSendAsync(send, PushType.SyncSendDelete); } - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - return currentContext?.DeviceIdentifier; - } - - private string BuildTag(string tag, string identifier) - { - if (!string.IsNullOrWhiteSpace(identifier)) + private async Task PushSendAsync(Send send, PushType type) { - tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}"; - } - - return $"({tag})"; - } - - private async Task SendPayloadAsync(string tag, PushType type, object payload) - { - await _client.SendTemplateNotificationAsync( - new Dictionary + if (send.UserId.HasValue) { - { "type", ((byte)type).ToString() }, - { "payload", JsonSerializer.Serialize(payload) } - }, tag); - } + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; - private string SanitizeTagInput(string input) - { - // Only allow a-z, A-Z, 0-9, and special characters -_: - return Regex.Replace(input, "[^a-zA-Z0-9-_:]", string.Empty); + await SendPayloadToUserAsync(message.UserId, type, message, true); + } + } + + private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) + { + await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); + } + + private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) + { + await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); + } + + public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier); + await SendPayloadAsync(tag, type, payload); + if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + { + await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + } + } + + public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier); + await SendPayloadAsync(tag, type, payload); + if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + { + await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + } + } + + private string GetContextIdentifier(bool excludeCurrentContext) + { + if (!excludeCurrentContext) + { + return null; + } + + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + return currentContext?.DeviceIdentifier; + } + + private string BuildTag(string tag, string identifier) + { + if (!string.IsNullOrWhiteSpace(identifier)) + { + tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}"; + } + + return $"({tag})"; + } + + private async Task SendPayloadAsync(string tag, PushType type, object payload) + { + await _client.SendTemplateNotificationAsync( + new Dictionary + { + { "type", ((byte)type).ToString() }, + { "payload", JsonSerializer.Serialize(payload) } + }, tag); + } + + private string SanitizeTagInput(string input) + { + // Only allow a-z, A-Z, 0-9, and special characters -_: + return Regex.Replace(input, "[^a-zA-Z0-9-_:]", string.Empty); + } } } diff --git a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs b/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs index 6f09375398..be3f8735f9 100644 --- a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs +++ b/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs @@ -4,191 +4,192 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Microsoft.Azure.NotificationHubs; -namespace Bit.Core.Services; - -public class NotificationHubPushRegistrationService : IPushRegistrationService +namespace Bit.Core.Services { - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - - private NotificationHubClient _client = null; - - public NotificationHubPushRegistrationService( - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings) + public class NotificationHubPushRegistrationService : IPushRegistrationService { - _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; - _client = NotificationHubClient.CreateClientFromConnectionString( - _globalSettings.NotificationHub.ConnectionString, - _globalSettings.NotificationHub.HubName); - } + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; - public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) - { - if (string.IsNullOrWhiteSpace(pushToken)) + private NotificationHubClient _client = null; + + public NotificationHubPushRegistrationService( + IInstallationDeviceRepository installationDeviceRepository, + GlobalSettings globalSettings) { - return; + _installationDeviceRepository = installationDeviceRepository; + _globalSettings = globalSettings; + _client = NotificationHubClient.CreateClientFromConnectionString( + _globalSettings.NotificationHub.ConnectionString, + _globalSettings.NotificationHub.HubName); } - var installation = new Installation + public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type) { - InstallationId = deviceId, - PushChannel = pushToken, - Templates = new Dictionary() - }; - - installation.Tags = new List - { - $"userId:{userId}" - }; - - if (!string.IsNullOrWhiteSpace(identifier)) - { - installation.Tags.Add("deviceIdentifier:" + identifier); - } - - string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null; - switch (type) - { - case DeviceType.Android: - payloadTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}"; - messageTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\"}," + - "\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}"; - - installation.Platform = NotificationPlatform.Fcm; - break; - case DeviceType.iOS: - payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," + - "\"aps\":{\"content-available\":1}}"; - messageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + - "\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}"; - badgeMessageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + - "\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}"; - - installation.Platform = NotificationPlatform.Apns; - break; - case DeviceType.AndroidAmazon: - payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}"; - messageTemplate = "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}"; - - installation.Platform = NotificationPlatform.Adm; - break; - default: - break; - } - - BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier); - BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier); - BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, - userId, identifier); - - await _client.CreateOrUpdateInstallationAsync(installation); - if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) - { - await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); - } - } - - private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody, - string userId, string identifier) - { - if (templateBody == null) - { - return; - } - - var fullTemplateId = $"template:{templateId}"; - - var template = new InstallationTemplate - { - Body = templateBody, - Tags = new List + if (string.IsNullOrWhiteSpace(pushToken)) { - fullTemplateId, - $"{fullTemplateId}_userId:{userId}" + return; } - }; - if (!string.IsNullOrWhiteSpace(identifier)) - { - template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}"); - } + var installation = new Installation + { + InstallationId = deviceId, + PushChannel = pushToken, + Templates = new Dictionary() + }; - installation.Templates.Add(fullTemplateId, template); - } + installation.Tags = new List + { + $"userId:{userId}" + }; - public async Task DeleteRegistrationAsync(string deviceId) - { - try - { - await _client.DeleteInstallationAsync(deviceId); + if (!string.IsNullOrWhiteSpace(identifier)) + { + installation.Tags.Add("deviceIdentifier:" + identifier); + } + + string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null; + switch (type) + { + case DeviceType.Android: + payloadTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}"; + messageTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\"}," + + "\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}"; + + installation.Platform = NotificationPlatform.Fcm; + break; + case DeviceType.iOS: + payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," + + "\"aps\":{\"content-available\":1}}"; + messageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + + "\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}"; + badgeMessageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + + "\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}"; + + installation.Platform = NotificationPlatform.Apns; + break; + case DeviceType.AndroidAmazon: + payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}"; + messageTemplate = "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}"; + + installation.Platform = NotificationPlatform.Adm; + break; + default: + break; + } + + BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier); + BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier); + BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, + userId, identifier); + + await _client.CreateOrUpdateInstallationAsync(installation); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { - await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId)); + await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); } } - catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) - { - throw; - } - } - public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}"); - if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) + private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody, + string userId, string identifier) { - var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); - await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); - } - } + if (templateBody == null) + { + return; + } - public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove, - $"organizationId:{organizationId}"); - if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) - { - var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); - await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); - } - } + var fullTemplateId = $"template:{templateId}"; - private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op, - string tag) - { - if (!deviceIds.Any()) - { - return; + var template = new InstallationTemplate + { + Body = templateBody, + Tags = new List + { + fullTemplateId, + $"{fullTemplateId}_userId:{userId}" + } + }; + + if (!string.IsNullOrWhiteSpace(identifier)) + { + template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}"); + } + + installation.Templates.Add(fullTemplateId, template); } - var operation = new PartialUpdateOperation - { - Operation = op, - Path = "/tags" - }; - - if (op == UpdateOperationType.Add) - { - operation.Value = tag; - } - else if (op == UpdateOperationType.Remove) - { - operation.Path += $"/{tag}"; - } - - foreach (var id in deviceIds) + public async Task DeleteRegistrationAsync(string deviceId) { try { - await _client.PatchInstallationAsync(id, new List { operation }); + await _client.DeleteInstallationAsync(deviceId); + if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + { + await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId)); + } } catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) { throw; } } + + public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}"); + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) + { + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); + await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); + } + } + + public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove, + $"organizationId:{organizationId}"); + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) + { + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); + await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); + } + } + + private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op, + string tag) + { + if (!deviceIds.Any()) + { + return; + } + + var operation = new PartialUpdateOperation + { + Operation = op, + Path = "/tags" + }; + + if (op == UpdateOperationType.Add) + { + operation.Value = tag; + } + else if (op == UpdateOperationType.Remove) + { + operation.Path += $"/{tag}"; + } + + foreach (var id in deviceIds) + { + try + { + await _client.PatchInstallationAsync(id, new List { operation }); + } + catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) + { + throw; + } + } + } } } diff --git a/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs b/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs index 144178f84d..87729be705 100644 --- a/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs +++ b/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs @@ -6,197 +6,198 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService +namespace Bit.Core.Services { - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - - public NotificationsApiPushNotificationService( - IHttpClientFactory httpFactory, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, - ILogger logger) - : base( - httpFactory, - globalSettings.BaseServiceUri.InternalNotifications, - globalSettings.BaseServiceUri.InternalIdentity, - "internal", - $"internal.{globalSettings.ProjectName}", - globalSettings.InternalIdentityKey, - logger) + public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService { - _globalSettings = globalSettings; - _httpContextAccessor = httpContextAccessor; - } + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } - - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } - - public async Task PushSyncCipherDeleteAsync(Cipher cipher) - { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + public NotificationsApiPushNotificationService( + IHttpClientFactory httpFactory, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor, + ILogger logger) + : base( + httpFactory, + globalSettings.BaseServiceUri.InternalNotifications, + globalSettings.BaseServiceUri.InternalIdentity, + "internal", + $"internal.{globalSettings.ProjectName}", + globalSettings.InternalIdentityKey, + logger) { - var message = new SyncCipherPushNotification + _globalSettings = globalSettings; + _httpContextAccessor = httpContextAccessor; + } + + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } + + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } + + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } + + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - Id = cipher.Id, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, + var message = new SyncCipherPushNotification + { + Id = cipher.Id, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, + }; + + await SendMessageAsync(type, message, true); + } + else if (cipher.UserId.HasValue) + { + var message = new SyncCipherPushNotification + { + Id = cipher.Id, + UserId = cipher.UserId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, + }; + + await SendMessageAsync(type, message, true); + } + } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate }; await SendMessageAsync(type, message, true); } - else if (cipher.UserId.HasValue) - { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - UserId = cipher.UserId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, - }; - await SendMessageAsync(type, message, true); + public async Task PushSyncCiphersAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncCiphers); } - } - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification + public async Task PushSyncVaultAsync(Guid userId) { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate - }; + await PushUserAsync(userId, PushType.SyncVault); + } - await SendMessageAsync(type, message, true); - } - - public async Task PushSyncCiphersAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification + public async Task PushSyncOrgKeysAsync(Guid userId) { - UserId = userId, - Date = DateTime.UtcNow - }; + await PushUserAsync(userId, PushType.SyncOrgKeys); + } - await SendMessageAsync(type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) + public async Task PushSyncSettingsAsync(Guid userId) { - var message = new SyncSendPushNotification + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate + UserId = userId, + Date = DateTime.UtcNow }; await SendMessageAsync(type, message, false); } - } - private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) - { - var contextId = GetContextIdentifier(excludeCurrentContext); - var request = new PushNotificationData(type, payload, contextId); - await SendAsync(HttpMethod.Post, "send", request); - } - - private string GetContextIdentifier(bool excludeCurrentContext) - { - if (!excludeCurrentContext) + public async Task PushSyncSendCreateAsync(Send send) { - return null; + await PushSendAsync(send, PushType.SyncSendCreate); } - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - return currentContext?.DeviceIdentifier; - } + public async Task PushSyncSendUpdateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendUpdate); + } - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); - } + public async Task PushSyncSendDeleteAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendDelete); + } - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) + { + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; + + await SendMessageAsync(type, message, false); + } + } + + private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) + { + var contextId = GetContextIdentifier(excludeCurrentContext); + var request = new PushNotificationData(type, payload, contextId); + await SendAsync(HttpMethod.Post, "send", request); + } + + private string GetContextIdentifier(bool excludeCurrentContext) + { + if (!excludeCurrentContext) + { + return null; + } + + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + return currentContext?.DeviceIdentifier; + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/Implementations/OrganizationService.cs b/src/Core/Services/Implementations/OrganizationService.cs index 3d9f1da1f8..b0b3dfc07b 100644 --- a/src/Core/Services/Implementations/OrganizationService.cs +++ b/src/Core/Services/Implementations/OrganizationService.cs @@ -14,2440 +14,2441 @@ using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.Logging; using Stripe; -namespace Bit.Core.Services; - -public class OrganizationService : IOrganizationService +namespace Bit.Core.Services { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IUserRepository _userRepository; - private readonly IGroupRepository _groupRepository; - private readonly IDataProtector _dataProtector; - private readonly IMailService _mailService; - private readonly IPushNotificationService _pushNotificationService; - private readonly IPushRegistrationService _pushRegistrationService; - private readonly IDeviceRepository _deviceRepository; - private readonly ILicensingService _licensingService; - private readonly IEventService _eventService; - private readonly IInstallationRepository _installationRepository; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoUserRepository _ssoUserRepository; - private readonly IReferenceEventService _referenceEventService; - private readonly IGlobalSettings _globalSettings; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ICurrentContext _currentContext; - private readonly ILogger _logger; - - - public OrganizationService( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - ICollectionRepository collectionRepository, - IUserRepository userRepository, - IGroupRepository groupRepository, - IDataProtectionProvider dataProtectionProvider, - IMailService mailService, - IPushNotificationService pushNotificationService, - IPushRegistrationService pushRegistrationService, - IDeviceRepository deviceRepository, - ILicensingService licensingService, - IEventService eventService, - IInstallationRepository installationRepository, - IApplicationCacheService applicationCacheService, - IPaymentService paymentService, - IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - ISsoUserRepository ssoUserRepository, - IReferenceEventService referenceEventService, - IGlobalSettings globalSettings, - ITaxRateRepository taxRateRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - ICurrentContext currentContext, - ILogger logger) + public class OrganizationService : IOrganizationService { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _collectionRepository = collectionRepository; - _userRepository = userRepository; - _groupRepository = groupRepository; - _dataProtector = dataProtectionProvider.CreateProtector("OrganizationServiceDataProtector"); - _mailService = mailService; - _pushNotificationService = pushNotificationService; - _pushRegistrationService = pushRegistrationService; - _deviceRepository = deviceRepository; - _licensingService = licensingService; - _eventService = eventService; - _installationRepository = installationRepository; - _applicationCacheService = applicationCacheService; - _paymentService = paymentService; - _policyRepository = policyRepository; - _ssoConfigRepository = ssoConfigRepository; - _ssoUserRepository = ssoUserRepository; - _referenceEventService = referenceEventService; - _globalSettings = globalSettings; - _taxRateRepository = taxRateRepository; - _organizationApiKeyRepository = organizationApiKeyRepository; - _organizationConnectionRepository = organizationConnectionRepository; - _currentContext = currentContext; - _logger = logger; - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IUserRepository _userRepository; + private readonly IGroupRepository _groupRepository; + private readonly IDataProtector _dataProtector; + private readonly IMailService _mailService; + private readonly IPushNotificationService _pushNotificationService; + private readonly IPushRegistrationService _pushRegistrationService; + private readonly IDeviceRepository _deviceRepository; + private readonly ILicensingService _licensingService; + private readonly IEventService _eventService; + private readonly IInstallationRepository _installationRepository; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IPaymentService _paymentService; + private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoUserRepository _ssoUserRepository; + private readonly IReferenceEventService _referenceEventService; + private readonly IGlobalSettings _globalSettings; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ICurrentContext _currentContext; + private readonly ILogger _logger; - public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, - PaymentMethodType paymentMethodType, TaxInfo taxInfo) - { - var organization = await GetOrgById(organizationId); - if (organization == null) + + public OrganizationService( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository, + IUserRepository userRepository, + IGroupRepository groupRepository, + IDataProtectionProvider dataProtectionProvider, + IMailService mailService, + IPushNotificationService pushNotificationService, + IPushRegistrationService pushRegistrationService, + IDeviceRepository deviceRepository, + ILicensingService licensingService, + IEventService eventService, + IInstallationRepository installationRepository, + IApplicationCacheService applicationCacheService, + IPaymentService paymentService, + IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + ISsoUserRepository ssoUserRepository, + IReferenceEventService referenceEventService, + IGlobalSettings globalSettings, + ITaxRateRepository taxRateRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + ICurrentContext currentContext, + ILogger logger) { - throw new NotFoundException(); + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _collectionRepository = collectionRepository; + _userRepository = userRepository; + _groupRepository = groupRepository; + _dataProtector = dataProtectionProvider.CreateProtector("OrganizationServiceDataProtector"); + _mailService = mailService; + _pushNotificationService = pushNotificationService; + _pushRegistrationService = pushRegistrationService; + _deviceRepository = deviceRepository; + _licensingService = licensingService; + _eventService = eventService; + _installationRepository = installationRepository; + _applicationCacheService = applicationCacheService; + _paymentService = paymentService; + _policyRepository = policyRepository; + _ssoConfigRepository = ssoConfigRepository; + _ssoUserRepository = ssoUserRepository; + _referenceEventService = referenceEventService; + _globalSettings = globalSettings; + _taxRateRepository = taxRateRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + _organizationConnectionRepository = organizationConnectionRepository; + _currentContext = currentContext; + _logger = logger; } - await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - var updated = await _paymentService.UpdatePaymentMethodAsync(organization, - paymentMethodType, paymentToken); - if (updated) + public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, + PaymentMethodType paymentMethodType, TaxInfo taxInfo) { - await ReplaceAndUpdateCache(organization); - } - } - - public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - var eop = endOfPeriod.GetValueOrDefault(true); - if (!endOfPeriod.HasValue && organization.ExpirationDate.HasValue && - organization.ExpirationDate.Value < DateTime.UtcNow) - { - eop = false; - } - - await _paymentService.CancelSubscriptionAsync(organization, eop); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.CancelSubscription, organization) + var organization = await GetOrgById(organizationId); + if (organization == null) { - EndOfPeriod = endOfPeriod, - }); - } + throw new NotFoundException(); + } - public async Task ReinstateSubscriptionAsync(Guid organizationId) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - await _paymentService.ReinstateSubscriptionAsync(organization); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.ReinstateSubscription, organization)); - } - - public async Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) - { - throw new BadRequestException("Your account has no payment method available."); - } - - var existingPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (existingPlan == null) - { - throw new BadRequestException("Existing plan not found."); - } - - var newPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == upgrade.Plan && !p.Disabled); - if (newPlan == null) - { - throw new BadRequestException("Plan not found."); - } - - if (existingPlan.Type == newPlan.Type) - { - throw new BadRequestException("Organization is already on this plan."); - } - - if (existingPlan.UpgradeSortOrder >= newPlan.UpgradeSortOrder) - { - throw new BadRequestException("You cannot upgrade to this plan."); - } - - if (existingPlan.Type != PlanType.Free) - { - throw new BadRequestException("You can only upgrade from the free plan. Contact support."); - } - - ValidateOrganizationUpgradeParameters(newPlan, upgrade); - - var newPlanSeats = (short)(newPlan.BaseSeats + - (newPlan.HasAdditionalSeatsOption ? upgrade.AdditionalSeats : 0)); - if (!organization.Seats.HasValue || organization.Seats.Value > newPlanSeats) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); - if (userCount > newPlanSeats) + await _paymentService.SaveTaxInfoAsync(organization, taxInfo); + var updated = await _paymentService.UpdatePaymentMethodAsync(organization, + paymentMethodType, paymentToken); + if (updated) { - throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + - $"Your new plan only has ({newPlanSeats}) seats. Remove some users."); + await ReplaceAndUpdateCache(organization); } } - if (newPlan.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || - organization.MaxCollections.Value > newPlan.MaxCollections.Value)) + public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); - if (collectionCount > newPlan.MaxCollections.Value) + var organization = await GetOrgById(organizationId); + if (organization == null) { - throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + - $"Your new plan allows for a maximum of ({newPlan.MaxCollections.Value}) collections. " + - "Remove some collections."); + throw new NotFoundException(); } - } - if (!newPlan.HasGroups && organization.UseGroups) - { - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); - if (groups.Any()) + var eop = endOfPeriod.GetValueOrDefault(true); + if (!endOfPeriod.HasValue && organization.ExpirationDate.HasValue && + organization.ExpirationDate.Value < DateTime.UtcNow) { - throw new BadRequestException($"Your new plan does not allow the groups feature. " + - $"Remove your groups."); + eop = false; } - } - if (!newPlan.HasPolicies && organization.UsePolicies) - { - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); - if (policies.Any(p => p.Enabled)) - { - throw new BadRequestException($"Your new plan does not allow the policies feature. " + - $"Disable your policies."); - } - } - - if (!newPlan.HasSso && organization.UseSso) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.Enabled) - { - throw new BadRequestException($"Your new plan does not allow the SSO feature. " + - $"Disable your SSO configuration."); - } - } - - if (!newPlan.HasKeyConnector && organization.UseKeyConnector) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) - { - throw new BadRequestException("Your new plan does not allow the Key Connector feature. " + - "Disable your Key Connector."); - } - } - - if (!newPlan.HasResetPassword && organization.UseResetPassword) - { - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Your new plan does not allow the Password Reset feature. " + - "Disable your Password Reset policy."); - } - } - - if (!newPlan.HasScim && organization.UseScim) - { - var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, - OrganizationConnectionType.Scim); - if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) - { - throw new BadRequestException("Your new plan does not allow the SCIM feature. " + - "Disable your SCIM configuration."); - } - } - - // TODO: Check storage? - - string paymentIntentClientSecret = null; - var success = true; - if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) - { - paymentIntentClientSecret = await _paymentService.UpgradeFreeOrganizationAsync(organization, newPlan, - upgrade.AdditionalStorageGb, upgrade.AdditionalSeats, upgrade.PremiumAccessAddon, upgrade.TaxInfo); - success = string.IsNullOrWhiteSpace(paymentIntentClientSecret); - } - else - { - // TODO: Update existing sub - throw new BadRequestException("You can only upgrade from the free plan. Contact support."); - } - - organization.BusinessName = upgrade.BusinessName; - organization.PlanType = newPlan.Type; - organization.Seats = (short)(newPlan.BaseSeats + upgrade.AdditionalSeats); - organization.MaxCollections = newPlan.MaxCollections; - organization.UseGroups = newPlan.HasGroups; - organization.UseDirectory = newPlan.HasDirectory; - organization.UseEvents = newPlan.HasEvents; - organization.UseTotp = newPlan.HasTotp; - organization.Use2fa = newPlan.Has2fa; - organization.UseApi = newPlan.HasApi; - organization.SelfHost = newPlan.HasSelfHost; - organization.UsePolicies = newPlan.HasPolicies; - organization.MaxStorageGb = !newPlan.BaseStorageGb.HasValue ? - (short?)null : (short)(newPlan.BaseStorageGb.Value + upgrade.AdditionalStorageGb); - organization.UseGroups = newPlan.HasGroups; - organization.UseDirectory = newPlan.HasDirectory; - organization.UseEvents = newPlan.HasEvents; - organization.UseTotp = newPlan.HasTotp; - organization.Use2fa = newPlan.Has2fa; - organization.UseApi = newPlan.HasApi; - organization.UseSso = newPlan.HasSso; - organization.UseKeyConnector = newPlan.HasKeyConnector; - organization.UseScim = newPlan.HasScim; - organization.UseResetPassword = newPlan.HasResetPassword; - organization.SelfHost = newPlan.HasSelfHost; - organization.UsersGetPremium = newPlan.UsersGetPremium || upgrade.PremiumAccessAddon; - organization.Plan = newPlan.Name; - organization.Enabled = success; - organization.PublicKey = upgrade.PublicKey; - organization.PrivateKey = upgrade.PrivateKey; - await ReplaceAndUpdateCache(organization); - if (success) - { + await _paymentService.CancelSubscriptionAsync(organization, eop); await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.UpgradePlan, organization) + new ReferenceEvent(ReferenceEventType.CancelSubscription, organization) { - PlanName = newPlan.Name, - PlanType = newPlan.Type, - OldPlanName = existingPlan.Name, - OldPlanType = existingPlan.Type, - Seats = organization.Seats, - Storage = organization.MaxStorageGb, + EndOfPeriod = endOfPeriod, }); } - return new Tuple(success, paymentIntentClientSecret); - } - - public async Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb) - { - var organization = await GetOrgById(organizationId); - if (organization == null) + public async Task ReinstateSubscriptionAsync(Guid organizationId) { - throw new NotFoundException(); - } - - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan == null) - { - throw new BadRequestException("Existing plan not found."); - } - - if (!plan.HasAdditionalStorageOption) - { - throw new BadRequestException("Plan does not allow additional storage."); - } - - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, organization, storageAdjustmentGb, - plan.StripeStoragePlanId); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.AdjustStorage, organization) + var organization = await GetOrgById(organizationId); + if (organization == null) { - PlanName = plan.Name, - PlanType = plan.Type, - Storage = storageAdjustmentGb, - }); - await ReplaceAndUpdateCache(organization); - return secret; - } - - public async Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - var newSeatCount = organization.Seats + seatAdjustment; - if (maxAutoscaleSeats.HasValue && newSeatCount > maxAutoscaleSeats.Value) - { - throw new BadRequestException("Cannot set max seat autoscaling below seat count."); - } - - if (seatAdjustment != 0) - { - await AdjustSeatsAsync(organization, seatAdjustment); - } - if (maxAutoscaleSeats != organization.MaxAutoscaleSeats) - { - await UpdateAutoscalingAsync(organization, maxAutoscaleSeats); - } - } - - private async Task UpdateAutoscalingAsync(Organization organization, int? maxAutoscaleSeats) - { - - if (maxAutoscaleSeats.HasValue && - organization.Seats.HasValue && - maxAutoscaleSeats.Value < organization.Seats.Value) - { - throw new BadRequestException($"Cannot set max seat autoscaling below current seat count."); - } - - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan == null) - { - throw new BadRequestException("Existing plan not found."); - } - - if (!plan.AllowSeatAutoscale) - { - throw new BadRequestException("Your plan does not allow seat autoscaling."); - } - - if (plan.MaxUsers.HasValue && maxAutoscaleSeats.HasValue && - maxAutoscaleSeats > plan.MaxUsers) - { - throw new BadRequestException(string.Concat($"Your plan has a seat limit of {plan.MaxUsers}, ", - $"but you have specified a max autoscale count of {maxAutoscaleSeats}.", - "Reduce your max autoscale seat count.")); - } - - organization.MaxAutoscaleSeats = maxAutoscaleSeats; - - await ReplaceAndUpdateCache(organization); - } - - public async Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - return await AdjustSeatsAsync(organization, seatAdjustment, prorationDate); - } - - private async Task AdjustSeatsAsync(Organization organization, int seatAdjustment, DateTime? prorationDate = null, IEnumerable ownerEmails = null) - { - if (organization.Seats == null) - { - throw new BadRequestException("Organization has no seat limit, no need to adjust seats"); - } - - if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) - { - throw new BadRequestException("No payment method found."); - } - - if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) - { - throw new BadRequestException("No subscription found."); - } - - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan == null) - { - throw new BadRequestException("Existing plan not found."); - } - - if (!plan.HasAdditionalSeatsOption) - { - throw new BadRequestException("Plan does not allow additional seats."); - } - - var newSeatTotal = organization.Seats.Value + seatAdjustment; - if (plan.BaseSeats > newSeatTotal) - { - throw new BadRequestException($"Plan has a minimum of {plan.BaseSeats} seats."); - } - - if (newSeatTotal <= 0) - { - throw new BadRequestException("You must have at least 1 seat."); - } - - var additionalSeats = newSeatTotal - plan.BaseSeats; - if (plan.MaxAdditionalSeats.HasValue && additionalSeats > plan.MaxAdditionalSeats.Value) - { - throw new BadRequestException($"Organization plan allows a maximum of " + - $"{plan.MaxAdditionalSeats.Value} additional seats."); - } - - if (!organization.Seats.HasValue || organization.Seats.Value > newSeatTotal) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); - if (userCount > newSeatTotal) - { - throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + - $"Your new plan only has ({newSeatTotal}) seats. Remove some users."); + throw new NotFoundException(); } + + await _paymentService.ReinstateSubscriptionAsync(organization); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.ReinstateSubscription, organization)); } - var paymentIntentClientSecret = await _paymentService.AdjustSeatsAsync(organization, plan, additionalSeats, prorationDate); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.AdjustSeats, organization) - { - PlanName = plan.Name, - PlanType = plan.Type, - Seats = newSeatTotal, - PreviousSeats = organization.Seats - }); - organization.Seats = (short?)newSeatTotal; - await ReplaceAndUpdateCache(organization); - - if (organization.Seats.HasValue && organization.MaxAutoscaleSeats.HasValue && organization.Seats == organization.MaxAutoscaleSeats) + public async Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade) { - try + var organization = await GetOrgById(organizationId); + if (organization == null) { - if (ownerEmails == null) + throw new NotFoundException(); + } + + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + throw new BadRequestException("Your account has no payment method available."); + } + + var existingPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (existingPlan == null) + { + throw new BadRequestException("Existing plan not found."); + } + + var newPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == upgrade.Plan && !p.Disabled); + if (newPlan == null) + { + throw new BadRequestException("Plan not found."); + } + + if (existingPlan.Type == newPlan.Type) + { + throw new BadRequestException("Organization is already on this plan."); + } + + if (existingPlan.UpgradeSortOrder >= newPlan.UpgradeSortOrder) + { + throw new BadRequestException("You cannot upgrade to this plan."); + } + + if (existingPlan.Type != PlanType.Free) + { + throw new BadRequestException("You can only upgrade from the free plan. Contact support."); + } + + ValidateOrganizationUpgradeParameters(newPlan, upgrade); + + var newPlanSeats = (short)(newPlan.BaseSeats + + (newPlan.HasAdditionalSeatsOption ? upgrade.AdditionalSeats : 0)); + if (!organization.Seats.HasValue || organization.Seats.Value > newPlanSeats) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); + if (userCount > newPlanSeats) { - ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, - OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); + throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + + $"Your new plan only has ({newPlanSeats}) seats. Remove some users."); } - await _mailService.SendOrganizationMaxSeatLimitReachedEmailAsync(organization, organization.MaxAutoscaleSeats.Value, ownerEmails); } - catch (Exception e) + + if (newPlan.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || + organization.MaxCollections.Value > newPlan.MaxCollections.Value)) { - _logger.LogError(e, "Error encountered notifying organization owners of seat limit reached."); - } - } - - return paymentIntentClientSecret; - } - - public async Task VerifyBankAsync(Guid organizationId, int amount1, int amount2) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) - { - throw new GatewayException("Not a gateway customer."); - } - - var bankService = new BankAccountService(); - var customerService = new CustomerService(); - var customer = await customerService.GetAsync(organization.GatewayCustomerId, - new CustomerGetOptions { Expand = new List { "sources" } }); - if (customer == null) - { - throw new GatewayException("Cannot find customer."); - } - - var bankAccount = customer.Sources - .FirstOrDefault(s => s is BankAccount && ((BankAccount)s).Status != "verified") as BankAccount; - if (bankAccount == null) - { - throw new GatewayException("Cannot find an unverified bank account."); - } - - try - { - var result = await bankService.VerifyAsync(organization.GatewayCustomerId, bankAccount.Id, - new BankAccountVerifyOptions { Amounts = new List { amount1, amount2 } }); - if (result.Status != "verified") - { - throw new GatewayException("Unable to verify account."); - } - } - catch (StripeException e) - { - throw new GatewayException(e.Message); - } - } - - public async Task> SignUpAsync(OrganizationSignup signup, - bool provider = false) - { - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == signup.Plan); - if (!(plan is { LegacyYear: null })) - { - throw new BadRequestException("Invalid plan selected."); - } - - if (plan.Disabled) - { - throw new BadRequestException("Plan not found."); - } - - if (!provider) - { - await ValidateSignUpPoliciesAsync(signup.Owner.Id); - } - - ValidateOrganizationUpgradeParameters(plan, signup); - - var organization = new Organization - { - // Pre-generate the org id so that we can save it with the Stripe subscription.. - Id = CoreHelpers.GenerateComb(), - Name = signup.Name, - BillingEmail = signup.BillingEmail, - BusinessName = signup.BusinessName, - PlanType = plan.Type, - Seats = (short)(plan.BaseSeats + signup.AdditionalSeats), - MaxCollections = plan.MaxCollections, - MaxStorageGb = !plan.BaseStorageGb.HasValue ? - (short?)null : (short)(plan.BaseStorageGb.Value + signup.AdditionalStorageGb), - UsePolicies = plan.HasPolicies, - UseSso = plan.HasSso, - UseGroups = plan.HasGroups, - UseEvents = plan.HasEvents, - UseDirectory = plan.HasDirectory, - UseTotp = plan.HasTotp, - Use2fa = plan.Has2fa, - UseApi = plan.HasApi, - UseResetPassword = plan.HasResetPassword, - SelfHost = plan.HasSelfHost, - UsersGetPremium = plan.UsersGetPremium || signup.PremiumAccessAddon, - UseScim = plan.HasScim, - Plan = plan.Name, - Gateway = null, - ReferenceData = signup.Owner.ReferenceData, - Enabled = true, - LicenseKey = CoreHelpers.SecureRandomString(20), - PublicKey = signup.PublicKey, - PrivateKey = signup.PrivateKey, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; - - if (plan.Type == PlanType.Free && !provider) - { - var adminCount = - await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(signup.Owner.Id); - if (adminCount > 0) - { - throw new BadRequestException("You can only be an admin of one free organization."); - } - } - else if (plan.Type != PlanType.Free) - { - await _paymentService.PurchaseOrganizationAsync(organization, signup.PaymentMethodType.Value, - signup.PaymentToken, plan, signup.AdditionalStorageGb, signup.AdditionalSeats, - signup.PremiumAccessAddon, signup.TaxInfo); - } - - var ownerId = provider ? default : signup.Owner.Id; - var returnValue = await SignUpAsync(organization, ownerId, signup.OwnerKey, signup.CollectionName, true); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.Signup, organization) - { - PlanName = plan.Name, - PlanType = plan.Type, - Seats = returnValue.Item1.Seats, - Storage = returnValue.Item1.MaxStorageGb, - }); - return returnValue; - } - - private async Task ValidateSignUpPoliciesAsync(Guid ownerId) - { - var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(ownerId, PolicyType.SingleOrg); - if (singleOrgPolicyCount > 0) - { - throw new BadRequestException("You may not create an organization. You belong to an organization " + - "which has a policy that prohibits you from being a member of any other organization."); - } - } - - public async Task> SignUpAsync( - OrganizationLicense license, User owner, string ownerKey, string collectionName, string publicKey, - string privateKey) - { - if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) - { - throw new BadRequestException("Premium licenses cannot be applied to an organization. " - + "Upload this license from your personal account settings page."); - } - - if (license == null || !_licensingService.VerifyLicense(license)) - { - throw new BadRequestException("Invalid license."); - } - - if (!license.CanUse(_globalSettings)) - { - throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + - "hosting of organizations and that the installation id matches your current installation."); - } - - if (license.PlanType != PlanType.Custom && - StaticStore.Plans.FirstOrDefault(p => p.Type == license.PlanType && !p.Disabled) == null) - { - throw new BadRequestException("Plan not found."); - } - - var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); - if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey))) - { - throw new BadRequestException("License is already in use by another organization."); - } - - await ValidateSignUpPoliciesAsync(owner.Id); - - var organization = new Organization - { - Name = license.Name, - BillingEmail = license.BillingEmail, - BusinessName = license.BusinessName, - PlanType = license.PlanType, - Seats = license.Seats, - MaxCollections = license.MaxCollections, - MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb, // 10 TB - UsePolicies = license.UsePolicies, - UseSso = license.UseSso, - UseKeyConnector = license.UseKeyConnector, - UseScim = license.UseScim, - UseGroups = license.UseGroups, - UseDirectory = license.UseDirectory, - UseEvents = license.UseEvents, - UseTotp = license.UseTotp, - Use2fa = license.Use2fa, - UseApi = license.UseApi, - UseResetPassword = license.UseResetPassword, - Plan = license.Plan, - SelfHost = license.SelfHost, - UsersGetPremium = license.UsersGetPremium, - Gateway = null, - GatewayCustomerId = null, - GatewaySubscriptionId = null, - ReferenceData = owner.ReferenceData, - Enabled = license.Enabled, - ExpirationDate = license.Expires, - LicenseKey = license.LicenseKey, - PublicKey = publicKey, - PrivateKey = privateKey, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow - }; - - var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); - - var dir = $"{_globalSettings.LicenseDirectory}/organization"; - Directory.CreateDirectory(dir); - await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); - await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - return result; - } - - private async Task> SignUpAsync(Organization organization, - Guid ownerId, string ownerKey, string collectionName, bool withPayment) - { - try - { - await _organizationRepository.CreateAsync(organization); - await _organizationApiKeyRepository.CreateAsync(new OrganizationApiKey - { - OrganizationId = organization.Id, - ApiKey = CoreHelpers.SecureRandomString(30), - Type = OrganizationApiKeyType.Default, - RevisionDate = DateTime.UtcNow, - }); - await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); - - if (!string.IsNullOrWhiteSpace(collectionName)) - { - var defaultCollection = new Collection + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); + if (collectionCount > newPlan.MaxCollections.Value) { - Name = collectionName, - OrganizationId = organization.Id, - CreationDate = organization.CreationDate, - RevisionDate = organization.CreationDate - }; - await _collectionRepository.CreateAsync(defaultCollection); + throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + + $"Your new plan allows for a maximum of ({newPlan.MaxCollections.Value}) collections. " + + "Remove some collections."); + } } - OrganizationUser orgUser = null; - if (ownerId != default) + if (!newPlan.HasGroups && organization.UseGroups) { - orgUser = new OrganizationUser + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); + if (groups.Any()) { - OrganizationId = organization.Id, - UserId = ownerId, - Key = ownerKey, - Type = OrganizationUserType.Owner, - Status = OrganizationUserStatusType.Confirmed, - AccessAll = true, - CreationDate = organization.CreationDate, - RevisionDate = organization.CreationDate - }; - - await _organizationUserRepository.CreateAsync(orgUser); - - var deviceIds = await GetUserDeviceIdsAsync(orgUser.UserId.Value); - await _pushRegistrationService.AddUserRegistrationOrganizationAsync(deviceIds, - organization.Id.ToString()); - await _pushNotificationService.PushSyncOrgKeysAsync(ownerId); + throw new BadRequestException($"Your new plan does not allow the groups feature. " + + $"Remove your groups."); + } } - return new Tuple(organization, orgUser); - } - catch - { - if (withPayment) + if (!newPlan.HasPolicies && organization.UsePolicies) { - await _paymentService.CancelAndRecoverChargesAsync(organization); + var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); + if (policies.Any(p => p.Enabled)) + { + throw new BadRequestException($"Your new plan does not allow the policies feature. " + + $"Disable your policies."); + } } - if (organization.Id != default(Guid)) + if (!newPlan.HasSso && organization.UseSso) { - await _organizationRepository.DeleteAsync(organization); - await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.Enabled) + { + throw new BadRequestException($"Your new plan does not allow the SSO feature. " + + $"Disable your SSO configuration."); + } } - throw; - } - } - - public async Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - if (!_globalSettings.SelfHosted) - { - throw new InvalidOperationException("Licenses require self hosting."); - } - - if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) - { - throw new BadRequestException("Premium licenses cannot be applied to an organization. " - + "Upload this license from your personal account settings page."); - } - - if (license == null || !_licensingService.VerifyLicense(license)) - { - throw new BadRequestException("Invalid license."); - } - - if (!license.CanUse(_globalSettings)) - { - throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + - "hosting of organizations and that the installation id matches your current installation."); - } - - var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); - if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey) && o.Id != organizationId)) - { - throw new BadRequestException("License is already in use by another organization."); - } - - if (license.Seats.HasValue && - (!organization.Seats.HasValue || organization.Seats.Value > license.Seats.Value)) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); - if (userCount > license.Seats.Value) + if (!newPlan.HasKeyConnector && organization.UseKeyConnector) { - throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + - $"Your new license only has ({license.Seats.Value}) seats. Remove some users."); + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) + { + throw new BadRequestException("Your new plan does not allow the Key Connector feature. " + + "Disable your Key Connector."); + } } - } - if (license.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || - organization.MaxCollections.Value > license.MaxCollections.Value)) - { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); - if (collectionCount > license.MaxCollections.Value) + if (!newPlan.HasResetPassword && organization.UseResetPassword) { - throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + - $"Your new license allows for a maximum of ({license.MaxCollections.Value}) collections. " + - "Remove some collections."); + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Your new plan does not allow the Password Reset feature. " + + "Disable your Password Reset policy."); + } } - } - if (!license.UseGroups && organization.UseGroups) - { - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); - if (groups.Count > 0) + if (!newPlan.HasScim && organization.UseScim) { - throw new BadRequestException($"Your organization currently has {groups.Count} groups. " + - $"Your new license does not allow for the use of groups. Remove all groups."); + var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, + OrganizationConnectionType.Scim); + if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) + { + throw new BadRequestException("Your new plan does not allow the SCIM feature. " + + "Disable your SCIM configuration."); + } } - } - if (!license.UsePolicies && organization.UsePolicies) - { - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); - if (policies.Any(p => p.Enabled)) + // TODO: Check storage? + + string paymentIntentClientSecret = null; + var success = true; + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) { - throw new BadRequestException($"Your organization currently has {policies.Count} enabled " + - $"policies. Your new license does not allow for the use of policies. Disable all policies."); + paymentIntentClientSecret = await _paymentService.UpgradeFreeOrganizationAsync(organization, newPlan, + upgrade.AdditionalStorageGb, upgrade.AdditionalSeats, upgrade.PremiumAccessAddon, upgrade.TaxInfo); + success = string.IsNullOrWhiteSpace(paymentIntentClientSecret); } - } - - if (!license.UseSso && organization.UseSso) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.Enabled) + else { - throw new BadRequestException($"Your organization currently has a SSO configuration. " + - $"Your new license does not allow for the use of SSO. Disable your SSO configuration."); + // TODO: Update existing sub + throw new BadRequestException("You can only upgrade from the free plan. Contact support."); } - } - if (!license.UseKeyConnector && organization.UseKeyConnector) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) + organization.BusinessName = upgrade.BusinessName; + organization.PlanType = newPlan.Type; + organization.Seats = (short)(newPlan.BaseSeats + upgrade.AdditionalSeats); + organization.MaxCollections = newPlan.MaxCollections; + organization.UseGroups = newPlan.HasGroups; + organization.UseDirectory = newPlan.HasDirectory; + organization.UseEvents = newPlan.HasEvents; + organization.UseTotp = newPlan.HasTotp; + organization.Use2fa = newPlan.Has2fa; + organization.UseApi = newPlan.HasApi; + organization.SelfHost = newPlan.HasSelfHost; + organization.UsePolicies = newPlan.HasPolicies; + organization.MaxStorageGb = !newPlan.BaseStorageGb.HasValue ? + (short?)null : (short)(newPlan.BaseStorageGb.Value + upgrade.AdditionalStorageGb); + organization.UseGroups = newPlan.HasGroups; + organization.UseDirectory = newPlan.HasDirectory; + organization.UseEvents = newPlan.HasEvents; + organization.UseTotp = newPlan.HasTotp; + organization.Use2fa = newPlan.Has2fa; + organization.UseApi = newPlan.HasApi; + organization.UseSso = newPlan.HasSso; + organization.UseKeyConnector = newPlan.HasKeyConnector; + organization.UseScim = newPlan.HasScim; + organization.UseResetPassword = newPlan.HasResetPassword; + organization.SelfHost = newPlan.HasSelfHost; + organization.UsersGetPremium = newPlan.UsersGetPremium || upgrade.PremiumAccessAddon; + organization.Plan = newPlan.Name; + organization.Enabled = success; + organization.PublicKey = upgrade.PublicKey; + organization.PrivateKey = upgrade.PrivateKey; + await ReplaceAndUpdateCache(organization); + if (success) { - throw new BadRequestException($"Your organization currently has Key Connector enabled. " + - $"Your new license does not allow for the use of Key Connector. Disable your Key Connector."); - } - } - - if (!license.UseScim && organization.UseScim) - { - var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, - OrganizationConnectionType.Scim); - if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) - { - throw new BadRequestException("Your new plan does not allow the SCIM feature. " + - "Disable your SCIM configuration."); - } - } - - if (!license.UseResetPassword && organization.UseResetPassword) - { - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Your new license does not allow the Password Reset feature. " - + "Disable your Password Reset policy."); - } - } - - var dir = $"{_globalSettings.LicenseDirectory}/organization"; - Directory.CreateDirectory(dir); - await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); - await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - - organization.Name = license.Name; - organization.BusinessName = license.BusinessName; - organization.BillingEmail = license.BillingEmail; - organization.PlanType = license.PlanType; - organization.Seats = license.Seats; - organization.MaxCollections = license.MaxCollections; - organization.UseGroups = license.UseGroups; - organization.UseDirectory = license.UseDirectory; - organization.UseEvents = license.UseEvents; - organization.UseTotp = license.UseTotp; - organization.Use2fa = license.Use2fa; - organization.UseApi = license.UseApi; - organization.UsePolicies = license.UsePolicies; - organization.UseSso = license.UseSso; - organization.UseKeyConnector = license.UseKeyConnector; - organization.UseScim = license.UseScim; - organization.UseResetPassword = license.UseResetPassword; - organization.SelfHost = license.SelfHost; - organization.UsersGetPremium = license.UsersGetPremium; - organization.Plan = license.Plan; - organization.Enabled = license.Enabled; - organization.ExpirationDate = license.Expires; - organization.LicenseKey = license.LicenseKey; - organization.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(organization); - } - - public async Task DeleteAsync(Organization organization) - { - await ValidateDeleteOrganizationAsync(organization); - - if (!string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) - { - try - { - var eop = !organization.ExpirationDate.HasValue || - organization.ExpirationDate.Value >= DateTime.UtcNow; - await _paymentService.CancelSubscriptionAsync(organization, eop); await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.DeleteAccount, organization)); + new ReferenceEvent(ReferenceEventType.UpgradePlan, organization) + { + PlanName = newPlan.Name, + PlanType = newPlan.Type, + OldPlanName = existingPlan.Name, + OldPlanType = existingPlan.Type, + Seats = organization.Seats, + Storage = organization.MaxStorageGb, + }); } - catch (GatewayException) { } + + return new Tuple(success, paymentIntentClientSecret); } - await _organizationRepository.DeleteAsync(organization); - await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); - } - - public async Task EnableAsync(Guid organizationId, DateTime? expirationDate) - { - var org = await GetOrgById(organizationId); - if (org != null && !org.Enabled && org.Gateway.HasValue) + public async Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb) { - org.Enabled = true; - org.ExpirationDate = expirationDate; - org.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(org); - } - } - - public async Task DisableAsync(Guid organizationId, DateTime? expirationDate) - { - var org = await GetOrgById(organizationId); - if (org != null && org.Enabled) - { - org.Enabled = false; - org.ExpirationDate = expirationDate; - org.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(org); - - // TODO: send email to owners? - } - } - - public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate) - { - var org = await GetOrgById(organizationId); - if (org != null) - { - org.ExpirationDate = expirationDate; - org.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(org); - } - } - - public async Task EnableAsync(Guid organizationId) - { - var org = await GetOrgById(organizationId); - if (org != null && !org.Enabled) - { - org.Enabled = true; - await ReplaceAndUpdateCache(org); - } - } - - public async Task UpdateAsync(Organization organization, bool updateBilling = false) - { - if (organization.Id == default(Guid)) - { - throw new ApplicationException("Cannot create org this way. Call SignUpAsync."); - } - - if (!string.IsNullOrWhiteSpace(organization.Identifier)) - { - var orgById = await _organizationRepository.GetByIdentifierAsync(organization.Identifier); - if (orgById != null && orgById.Id != organization.Id) + var organization = await GetOrgById(organizationId); + if (organization == null) { - throw new BadRequestException("Identifier already in use by another organization."); + throw new NotFoundException(); + } + + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (plan == null) + { + throw new BadRequestException("Existing plan not found."); + } + + if (!plan.HasAdditionalStorageOption) + { + throw new BadRequestException("Plan does not allow additional storage."); + } + + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, organization, storageAdjustmentGb, + plan.StripeStoragePlanId); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.AdjustStorage, organization) + { + PlanName = plan.Name, + PlanType = plan.Type, + Storage = storageAdjustmentGb, + }); + await ReplaceAndUpdateCache(organization); + return secret; + } + + public async Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + var newSeatCount = organization.Seats + seatAdjustment; + if (maxAutoscaleSeats.HasValue && newSeatCount > maxAutoscaleSeats.Value) + { + throw new BadRequestException("Cannot set max seat autoscaling below seat count."); + } + + if (seatAdjustment != 0) + { + await AdjustSeatsAsync(organization, seatAdjustment); + } + if (maxAutoscaleSeats != organization.MaxAutoscaleSeats) + { + await UpdateAutoscalingAsync(organization, maxAutoscaleSeats); } } - await ReplaceAndUpdateCache(organization, EventType.Organization_Updated); - - if (updateBilling && !string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + private async Task UpdateAutoscalingAsync(Organization organization, int? maxAutoscaleSeats) { - var customerService = new CustomerService(); - await customerService.UpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + + if (maxAutoscaleSeats.HasValue && + organization.Seats.HasValue && + maxAutoscaleSeats.Value < organization.Seats.Value) { - Email = organization.BillingEmail, - Description = organization.BusinessName - }); - } - } - - public async Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) - { - if (!type.ToString().Contains("Organization")) - { - throw new ArgumentException("Not an organization provider type."); - } - - if (!organization.Use2fa) - { - throw new BadRequestException("Organization cannot use 2FA."); - } - - var providers = organization.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - providers[type].Enabled = true; - organization.SetTwoFactorProviders(providers); - await UpdateAsync(organization); - } - - public async Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) - { - if (!type.ToString().Contains("Organization")) - { - throw new ArgumentException("Not an organization provider type."); - } - - var providers = organization.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - providers.Remove(type); - organization.SetTwoFactorProviders(providers); - await UpdateAsync(organization); - } - - public async Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, - IEnumerable<(OrganizationUserInvite invite, string externalId)> invites) - { - var organization = await GetOrgById(organizationId); - var initialSeatCount = organization.Seats; - if (organization == null || invites.Any(i => i.invite.Emails == null)) - { - throw new NotFoundException(); - } - - var inviteTypes = new HashSet(invites.Where(i => i.invite.Type.HasValue) - .Select(i => i.invite.Type.Value)); - if (invitingUserId.HasValue && inviteTypes.Count > 0) - { - foreach (var type in inviteTypes) - { - await ValidateOrganizationUserUpdatePermissions(organizationId, type, null); + throw new BadRequestException($"Cannot set max seat autoscaling below current seat count."); } - } - var newSeatsRequired = 0; - var existingEmails = new HashSet(await _organizationUserRepository.SelectKnownEmailsAsync( - organizationId, invites.SelectMany(i => i.invite.Emails), false), StringComparer.InvariantCultureIgnoreCase); - if (organization.Seats.HasValue) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); - var availableSeats = organization.Seats.Value - userCount; - newSeatsRequired = invites.Sum(i => i.invite.Emails.Count()) - existingEmails.Count() - availableSeats; - } - - if (newSeatsRequired > 0) - { - var (canScale, failureReason) = CanScale(organization, newSeatsRequired); - if (!canScale) + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (plan == null) { - throw new BadRequestException(failureReason); + throw new BadRequestException("Existing plan not found."); } + + if (!plan.AllowSeatAutoscale) + { + throw new BadRequestException("Your plan does not allow seat autoscaling."); + } + + if (plan.MaxUsers.HasValue && maxAutoscaleSeats.HasValue && + maxAutoscaleSeats > plan.MaxUsers) + { + throw new BadRequestException(string.Concat($"Your plan has a seat limit of {plan.MaxUsers}, ", + $"but you have specified a max autoscale count of {maxAutoscaleSeats}.", + "Reduce your max autoscale seat count.")); + } + + organization.MaxAutoscaleSeats = maxAutoscaleSeats; + + await ReplaceAndUpdateCache(organization); } - var invitedAreAllOwners = invites.All(i => i.invite.Type == OrganizationUserType.Owner); - if (!invitedAreAllOwners && !await HasConfirmedOwnersExceptAsync(organizationId, new Guid[] { })) + public async Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null) { - throw new BadRequestException("Organization must have at least one confirmed owner."); + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + return await AdjustSeatsAsync(organization, seatAdjustment, prorationDate); } - - var orgUsers = new List(); - var limitedCollectionOrgUsers = new List<(OrganizationUser, IEnumerable)>(); - var orgUserInvitedCount = 0; - var exceptions = new List(); - var events = new List<(OrganizationUser, EventType, DateTime?)>(); - foreach (var (invite, externalId) in invites) + private async Task AdjustSeatsAsync(Organization organization, int seatAdjustment, DateTime? prorationDate = null, IEnumerable ownerEmails = null) { - // Prevent duplicate invitations - foreach (var email in invite.Emails.Distinct()) + if (organization.Seats == null) + { + throw new BadRequestException("Organization has no seat limit, no need to adjust seats"); + } + + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + throw new BadRequestException("No payment method found."); + } + + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + throw new BadRequestException("No subscription found."); + } + + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (plan == null) + { + throw new BadRequestException("Existing plan not found."); + } + + if (!plan.HasAdditionalSeatsOption) + { + throw new BadRequestException("Plan does not allow additional seats."); + } + + var newSeatTotal = organization.Seats.Value + seatAdjustment; + if (plan.BaseSeats > newSeatTotal) + { + throw new BadRequestException($"Plan has a minimum of {plan.BaseSeats} seats."); + } + + if (newSeatTotal <= 0) + { + throw new BadRequestException("You must have at least 1 seat."); + } + + var additionalSeats = newSeatTotal - plan.BaseSeats; + if (plan.MaxAdditionalSeats.HasValue && additionalSeats > plan.MaxAdditionalSeats.Value) + { + throw new BadRequestException($"Organization plan allows a maximum of " + + $"{plan.MaxAdditionalSeats.Value} additional seats."); + } + + if (!organization.Seats.HasValue || organization.Seats.Value > newSeatTotal) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); + if (userCount > newSeatTotal) + { + throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + + $"Your new plan only has ({newSeatTotal}) seats. Remove some users."); + } + } + + var paymentIntentClientSecret = await _paymentService.AdjustSeatsAsync(organization, plan, additionalSeats, prorationDate); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.AdjustSeats, organization) + { + PlanName = plan.Name, + PlanType = plan.Type, + Seats = newSeatTotal, + PreviousSeats = organization.Seats + }); + organization.Seats = (short?)newSeatTotal; + await ReplaceAndUpdateCache(organization); + + if (organization.Seats.HasValue && organization.MaxAutoscaleSeats.HasValue && organization.Seats == organization.MaxAutoscaleSeats) { try { - // Make sure user is not already invited - if (existingEmails.Contains(email)) + if (ownerEmails == null) { - continue; + ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, + OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); } - - var orgUser = new OrganizationUser - { - OrganizationId = organizationId, - UserId = null, - Email = email.ToLowerInvariant(), - Key = null, - Type = invite.Type.Value, - Status = OrganizationUserStatusType.Invited, - AccessAll = invite.AccessAll, - ExternalId = externalId, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; - - if (invite.Permissions != null) - { - orgUser.Permissions = JsonSerializer.Serialize(invite.Permissions, JsonHelpers.CamelCase); - } - - if (!orgUser.AccessAll && invite.Collections.Any()) - { - limitedCollectionOrgUsers.Add((orgUser, invite.Collections)); - } - else - { - orgUsers.Add(orgUser); - } - - events.Add((orgUser, EventType.OrganizationUser_Invited, DateTime.UtcNow)); - orgUserInvitedCount++; + await _mailService.SendOrganizationMaxSeatLimitReachedEmailAsync(organization, organization.MaxAutoscaleSeats.Value, ownerEmails); } catch (Exception e) { - exceptions.Add(e); + _logger.LogError(e, "Error encountered notifying organization owners of seat limit reached."); } } + + return paymentIntentClientSecret; } - if (exceptions.Any()) + public async Task VerifyBankAsync(Guid organizationId, int amount1, int amount2) { - throw new AggregateException("One or more errors occurred while inviting users.", exceptions); - } - - var prorationDate = DateTime.UtcNow; - try - { - await _organizationUserRepository.CreateManyAsync(orgUsers); - foreach (var (orgUser, collections) in limitedCollectionOrgUsers) + var organization = await GetOrgById(organizationId); + if (organization == null) { - await _organizationUserRepository.CreateAsync(orgUser, collections); + throw new NotFoundException(); } - if (!await _currentContext.ManageUsers(organization.Id)) + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) { - throw new BadRequestException("Cannot add seats. Cannot manage organization users."); + throw new GatewayException("Not a gateway customer."); } - await AutoAddSeatsAsync(organization, newSeatsRequired, prorationDate); - await SendInvitesAsync(orgUsers.Concat(limitedCollectionOrgUsers.Select(u => u.Item1)), organization); - await _eventService.LogOrganizationUserEventsAsync(events); + var bankService = new BankAccountService(); + var customerService = new CustomerService(); + var customer = await customerService.GetAsync(organization.GatewayCustomerId, + new CustomerGetOptions { Expand = new List { "sources" } }); + if (customer == null) + { + throw new GatewayException("Cannot find customer."); + } - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.InvitedUsers, organization) + var bankAccount = customer.Sources + .FirstOrDefault(s => s is BankAccount && ((BankAccount)s).Status != "verified") as BankAccount; + if (bankAccount == null) + { + throw new GatewayException("Cannot find an unverified bank account."); + } + + try + { + var result = await bankService.VerifyAsync(organization.GatewayCustomerId, bankAccount.Id, + new BankAccountVerifyOptions { Amounts = new List { amount1, amount2 } }); + if (result.Status != "verified") { - Users = orgUserInvitedCount - }); - } - catch (Exception e) - { - // Revert any added users. - var invitedOrgUserIds = orgUsers.Select(u => u.Id).Concat(limitedCollectionOrgUsers.Select(u => u.Item1.Id)); - await _organizationUserRepository.DeleteManyAsync(invitedOrgUserIds); - var currentSeatCount = (await _organizationRepository.GetByIdAsync(organization.Id)).Seats; - - if (initialSeatCount.HasValue && currentSeatCount.HasValue && currentSeatCount.Value != initialSeatCount.Value) + throw new GatewayException("Unable to verify account."); + } + } + catch (StripeException e) { - await AdjustSeatsAsync(organization, initialSeatCount.Value - currentSeatCount.Value, prorationDate); + throw new GatewayException(e.Message); + } + } + + public async Task> SignUpAsync(OrganizationSignup signup, + bool provider = false) + { + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == signup.Plan); + if (!(plan is { LegacyYear: null })) + { + throw new BadRequestException("Invalid plan selected."); } - exceptions.Add(e); - } - - if (exceptions.Any()) - { - throw new AggregateException("One or more errors occurred while inviting users.", exceptions); - } - - return orgUsers; - } - - public async Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, - IEnumerable organizationUsersId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); - var org = await GetOrgById(organizationId); - - var result = new List>(); - foreach (var orgUser in orgUsers) - { - if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) + if (plan.Disabled) { - result.Add(Tuple.Create(orgUser, "User invalid.")); - continue; + throw new BadRequestException("Plan not found."); } - await SendInviteAsync(orgUser, org); - result.Add(Tuple.Create(orgUser, "")); - } - - return result; - } - - public async Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != organizationId || - orgUser.Status != OrganizationUserStatusType.Invited) - { - throw new BadRequestException("User invalid."); - } - - var org = await GetOrgById(orgUser.OrganizationId); - await SendInviteAsync(orgUser, org); - } - - private async Task SendInvitesAsync(IEnumerable orgUsers, Organization organization) - { - string MakeToken(OrganizationUser orgUser) => - _dataProtector.Protect($"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - await _mailService.BulkSendOrganizationInviteEmailAsync(organization.Name, - orgUsers.Select(o => (o, new ExpiringToken(MakeToken(o), DateTime.UtcNow.AddDays(5))))); - } - - private async Task SendInviteAsync(OrganizationUser orgUser, Organization organization) - { - var now = DateTime.UtcNow; - var nowMillis = CoreHelpers.ToEpocMilliseconds(now); - var token = _dataProtector.Protect( - $"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {nowMillis}"); - - await _mailService.SendOrganizationInviteEmailAsync(organization.Name, orgUser, new ExpiringToken(token, now.AddDays(5))); - } - - public async Task AcceptUserAsync(Guid organizationUserId, User user, string token, - IUserService userService) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null) - { - throw new BadRequestException("User invalid."); - } - - if (!CoreHelpers.UserInviteTokenIsValid(_dataProtector, token, user.Email, orgUser.Id, _globalSettings)) - { - throw new BadRequestException("Invalid token."); - } - - var existingOrgUserCount = await _organizationUserRepository.GetCountByOrganizationAsync( - orgUser.OrganizationId, user.Email, true); - if (existingOrgUserCount > 0) - { - if (orgUser.Status == OrganizationUserStatusType.Accepted) + if (!provider) { - throw new BadRequestException("Invitation already accepted. You will receive an email when your organization membership is confirmed."); + await ValidateSignUpPoliciesAsync(signup.Owner.Id); } - throw new BadRequestException("You are already part of this organization."); - } - if (string.IsNullOrWhiteSpace(orgUser.Email) || - !orgUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("User email does not match invite."); - } + ValidateOrganizationUpgradeParameters(plan, signup); - return await AcceptUserAsync(orgUser, user, userService); - } - - public async Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService) - { - var org = await _organizationRepository.GetByIdentifierAsync(orgIdentifier); - if (org == null) - { - throw new BadRequestException("Organization invalid."); - } - - var usersOrgs = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var orgUser = usersOrgs.FirstOrDefault(u => u.OrganizationId == org.Id); - if (orgUser == null) - { - throw new BadRequestException("User not found within organization."); - } - - return await AcceptUserAsync(orgUser, user, userService); - } - - private async Task AcceptUserAsync(OrganizationUser orgUser, User user, - IUserService userService) - { - if (orgUser.Status != OrganizationUserStatusType.Invited) - { - throw new BadRequestException("Already accepted."); - } - - if (orgUser.Type == OrganizationUserType.Owner || orgUser.Type == OrganizationUserType.Admin) - { - var org = await GetOrgById(orgUser.OrganizationId); - if (org.PlanType == PlanType.Free) + var organization = new Organization { - var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync( - user.Id); + // Pre-generate the org id so that we can save it with the Stripe subscription.. + Id = CoreHelpers.GenerateComb(), + Name = signup.Name, + BillingEmail = signup.BillingEmail, + BusinessName = signup.BusinessName, + PlanType = plan.Type, + Seats = (short)(plan.BaseSeats + signup.AdditionalSeats), + MaxCollections = plan.MaxCollections, + MaxStorageGb = !plan.BaseStorageGb.HasValue ? + (short?)null : (short)(plan.BaseStorageGb.Value + signup.AdditionalStorageGb), + UsePolicies = plan.HasPolicies, + UseSso = plan.HasSso, + UseGroups = plan.HasGroups, + UseEvents = plan.HasEvents, + UseDirectory = plan.HasDirectory, + UseTotp = plan.HasTotp, + Use2fa = plan.Has2fa, + UseApi = plan.HasApi, + UseResetPassword = plan.HasResetPassword, + SelfHost = plan.HasSelfHost, + UsersGetPremium = plan.UsersGetPremium || signup.PremiumAccessAddon, + UseScim = plan.HasScim, + Plan = plan.Name, + Gateway = null, + ReferenceData = signup.Owner.ReferenceData, + Enabled = true, + LicenseKey = CoreHelpers.SecureRandomString(20), + PublicKey = signup.PublicKey, + PrivateKey = signup.PrivateKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + if (plan.Type == PlanType.Free && !provider) + { + var adminCount = + await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(signup.Owner.Id); if (adminCount > 0) { throw new BadRequestException("You can only be an admin of one free organization."); } } - } - - // Enforce Single Organization Policy of organization user is trying to join - var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); - var invitedSingleOrgPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, - PolicyType.SingleOrg, OrganizationUserStatusType.Invited); - - if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) - { - throw new BadRequestException("You may not join this organization until you leave or remove " + - "all other organizations."); - } - - // Enforce Single Organization Policy of other organizations user is a member of - var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(user.Id, - PolicyType.SingleOrg); - if (singleOrgPolicyCount > 0) - { - throw new BadRequestException("You cannot join this organization because you are a member of " + - "another organization which forbids it"); - } - - // Enforce Two Factor Authentication Policy of organization user is trying to join - if (!await userService.TwoFactorIsEnabledAsync(user)) - { - var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, - PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); - if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + else if (plan.Type != PlanType.Free) { - throw new BadRequestException("You cannot join this organization until you enable " + - "two-step login on your user account."); + await _paymentService.PurchaseOrganizationAsync(organization, signup.PaymentMethodType.Value, + signup.PaymentToken, plan, signup.AdditionalStorageGb, signup.AdditionalSeats, + signup.PremiumAccessAddon, signup.TaxInfo); + } + + var ownerId = provider ? default : signup.Owner.Id; + var returnValue = await SignUpAsync(organization, ownerId, signup.OwnerKey, signup.CollectionName, true); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Signup, organization) + { + PlanName = plan.Name, + PlanType = plan.Type, + Seats = returnValue.Item1.Seats, + Storage = returnValue.Item1.MaxStorageGb, + }); + return returnValue; + } + + private async Task ValidateSignUpPoliciesAsync(Guid ownerId) + { + var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(ownerId, PolicyType.SingleOrg); + if (singleOrgPolicyCount > 0) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); } } - orgUser.Status = OrganizationUserStatusType.Accepted; - orgUser.UserId = user.Id; - orgUser.Email = null; - - await _organizationUserRepository.ReplaceAsync(orgUser); - - var admins = await _organizationUserRepository.GetManyByMinimumRoleAsync(orgUser.OrganizationId, OrganizationUserType.Admin); - var adminEmails = admins.Select(a => a.Email).Distinct().ToList(); - - if (adminEmails.Count > 0) + public async Task> SignUpAsync( + OrganizationLicense license, User owner, string ownerKey, string collectionName, string publicKey, + string privateKey) { - var organization = await _organizationRepository.GetByIdAsync(orgUser.OrganizationId); - await _mailService.SendOrganizationAcceptedEmailAsync(organization, user.Email, adminEmails); - } - - return orgUser; - } - - public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, - Guid confirmingUserId, IUserService userService) - { - var result = await ConfirmUsersAsync(organizationId, new Dictionary() { { organizationUserId, key } }, - confirmingUserId, userService); - - if (!result.Any()) - { - throw new BadRequestException("User not valid."); - } - - var (orgUser, error) = result[0]; - if (error != "") - { - throw new BadRequestException(error); - } - return orgUser; - } - - public async Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, - Guid confirmingUserId, IUserService userService) - { - var organizationUsers = await _organizationUserRepository.GetManyAsync(keys.Keys); - var validOrganizationUsers = organizationUsers - .Where(u => u.Status == OrganizationUserStatusType.Accepted && u.OrganizationId == organizationId && u.UserId != null) - .ToList(); - - if (!validOrganizationUsers.Any()) - { - return new List>(); - } - - var validOrganizationUserIds = validOrganizationUsers.Select(u => u.UserId.Value).ToList(); - - var organization = await GetOrgById(organizationId); - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organizationId); - var usersOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(validOrganizationUserIds); - var users = await _userRepository.GetManyAsync(validOrganizationUserIds); - - var keyedFilteredUsers = validOrganizationUsers.ToDictionary(u => u.UserId.Value, u => u); - var keyedOrganizationUsers = usersOrgs.GroupBy(u => u.UserId.Value) - .ToDictionary(u => u.Key, u => u.ToList()); - - var succeededUsers = new List(); - var result = new List>(); - - foreach (var user in users) - { - if (!keyedFilteredUsers.ContainsKey(user.Id)) + if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) { - continue; + throw new BadRequestException("Premium licenses cannot be applied to an organization. " + + "Upload this license from your personal account settings page."); } - var orgUser = keyedFilteredUsers[user.Id]; - var orgUsers = keyedOrganizationUsers.GetValueOrDefault(user.Id, new List()); + + if (license == null || !_licensingService.VerifyLicense(license)) + { + throw new BadRequestException("Invalid license."); + } + + if (!license.CanUse(_globalSettings)) + { + throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + + "hosting of organizations and that the installation id matches your current installation."); + } + + if (license.PlanType != PlanType.Custom && + StaticStore.Plans.FirstOrDefault(p => p.Type == license.PlanType && !p.Disabled) == null) + { + throw new BadRequestException("Plan not found."); + } + + var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); + if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey))) + { + throw new BadRequestException("License is already in use by another organization."); + } + + await ValidateSignUpPoliciesAsync(owner.Id); + + var organization = new Organization + { + Name = license.Name, + BillingEmail = license.BillingEmail, + BusinessName = license.BusinessName, + PlanType = license.PlanType, + Seats = license.Seats, + MaxCollections = license.MaxCollections, + MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb, // 10 TB + UsePolicies = license.UsePolicies, + UseSso = license.UseSso, + UseKeyConnector = license.UseKeyConnector, + UseScim = license.UseScim, + UseGroups = license.UseGroups, + UseDirectory = license.UseDirectory, + UseEvents = license.UseEvents, + UseTotp = license.UseTotp, + Use2fa = license.Use2fa, + UseApi = license.UseApi, + UseResetPassword = license.UseResetPassword, + Plan = license.Plan, + SelfHost = license.SelfHost, + UsersGetPremium = license.UsersGetPremium, + Gateway = null, + GatewayCustomerId = null, + GatewaySubscriptionId = null, + ReferenceData = owner.ReferenceData, + Enabled = license.Enabled, + ExpirationDate = license.Expires, + LicenseKey = license.LicenseKey, + PublicKey = publicKey, + PrivateKey = privateKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow + }; + + var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); + + var dir = $"{_globalSettings.LicenseDirectory}/organization"; + Directory.CreateDirectory(dir); + await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + return result; + } + + private async Task> SignUpAsync(Organization organization, + Guid ownerId, string ownerKey, string collectionName, bool withPayment) + { try { - if (organization.PlanType == PlanType.Free && (orgUser.Type == OrganizationUserType.Admin - || orgUser.Type == OrganizationUserType.Owner)) + await _organizationRepository.CreateAsync(organization); + await _organizationApiKeyRepository.CreateAsync(new OrganizationApiKey { - // Since free organizations only supports a few users there is not much point in avoiding N+1 queries for this. - var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(user.Id); - if (adminCount > 0) + OrganizationId = organization.Id, + ApiKey = CoreHelpers.SecureRandomString(30), + Type = OrganizationApiKeyType.Default, + RevisionDate = DateTime.UtcNow, + }); + await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); + + if (!string.IsNullOrWhiteSpace(collectionName)) + { + var defaultCollection = new Collection { - throw new BadRequestException("User can only be an admin of one free organization."); - } + Name = collectionName, + OrganizationId = organization.Id, + CreationDate = organization.CreationDate, + RevisionDate = organization.CreationDate + }; + await _collectionRepository.CreateAsync(defaultCollection); } - await CheckPolicies(policies, organizationId, user, orgUsers, userService); - orgUser.Status = OrganizationUserStatusType.Confirmed; - orgUser.Key = keys[orgUser.Id]; - orgUser.Email = null; - - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); - await _mailService.SendOrganizationConfirmedEmailAsync(organization.Name, user.Email); - await DeleteAndPushUserRegistrationAsync(organizationId, user.Id); - succeededUsers.Add(orgUser); - result.Add(Tuple.Create(orgUser, "")); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(orgUser, e.Message)); - } - } - - await _organizationUserRepository.ReplaceManyAsync(succeededUsers); - - return result; - } - - internal (bool canScale, string failureReason) CanScale(Organization organization, - int seatsToAdd) - { - var failureReason = ""; - if (_globalSettings.SelfHosted) - { - failureReason = "Cannot autoscale on self-hosted instance."; - return (false, failureReason); - } - - if (seatsToAdd < 1) - { - return (true, failureReason); - } - - if (organization.Seats.HasValue && - organization.MaxAutoscaleSeats.HasValue && - organization.MaxAutoscaleSeats.Value < organization.Seats.Value + seatsToAdd) - { - return (false, $"Cannot invite new users. Seat limit has been reached."); - } - - return (true, failureReason); - } - - public async Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null) - { - if (seatsToAdd < 1 || !organization.Seats.HasValue) - { - return; - } - - var (canScale, failureMessage) = CanScale(organization, seatsToAdd); - if (!canScale) - { - throw new BadRequestException(failureMessage); - } - - var ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, - OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); - var initialSeatCount = organization.Seats.Value; - - await AdjustSeatsAsync(organization, seatsToAdd, prorationDate, ownerEmails); - - if (!organization.OwnersNotifiedOfAutoscaling.HasValue) - { - await _mailService.SendOrganizationAutoscaledEmailAsync(organization, initialSeatCount, - ownerEmails); - organization.OwnersNotifiedOfAutoscaling = DateTime.UtcNow; - await _organizationRepository.UpsertAsync(organization); - } - } - - private async Task CheckPolicies(ICollection policies, Guid organizationId, User user, - ICollection userOrgs, IUserService userService) - { - var usingTwoFactorPolicy = policies.Any(p => p.Type == PolicyType.TwoFactorAuthentication && p.Enabled); - if (usingTwoFactorPolicy && !await userService.TwoFactorIsEnabledAsync(user)) - { - throw new BadRequestException("User does not have two-step login enabled."); - } - - var usingSingleOrgPolicy = policies.Any(p => p.Type == PolicyType.SingleOrg && p.Enabled); - if (usingSingleOrgPolicy) - { - if (userOrgs.Any(ou => ou.OrganizationId != organizationId && ou.Status != OrganizationUserStatusType.Invited)) - { - throw new BadRequestException("User is a member of another organization."); - } - } - } - - public async Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, - IEnumerable collections) - { - if (user.Id.Equals(default(Guid))) - { - throw new BadRequestException("Invite the user first."); - } - - var originalUser = await _organizationUserRepository.GetByIdAsync(user.Id); - if (user.Equals(originalUser)) - { - throw new BadRequestException("Please make changes before saving."); - } - - if (savingUserId.HasValue) - { - await ValidateOrganizationUserUpdatePermissions(user.OrganizationId, user.Type, originalUser.Type); - } - - if (user.Type != OrganizationUserType.Owner && - !await HasConfirmedOwnersExceptAsync(user.OrganizationId, new[] { user.Id })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - if (user.AccessAll) - { - // We don't need any collections if we're flagged to have all access. - collections = new List(); - } - await _organizationUserRepository.ReplaceAsync(user, collections); - await _eventService.LogOrganizationUserEventAsync(user, EventType.OrganizationUser_Updated); - } - - public async Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - throw new BadRequestException("User not valid."); - } - - if (deletingUserId.HasValue && orgUser.UserId == deletingUserId.Value) - { - throw new BadRequestException("You cannot remove yourself."); - } - - if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && - !await _currentContext.OrganizationOwner(organizationId)) - { - throw new BadRequestException("Only owners can delete other owners."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { organizationUserId })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - await _organizationUserRepository.DeleteAsync(orgUser); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) - { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); - } - } - - public async Task DeleteUserAsync(Guid organizationId, Guid userId) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); - if (orgUser == null) - { - throw new NotFoundException(); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { orgUser.Id })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - await _organizationUserRepository.DeleteAsync(orgUser); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) - { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); - } - } - - public async Task>> DeleteUsersAsync(Guid organizationId, - IEnumerable organizationUsersId, - Guid? deletingUserId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); - - if (!filteredUsers.Any()) - { - throw new BadRequestException("Users invalid."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUsersId)) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - var deletingUserIsOwner = false; - if (deletingUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } - - var result = new List>(); - var deletedUserIds = new List(); - foreach (var orgUser in filteredUsers) - { - try - { - if (deletingUserId.HasValue && orgUser.UserId == deletingUserId) + OrganizationUser orgUser = null; + if (ownerId != default) { - throw new BadRequestException("You cannot remove yourself."); + orgUser = new OrganizationUser + { + OrganizationId = organization.Id, + UserId = ownerId, + Key = ownerKey, + Type = OrganizationUserType.Owner, + Status = OrganizationUserStatusType.Confirmed, + AccessAll = true, + CreationDate = organization.CreationDate, + RevisionDate = organization.CreationDate + }; + + await _organizationUserRepository.CreateAsync(orgUser); + + var deviceIds = await GetUserDeviceIdsAsync(orgUser.UserId.Value); + await _pushRegistrationService.AddUserRegistrationOrganizationAsync(deviceIds, + organization.Id.ToString()); + await _pushNotificationService.PushSyncOrgKeysAsync(ownerId); } - if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && !deletingUserIsOwner) + return new Tuple(organization, orgUser); + } + catch + { + if (withPayment) { - throw new BadRequestException("Only owners can delete other owners."); + await _paymentService.CancelAndRecoverChargesAsync(organization); } - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) + if (organization.Id != default(Guid)) { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + await _organizationRepository.DeleteAsync(organization); + await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); } - result.Add(Tuple.Create(orgUser, "")); - deletedUserIds.Add(orgUser.Id); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(orgUser, e.Message)); - } - await _organizationUserRepository.DeleteManyAsync(deletedUserIds); - } - - return result; - } - - public async Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true) - { - var confirmedOwners = await GetConfirmedOwnersAsync(organizationId); - var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); - bool hasOtherOwner = confirmedOwnersIds.Except(organizationUsersId).Any(); - if (!hasOtherOwner && includeProvider) - { - return (await _currentContext.ProviderIdForOrg(organizationId)).HasValue; - } - return hasOtherOwner; - } - - public async Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId) - { - if (loggedInUserId.HasValue) - { - await ValidateOrganizationUserUpdatePermissions(organizationUser.OrganizationId, organizationUser.Type, null); - } - await _organizationUserRepository.UpdateGroupsAsync(organizationUser.Id, groupIds); - await _eventService.LogOrganizationUserEventAsync(organizationUser, - EventType.OrganizationUser_UpdatedGroups); - } - - public async Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId) - { - // Org User must be the same as the calling user and the organization ID associated with the user must match passed org ID - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); - if (!callingUserId.HasValue || orgUser == null || orgUser.UserId != callingUserId.Value || - orgUser.OrganizationId != organizationId) - { - throw new BadRequestException("User not valid."); - } - - // Make sure the organization has the ability to use password reset - var org = await _organizationRepository.GetByIdAsync(organizationId); - if (org == null || !org.UseResetPassword) - { - throw new BadRequestException("Organization does not allow password reset enrollment."); - } - - // Make sure the organization has the policy enabled - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Organization does not have the password reset policy enabled."); - } - - // Block the user from withdrawal if auto enrollment is enabled - if (resetPasswordKey == null && resetPasswordPolicy.Data != null) - { - var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); - - if (data?.AutoEnrollEnabled ?? false) - { - throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from Password Reset."); + throw; } } - orgUser.ResetPasswordKey = resetPasswordKey; - await _organizationUserRepository.ReplaceAsync(orgUser); - await _eventService.LogOrganizationUserEventAsync(orgUser, resetPasswordKey != null ? - EventType.OrganizationUser_ResetPassword_Enroll : EventType.OrganizationUser_ResetPassword_Withdraw); - } - - public async Task GenerateLicenseAsync(Guid organizationId, Guid installationId) - { - var organization = await GetOrgById(organizationId); - return await GenerateLicenseAsync(organization, installationId); - } - - public async Task GenerateLicenseAsync(Organization organization, Guid installationId, - int? version = null) - { - if (organization == null) + public async Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license) { - throw new NotFoundException(); - } - - var installation = await _installationRepository.GetByIdAsync(installationId); - if (installation == null || !installation.Enabled) - { - throw new BadRequestException("Invalid installation id"); - } - - var subInfo = await _paymentService.GetSubscriptionAsync(organization); - return new OrganizationLicense(organization, subInfo, installationId, _licensingService, version); - } - - public async Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, - OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections) - { - var invite = new OrganizationUserInvite() - { - Emails = new List { email }, - Type = type, - AccessAll = accessAll, - Collections = collections, - }; - var results = await InviteUsersAsync(organizationId, invitingUserId, - new (OrganizationUserInvite, string)[] { (invite, externalId) }); - var result = results.FirstOrDefault(); - if (result == null) - { - throw new BadRequestException("This user has already been invited."); - } - return result; - } - - public async Task ImportAsync(Guid organizationId, - Guid? importingUserId, - IEnumerable groups, - IEnumerable newUsers, - IEnumerable removeUserExternalIds, - bool overwriteExisting) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - if (!organization.UseDirectory) - { - throw new BadRequestException("Organization cannot use directory syncing."); - } - - var newUsersSet = new HashSet(newUsers?.Select(u => u.ExternalId) ?? new List()); - var existingUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - var existingExternalUsers = existingUsers.Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); - var existingExternalUsersIdDict = existingExternalUsers.ToDictionary(u => u.ExternalId, u => u.Id); - - // Users - - // Remove Users - if (removeUserExternalIds?.Any() ?? false) - { - var removeUsersSet = new HashSet(removeUserExternalIds); - var existingUsersDict = existingExternalUsers.ToDictionary(u => u.ExternalId); - - await _organizationUserRepository.DeleteManyAsync(removeUsersSet - .Except(newUsersSet) - .Where(u => existingUsersDict.ContainsKey(u) && existingUsersDict[u].Type != OrganizationUserType.Owner) - .Select(u => existingUsersDict[u].Id)); - } - - if (overwriteExisting) - { - // Remove existing external users that are not in new user set - var usersToDelete = existingExternalUsers.Where(u => - u.Type != OrganizationUserType.Owner && - !newUsersSet.Contains(u.ExternalId) && - existingExternalUsersIdDict.ContainsKey(u.ExternalId)); - await _organizationUserRepository.DeleteManyAsync(usersToDelete.Select(u => u.Id)); - foreach (var deletedUser in usersToDelete) + var organization = await GetOrgById(organizationId); + if (organization == null) { - existingExternalUsersIdDict.Remove(deletedUser.ExternalId); + throw new NotFoundException(); } - } - if (newUsers?.Any() ?? false) - { - // Marry existing users - var existingUsersEmailsDict = existingUsers - .Where(u => string.IsNullOrWhiteSpace(u.ExternalId)) - .ToDictionary(u => u.Email); - var newUsersEmailsDict = newUsers.ToDictionary(u => u.Email); - var usersToAttach = existingUsersEmailsDict.Keys.Intersect(newUsersEmailsDict.Keys).ToList(); - var usersToUpsert = new List(); - foreach (var user in usersToAttach) + if (!_globalSettings.SelfHosted) { - var orgUserDetails = existingUsersEmailsDict[user]; - var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserDetails.Id); - if (orgUser != null) + throw new InvalidOperationException("Licenses require self hosting."); + } + + if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) + { + throw new BadRequestException("Premium licenses cannot be applied to an organization. " + + "Upload this license from your personal account settings page."); + } + + if (license == null || !_licensingService.VerifyLicense(license)) + { + throw new BadRequestException("Invalid license."); + } + + if (!license.CanUse(_globalSettings)) + { + throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + + "hosting of organizations and that the installation id matches your current installation."); + } + + var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); + if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey) && o.Id != organizationId)) + { + throw new BadRequestException("License is already in use by another organization."); + } + + if (license.Seats.HasValue && + (!organization.Seats.HasValue || organization.Seats.Value > license.Seats.Value)) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); + if (userCount > license.Seats.Value) { - orgUser.ExternalId = newUsersEmailsDict[user].ExternalId; - usersToUpsert.Add(orgUser); - existingExternalUsersIdDict.Add(orgUser.ExternalId, orgUser.Id); + throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + + $"Your new license only has ({license.Seats.Value}) seats. Remove some users."); } } - await _organizationUserRepository.UpsertManyAsync(usersToUpsert); - // Add new users - var existingUsersSet = new HashSet(existingExternalUsersIdDict.Keys); - var usersToAdd = newUsersSet.Except(existingUsersSet).ToList(); + if (license.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || + organization.MaxCollections.Value > license.MaxCollections.Value)) + { + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); + if (collectionCount > license.MaxCollections.Value) + { + throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + + $"Your new license allows for a maximum of ({license.MaxCollections.Value}) collections. " + + "Remove some collections."); + } + } - var seatsAvailable = int.MaxValue; - var enoughSeatsAvailable = true; + if (!license.UseGroups && organization.UseGroups) + { + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); + if (groups.Count > 0) + { + throw new BadRequestException($"Your organization currently has {groups.Count} groups. " + + $"Your new license does not allow for the use of groups. Remove all groups."); + } + } + + if (!license.UsePolicies && organization.UsePolicies) + { + var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); + if (policies.Any(p => p.Enabled)) + { + throw new BadRequestException($"Your organization currently has {policies.Count} enabled " + + $"policies. Your new license does not allow for the use of policies. Disable all policies."); + } + } + + if (!license.UseSso && organization.UseSso) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.Enabled) + { + throw new BadRequestException($"Your organization currently has a SSO configuration. " + + $"Your new license does not allow for the use of SSO. Disable your SSO configuration."); + } + } + + if (!license.UseKeyConnector && organization.UseKeyConnector) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) + { + throw new BadRequestException($"Your organization currently has Key Connector enabled. " + + $"Your new license does not allow for the use of Key Connector. Disable your Key Connector."); + } + } + + if (!license.UseScim && organization.UseScim) + { + var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, + OrganizationConnectionType.Scim); + if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) + { + throw new BadRequestException("Your new plan does not allow the SCIM feature. " + + "Disable your SCIM configuration."); + } + } + + if (!license.UseResetPassword && organization.UseResetPassword) + { + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Your new license does not allow the Password Reset feature. " + + "Disable your Password Reset policy."); + } + } + + var dir = $"{_globalSettings.LicenseDirectory}/organization"; + Directory.CreateDirectory(dir); + await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + + organization.Name = license.Name; + organization.BusinessName = license.BusinessName; + organization.BillingEmail = license.BillingEmail; + organization.PlanType = license.PlanType; + organization.Seats = license.Seats; + organization.MaxCollections = license.MaxCollections; + organization.UseGroups = license.UseGroups; + organization.UseDirectory = license.UseDirectory; + organization.UseEvents = license.UseEvents; + organization.UseTotp = license.UseTotp; + organization.Use2fa = license.Use2fa; + organization.UseApi = license.UseApi; + organization.UsePolicies = license.UsePolicies; + organization.UseSso = license.UseSso; + organization.UseKeyConnector = license.UseKeyConnector; + organization.UseScim = license.UseScim; + organization.UseResetPassword = license.UseResetPassword; + organization.SelfHost = license.SelfHost; + organization.UsersGetPremium = license.UsersGetPremium; + organization.Plan = license.Plan; + organization.Enabled = license.Enabled; + organization.ExpirationDate = license.Expires; + organization.LicenseKey = license.LicenseKey; + organization.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(organization); + } + + public async Task DeleteAsync(Organization organization) + { + await ValidateDeleteOrganizationAsync(organization); + + if (!string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + try + { + var eop = !organization.ExpirationDate.HasValue || + organization.ExpirationDate.Value >= DateTime.UtcNow; + await _paymentService.CancelSubscriptionAsync(organization, eop); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.DeleteAccount, organization)); + } + catch (GatewayException) { } + } + + await _organizationRepository.DeleteAsync(organization); + await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); + } + + public async Task EnableAsync(Guid organizationId, DateTime? expirationDate) + { + var org = await GetOrgById(organizationId); + if (org != null && !org.Enabled && org.Gateway.HasValue) + { + org.Enabled = true; + org.ExpirationDate = expirationDate; + org.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(org); + } + } + + public async Task DisableAsync(Guid organizationId, DateTime? expirationDate) + { + var org = await GetOrgById(organizationId); + if (org != null && org.Enabled) + { + org.Enabled = false; + org.ExpirationDate = expirationDate; + org.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(org); + + // TODO: send email to owners? + } + } + + public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate) + { + var org = await GetOrgById(organizationId); + if (org != null) + { + org.ExpirationDate = expirationDate; + org.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(org); + } + } + + public async Task EnableAsync(Guid organizationId) + { + var org = await GetOrgById(organizationId); + if (org != null && !org.Enabled) + { + org.Enabled = true; + await ReplaceAndUpdateCache(org); + } + } + + public async Task UpdateAsync(Organization organization, bool updateBilling = false) + { + if (organization.Id == default(Guid)) + { + throw new ApplicationException("Cannot create org this way. Call SignUpAsync."); + } + + if (!string.IsNullOrWhiteSpace(organization.Identifier)) + { + var orgById = await _organizationRepository.GetByIdentifierAsync(organization.Identifier); + if (orgById != null && orgById.Id != organization.Id) + { + throw new BadRequestException("Identifier already in use by another organization."); + } + } + + await ReplaceAndUpdateCache(organization, EventType.Organization_Updated); + + if (updateBilling && !string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + var customerService = new CustomerService(); + await customerService.UpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + { + Email = organization.BillingEmail, + Description = organization.BusinessName + }); + } + } + + public async Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) + { + if (!type.ToString().Contains("Organization")) + { + throw new ArgumentException("Not an organization provider type."); + } + + if (!organization.Use2fa) + { + throw new BadRequestException("Organization cannot use 2FA."); + } + + var providers = organization.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + providers[type].Enabled = true; + organization.SetTwoFactorProviders(providers); + await UpdateAsync(organization); + } + + public async Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) + { + if (!type.ToString().Contains("Organization")) + { + throw new ArgumentException("Not an organization provider type."); + } + + var providers = organization.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + providers.Remove(type); + organization.SetTwoFactorProviders(providers); + await UpdateAsync(organization); + } + + public async Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, + IEnumerable<(OrganizationUserInvite invite, string externalId)> invites) + { + var organization = await GetOrgById(organizationId); + var initialSeatCount = organization.Seats; + if (organization == null || invites.Any(i => i.invite.Emails == null)) + { + throw new NotFoundException(); + } + + var inviteTypes = new HashSet(invites.Where(i => i.invite.Type.HasValue) + .Select(i => i.invite.Type.Value)); + if (invitingUserId.HasValue && inviteTypes.Count > 0) + { + foreach (var type in inviteTypes) + { + await ValidateOrganizationUserUpdatePermissions(organizationId, type, null); + } + } + + var newSeatsRequired = 0; + var existingEmails = new HashSet(await _organizationUserRepository.SelectKnownEmailsAsync( + organizationId, invites.SelectMany(i => i.invite.Emails), false), StringComparer.InvariantCultureIgnoreCase); if (organization.Seats.HasValue) { var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); - seatsAvailable = organization.Seats.Value - userCount; - enoughSeatsAvailable = seatsAvailable >= usersToAdd.Count; + var availableSeats = organization.Seats.Value - userCount; + newSeatsRequired = invites.Sum(i => i.invite.Emails.Count()) - existingEmails.Count() - availableSeats; } - var userInvites = new List<(OrganizationUserInvite, string)>(); - foreach (var user in newUsers) + if (newSeatsRequired > 0) { - if (!usersToAdd.Contains(user.ExternalId) || string.IsNullOrWhiteSpace(user.Email)) + var (canScale, failureReason) = CanScale(organization, newSeatsRequired); + if (!canScale) { - continue; + throw new BadRequestException(failureReason); } + } - try + var invitedAreAllOwners = invites.All(i => i.invite.Type == OrganizationUserType.Owner); + if (!invitedAreAllOwners && !await HasConfirmedOwnersExceptAsync(organizationId, new Guid[] { })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + + var orgUsers = new List(); + var limitedCollectionOrgUsers = new List<(OrganizationUser, IEnumerable)>(); + var orgUserInvitedCount = 0; + var exceptions = new List(); + var events = new List<(OrganizationUser, EventType, DateTime?)>(); + foreach (var (invite, externalId) in invites) + { + // Prevent duplicate invitations + foreach (var email in invite.Emails.Distinct()) { - var invite = new OrganizationUserInvite + try { - Emails = new List { user.Email }, - Type = OrganizationUserType.User, - AccessAll = false, - Collections = new List(), - }; - userInvites.Add((invite, user.ExternalId)); - } - catch (BadRequestException) - { - // Thrown when the user is already invited to the organization - continue; + // Make sure user is not already invited + if (existingEmails.Contains(email)) + { + continue; + } + + var orgUser = new OrganizationUser + { + OrganizationId = organizationId, + UserId = null, + Email = email.ToLowerInvariant(), + Key = null, + Type = invite.Type.Value, + Status = OrganizationUserStatusType.Invited, + AccessAll = invite.AccessAll, + ExternalId = externalId, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + if (invite.Permissions != null) + { + orgUser.Permissions = JsonSerializer.Serialize(invite.Permissions, JsonHelpers.CamelCase); + } + + if (!orgUser.AccessAll && invite.Collections.Any()) + { + limitedCollectionOrgUsers.Add((orgUser, invite.Collections)); + } + else + { + orgUsers.Add(orgUser); + } + + events.Add((orgUser, EventType.OrganizationUser_Invited, DateTime.UtcNow)); + orgUserInvitedCount++; + } + catch (Exception e) + { + exceptions.Add(e); + } } } - var invitedUsers = await InviteUsersAsync(organizationId, importingUserId, userInvites); - foreach (var invitedUser in invitedUsers) + if (exceptions.Any()) { - existingExternalUsersIdDict.Add(invitedUser.ExternalId, invitedUser.Id); + throw new AggregateException("One or more errors occurred while inviting users.", exceptions); } + + var prorationDate = DateTime.UtcNow; + try + { + await _organizationUserRepository.CreateManyAsync(orgUsers); + foreach (var (orgUser, collections) in limitedCollectionOrgUsers) + { + await _organizationUserRepository.CreateAsync(orgUser, collections); + } + + if (!await _currentContext.ManageUsers(organization.Id)) + { + throw new BadRequestException("Cannot add seats. Cannot manage organization users."); + } + + await AutoAddSeatsAsync(organization, newSeatsRequired, prorationDate); + await SendInvitesAsync(orgUsers.Concat(limitedCollectionOrgUsers.Select(u => u.Item1)), organization); + await _eventService.LogOrganizationUserEventsAsync(events); + + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.InvitedUsers, organization) + { + Users = orgUserInvitedCount + }); + } + catch (Exception e) + { + // Revert any added users. + var invitedOrgUserIds = orgUsers.Select(u => u.Id).Concat(limitedCollectionOrgUsers.Select(u => u.Item1.Id)); + await _organizationUserRepository.DeleteManyAsync(invitedOrgUserIds); + var currentSeatCount = (await _organizationRepository.GetByIdAsync(organization.Id)).Seats; + + if (initialSeatCount.HasValue && currentSeatCount.HasValue && currentSeatCount.Value != initialSeatCount.Value) + { + await AdjustSeatsAsync(organization, initialSeatCount.Value - currentSeatCount.Value, prorationDate); + } + + exceptions.Add(e); + } + + if (exceptions.Any()) + { + throw new AggregateException("One or more errors occurred while inviting users.", exceptions); + } + + return orgUsers; } - - // Groups - if (groups?.Any() ?? false) + public async Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, + IEnumerable organizationUsersId) { - if (!organization.UseGroups) + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + var org = await GetOrgById(organizationId); + + var result = new List>(); + foreach (var orgUser in orgUsers) { - throw new BadRequestException("Organization cannot use groups."); + if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) + { + result.Add(Tuple.Create(orgUser, "User invalid.")); + continue; + } + + await SendInviteAsync(orgUser, org); + result.Add(Tuple.Create(orgUser, "")); } - var groupsDict = groups.ToDictionary(g => g.Group.ExternalId); - var existingGroups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); - var existingExternalGroups = existingGroups - .Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); - var existingExternalGroupsDict = existingExternalGroups.ToDictionary(g => g.ExternalId); + return result; + } - var newGroups = groups - .Where(g => !existingExternalGroupsDict.ContainsKey(g.Group.ExternalId)) - .Select(g => g.Group); - - foreach (var group in newGroups) + public async Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != organizationId || + orgUser.Status != OrganizationUserStatusType.Invited) { - group.CreationDate = group.RevisionDate = DateTime.UtcNow; - - await _groupRepository.CreateAsync(group); - await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, - existingExternalUsersIdDict); + throw new BadRequestException("User invalid."); } - var updateGroups = existingExternalGroups - .Where(g => groupsDict.ContainsKey(g.ExternalId)) + var org = await GetOrgById(orgUser.OrganizationId); + await SendInviteAsync(orgUser, org); + } + + private async Task SendInvitesAsync(IEnumerable orgUsers, Organization organization) + { + string MakeToken(OrganizationUser orgUser) => + _dataProtector.Protect($"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + await _mailService.BulkSendOrganizationInviteEmailAsync(organization.Name, + orgUsers.Select(o => (o, new ExpiringToken(MakeToken(o), DateTime.UtcNow.AddDays(5))))); + } + + private async Task SendInviteAsync(OrganizationUser orgUser, Organization organization) + { + var now = DateTime.UtcNow; + var nowMillis = CoreHelpers.ToEpocMilliseconds(now); + var token = _dataProtector.Protect( + $"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {nowMillis}"); + + await _mailService.SendOrganizationInviteEmailAsync(organization.Name, orgUser, new ExpiringToken(token, now.AddDays(5))); + } + + public async Task AcceptUserAsync(Guid organizationUserId, User user, string token, + IUserService userService) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null) + { + throw new BadRequestException("User invalid."); + } + + if (!CoreHelpers.UserInviteTokenIsValid(_dataProtector, token, user.Email, orgUser.Id, _globalSettings)) + { + throw new BadRequestException("Invalid token."); + } + + var existingOrgUserCount = await _organizationUserRepository.GetCountByOrganizationAsync( + orgUser.OrganizationId, user.Email, true); + if (existingOrgUserCount > 0) + { + if (orgUser.Status == OrganizationUserStatusType.Accepted) + { + throw new BadRequestException("Invitation already accepted. You will receive an email when your organization membership is confirmed."); + } + throw new BadRequestException("You are already part of this organization."); + } + + if (string.IsNullOrWhiteSpace(orgUser.Email) || + !orgUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + throw new BadRequestException("User email does not match invite."); + } + + return await AcceptUserAsync(orgUser, user, userService); + } + + public async Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService) + { + var org = await _organizationRepository.GetByIdentifierAsync(orgIdentifier); + if (org == null) + { + throw new BadRequestException("Organization invalid."); + } + + var usersOrgs = await _organizationUserRepository.GetManyByUserAsync(user.Id); + var orgUser = usersOrgs.FirstOrDefault(u => u.OrganizationId == org.Id); + if (orgUser == null) + { + throw new BadRequestException("User not found within organization."); + } + + return await AcceptUserAsync(orgUser, user, userService); + } + + private async Task AcceptUserAsync(OrganizationUser orgUser, User user, + IUserService userService) + { + if (orgUser.Status != OrganizationUserStatusType.Invited) + { + throw new BadRequestException("Already accepted."); + } + + if (orgUser.Type == OrganizationUserType.Owner || orgUser.Type == OrganizationUserType.Admin) + { + var org = await GetOrgById(orgUser.OrganizationId); + if (org.PlanType == PlanType.Free) + { + var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync( + user.Id); + if (adminCount > 0) + { + throw new BadRequestException("You can only be an admin of one free organization."); + } + } + } + + // Enforce Single Organization Policy of organization user is trying to join + var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id); + var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); + var invitedSingleOrgPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, + PolicyType.SingleOrg, OrganizationUserStatusType.Invited); + + if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + { + throw new BadRequestException("You may not join this organization until you leave or remove " + + "all other organizations."); + } + + // Enforce Single Organization Policy of other organizations user is a member of + var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(user.Id, + PolicyType.SingleOrg); + if (singleOrgPolicyCount > 0) + { + throw new BadRequestException("You cannot join this organization because you are a member of " + + "another organization which forbids it"); + } + + // Enforce Two Factor Authentication Policy of organization user is trying to join + if (!await userService.TwoFactorIsEnabledAsync(user)) + { + var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, + PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); + if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + { + throw new BadRequestException("You cannot join this organization until you enable " + + "two-step login on your user account."); + } + } + + orgUser.Status = OrganizationUserStatusType.Accepted; + orgUser.UserId = user.Id; + orgUser.Email = null; + + await _organizationUserRepository.ReplaceAsync(orgUser); + + var admins = await _organizationUserRepository.GetManyByMinimumRoleAsync(orgUser.OrganizationId, OrganizationUserType.Admin); + var adminEmails = admins.Select(a => a.Email).Distinct().ToList(); + + if (adminEmails.Count > 0) + { + var organization = await _organizationRepository.GetByIdAsync(orgUser.OrganizationId); + await _mailService.SendOrganizationAcceptedEmailAsync(organization, user.Email, adminEmails); + } + + return orgUser; + } + + public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, + Guid confirmingUserId, IUserService userService) + { + var result = await ConfirmUsersAsync(organizationId, new Dictionary() { { organizationUserId, key } }, + confirmingUserId, userService); + + if (!result.Any()) + { + throw new BadRequestException("User not valid."); + } + + var (orgUser, error) = result[0]; + if (error != "") + { + throw new BadRequestException(error); + } + return orgUser; + } + + public async Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, + Guid confirmingUserId, IUserService userService) + { + var organizationUsers = await _organizationUserRepository.GetManyAsync(keys.Keys); + var validOrganizationUsers = organizationUsers + .Where(u => u.Status == OrganizationUserStatusType.Accepted && u.OrganizationId == organizationId && u.UserId != null) .ToList(); - if (updateGroups.Any()) + if (!validOrganizationUsers.Any()) { - var groupUsers = await _groupRepository.GetManyGroupUsersByOrganizationIdAsync(organizationId); - var existingGroupUsers = groupUsers - .GroupBy(gu => gu.GroupId) - .ToDictionary(g => g.Key, g => new HashSet(g.Select(gr => gr.OrganizationUserId))); + return new List>(); + } - foreach (var group in updateGroups) + var validOrganizationUserIds = validOrganizationUsers.Select(u => u.UserId.Value).ToList(); + + var organization = await GetOrgById(organizationId); + var policies = await _policyRepository.GetManyByOrganizationIdAsync(organizationId); + var usersOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(validOrganizationUserIds); + var users = await _userRepository.GetManyAsync(validOrganizationUserIds); + + var keyedFilteredUsers = validOrganizationUsers.ToDictionary(u => u.UserId.Value, u => u); + var keyedOrganizationUsers = usersOrgs.GroupBy(u => u.UserId.Value) + .ToDictionary(u => u.Key, u => u.ToList()); + + var succeededUsers = new List(); + var result = new List>(); + + foreach (var user in users) + { + if (!keyedFilteredUsers.ContainsKey(user.Id)) { - var updatedGroup = groupsDict[group.ExternalId].Group; - if (group.Name != updatedGroup.Name) + continue; + } + var orgUser = keyedFilteredUsers[user.Id]; + var orgUsers = keyedOrganizationUsers.GetValueOrDefault(user.Id, new List()); + try + { + if (organization.PlanType == PlanType.Free && (orgUser.Type == OrganizationUserType.Admin + || orgUser.Type == OrganizationUserType.Owner)) { - group.RevisionDate = DateTime.UtcNow; - group.Name = updatedGroup.Name; - - await _groupRepository.ReplaceAsync(group); + // Since free organizations only supports a few users there is not much point in avoiding N+1 queries for this. + var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(user.Id); + if (adminCount > 0) + { + throw new BadRequestException("User can only be an admin of one free organization."); + } } + await CheckPolicies(policies, organizationId, user, orgUsers, userService); + orgUser.Status = OrganizationUserStatusType.Confirmed; + orgUser.Key = keys[orgUser.Id]; + orgUser.Email = null; + + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); + await _mailService.SendOrganizationConfirmedEmailAsync(organization.Name, user.Email); + await DeleteAndPushUserRegistrationAsync(organizationId, user.Id); + succeededUsers.Add(orgUser); + result.Add(Tuple.Create(orgUser, "")); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(orgUser, e.Message)); + } + } + + await _organizationUserRepository.ReplaceManyAsync(succeededUsers); + + return result; + } + + internal (bool canScale, string failureReason) CanScale(Organization organization, + int seatsToAdd) + { + var failureReason = ""; + if (_globalSettings.SelfHosted) + { + failureReason = "Cannot autoscale on self-hosted instance."; + return (false, failureReason); + } + + if (seatsToAdd < 1) + { + return (true, failureReason); + } + + if (organization.Seats.HasValue && + organization.MaxAutoscaleSeats.HasValue && + organization.MaxAutoscaleSeats.Value < organization.Seats.Value + seatsToAdd) + { + return (false, $"Cannot invite new users. Seat limit has been reached."); + } + + return (true, failureReason); + } + + public async Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null) + { + if (seatsToAdd < 1 || !organization.Seats.HasValue) + { + return; + } + + var (canScale, failureMessage) = CanScale(organization, seatsToAdd); + if (!canScale) + { + throw new BadRequestException(failureMessage); + } + + var ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, + OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); + var initialSeatCount = organization.Seats.Value; + + await AdjustSeatsAsync(organization, seatsToAdd, prorationDate, ownerEmails); + + if (!organization.OwnersNotifiedOfAutoscaling.HasValue) + { + await _mailService.SendOrganizationAutoscaledEmailAsync(organization, initialSeatCount, + ownerEmails); + organization.OwnersNotifiedOfAutoscaling = DateTime.UtcNow; + await _organizationRepository.UpsertAsync(organization); + } + } + + private async Task CheckPolicies(ICollection policies, Guid organizationId, User user, + ICollection userOrgs, IUserService userService) + { + var usingTwoFactorPolicy = policies.Any(p => p.Type == PolicyType.TwoFactorAuthentication && p.Enabled); + if (usingTwoFactorPolicy && !await userService.TwoFactorIsEnabledAsync(user)) + { + throw new BadRequestException("User does not have two-step login enabled."); + } + + var usingSingleOrgPolicy = policies.Any(p => p.Type == PolicyType.SingleOrg && p.Enabled); + if (usingSingleOrgPolicy) + { + if (userOrgs.Any(ou => ou.OrganizationId != organizationId && ou.Status != OrganizationUserStatusType.Invited)) + { + throw new BadRequestException("User is a member of another organization."); + } + } + } + + public async Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, + IEnumerable collections) + { + if (user.Id.Equals(default(Guid))) + { + throw new BadRequestException("Invite the user first."); + } + + var originalUser = await _organizationUserRepository.GetByIdAsync(user.Id); + if (user.Equals(originalUser)) + { + throw new BadRequestException("Please make changes before saving."); + } + + if (savingUserId.HasValue) + { + await ValidateOrganizationUserUpdatePermissions(user.OrganizationId, user.Type, originalUser.Type); + } + + if (user.Type != OrganizationUserType.Owner && + !await HasConfirmedOwnersExceptAsync(user.OrganizationId, new[] { user.Id })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + if (user.AccessAll) + { + // We don't need any collections if we're flagged to have all access. + collections = new List(); + } + await _organizationUserRepository.ReplaceAsync(user, collections); + await _eventService.LogOrganizationUserEventAsync(user, EventType.OrganizationUser_Updated); + } + + public async Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + throw new BadRequestException("User not valid."); + } + + if (deletingUserId.HasValue && orgUser.UserId == deletingUserId.Value) + { + throw new BadRequestException("You cannot remove yourself."); + } + + if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && + !await _currentContext.OrganizationOwner(organizationId)) + { + throw new BadRequestException("Only owners can delete other owners."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { organizationUserId })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + await _organizationUserRepository.DeleteAsync(orgUser); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); + + if (orgUser.UserId.HasValue) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + } + + public async Task DeleteUserAsync(Guid organizationId, Guid userId) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); + if (orgUser == null) + { + throw new NotFoundException(); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { orgUser.Id })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + await _organizationUserRepository.DeleteAsync(orgUser); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); + + if (orgUser.UserId.HasValue) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + } + + public async Task>> DeleteUsersAsync(Guid organizationId, + IEnumerable organizationUsersId, + Guid? deletingUserId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) + .ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException("Users invalid."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUsersId)) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + var deletingUserIsOwner = false; + if (deletingUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var result = new List>(); + var deletedUserIds = new List(); + foreach (var orgUser in filteredUsers) + { + try + { + if (deletingUserId.HasValue && orgUser.UserId == deletingUserId) + { + throw new BadRequestException("You cannot remove yourself."); + } + + if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException("Only owners can delete other owners."); + } + + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); + + if (orgUser.UserId.HasValue) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + result.Add(Tuple.Create(orgUser, "")); + deletedUserIds.Add(orgUser.Id); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(orgUser, e.Message)); + } + + await _organizationUserRepository.DeleteManyAsync(deletedUserIds); + } + + return result; + } + + public async Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true) + { + var confirmedOwners = await GetConfirmedOwnersAsync(organizationId); + var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); + bool hasOtherOwner = confirmedOwnersIds.Except(organizationUsersId).Any(); + if (!hasOtherOwner && includeProvider) + { + return (await _currentContext.ProviderIdForOrg(organizationId)).HasValue; + } + return hasOtherOwner; + } + + public async Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId) + { + if (loggedInUserId.HasValue) + { + await ValidateOrganizationUserUpdatePermissions(organizationUser.OrganizationId, organizationUser.Type, null); + } + await _organizationUserRepository.UpdateGroupsAsync(organizationUser.Id, groupIds); + await _eventService.LogOrganizationUserEventAsync(organizationUser, + EventType.OrganizationUser_UpdatedGroups); + } + + public async Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId) + { + // Org User must be the same as the calling user and the organization ID associated with the user must match passed org ID + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); + if (!callingUserId.HasValue || orgUser == null || orgUser.UserId != callingUserId.Value || + orgUser.OrganizationId != organizationId) + { + throw new BadRequestException("User not valid."); + } + + // Make sure the organization has the ability to use password reset + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset enrollment."); + } + + // Make sure the organization has the policy enabled + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Block the user from withdrawal if auto enrollment is enabled + if (resetPasswordKey == null && resetPasswordPolicy.Data != null) + { + var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); + + if (data?.AutoEnrollEnabled ?? false) + { + throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from Password Reset."); + } + } + + orgUser.ResetPasswordKey = resetPasswordKey; + await _organizationUserRepository.ReplaceAsync(orgUser); + await _eventService.LogOrganizationUserEventAsync(orgUser, resetPasswordKey != null ? + EventType.OrganizationUser_ResetPassword_Enroll : EventType.OrganizationUser_ResetPassword_Withdraw); + } + + public async Task GenerateLicenseAsync(Guid organizationId, Guid installationId) + { + var organization = await GetOrgById(organizationId); + return await GenerateLicenseAsync(organization, installationId); + } + + public async Task GenerateLicenseAsync(Organization organization, Guid installationId, + int? version = null) + { + if (organization == null) + { + throw new NotFoundException(); + } + + var installation = await _installationRepository.GetByIdAsync(installationId); + if (installation == null || !installation.Enabled) + { + throw new BadRequestException("Invalid installation id"); + } + + var subInfo = await _paymentService.GetSubscriptionAsync(organization); + return new OrganizationLicense(organization, subInfo, installationId, _licensingService, version); + } + + public async Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, + OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections) + { + var invite = new OrganizationUserInvite() + { + Emails = new List { email }, + Type = type, + AccessAll = accessAll, + Collections = collections, + }; + var results = await InviteUsersAsync(organizationId, invitingUserId, + new (OrganizationUserInvite, string)[] { (invite, externalId) }); + var result = results.FirstOrDefault(); + if (result == null) + { + throw new BadRequestException("This user has already been invited."); + } + return result; + } + + public async Task ImportAsync(Guid organizationId, + Guid? importingUserId, + IEnumerable groups, + IEnumerable newUsers, + IEnumerable removeUserExternalIds, + bool overwriteExisting) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + if (!organization.UseDirectory) + { + throw new BadRequestException("Organization cannot use directory syncing."); + } + + var newUsersSet = new HashSet(newUsers?.Select(u => u.ExternalId) ?? new List()); + var existingUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var existingExternalUsers = existingUsers.Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); + var existingExternalUsersIdDict = existingExternalUsers.ToDictionary(u => u.ExternalId, u => u.Id); + + // Users + + // Remove Users + if (removeUserExternalIds?.Any() ?? false) + { + var removeUsersSet = new HashSet(removeUserExternalIds); + var existingUsersDict = existingExternalUsers.ToDictionary(u => u.ExternalId); + + await _organizationUserRepository.DeleteManyAsync(removeUsersSet + .Except(newUsersSet) + .Where(u => existingUsersDict.ContainsKey(u) && existingUsersDict[u].Type != OrganizationUserType.Owner) + .Select(u => existingUsersDict[u].Id)); + } + + if (overwriteExisting) + { + // Remove existing external users that are not in new user set + var usersToDelete = existingExternalUsers.Where(u => + u.Type != OrganizationUserType.Owner && + !newUsersSet.Contains(u.ExternalId) && + existingExternalUsersIdDict.ContainsKey(u.ExternalId)); + await _organizationUserRepository.DeleteManyAsync(usersToDelete.Select(u => u.Id)); + foreach (var deletedUser in usersToDelete) + { + existingExternalUsersIdDict.Remove(deletedUser.ExternalId); + } + } + + if (newUsers?.Any() ?? false) + { + // Marry existing users + var existingUsersEmailsDict = existingUsers + .Where(u => string.IsNullOrWhiteSpace(u.ExternalId)) + .ToDictionary(u => u.Email); + var newUsersEmailsDict = newUsers.ToDictionary(u => u.Email); + var usersToAttach = existingUsersEmailsDict.Keys.Intersect(newUsersEmailsDict.Keys).ToList(); + var usersToUpsert = new List(); + foreach (var user in usersToAttach) + { + var orgUserDetails = existingUsersEmailsDict[user]; + var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserDetails.Id); + if (orgUser != null) + { + orgUser.ExternalId = newUsersEmailsDict[user].ExternalId; + usersToUpsert.Add(orgUser); + existingExternalUsersIdDict.Add(orgUser.ExternalId, orgUser.Id); + } + } + await _organizationUserRepository.UpsertManyAsync(usersToUpsert); + + // Add new users + var existingUsersSet = new HashSet(existingExternalUsersIdDict.Keys); + var usersToAdd = newUsersSet.Except(existingUsersSet).ToList(); + + var seatsAvailable = int.MaxValue; + var enoughSeatsAvailable = true; + if (organization.Seats.HasValue) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); + seatsAvailable = organization.Seats.Value - userCount; + enoughSeatsAvailable = seatsAvailable >= usersToAdd.Count; + } + + var userInvites = new List<(OrganizationUserInvite, string)>(); + foreach (var user in newUsers) + { + if (!usersToAdd.Contains(user.ExternalId) || string.IsNullOrWhiteSpace(user.Email)) + { + continue; + } + + try + { + var invite = new OrganizationUserInvite + { + Emails = new List { user.Email }, + Type = OrganizationUserType.User, + AccessAll = false, + Collections = new List(), + }; + userInvites.Add((invite, user.ExternalId)); + } + catch (BadRequestException) + { + // Thrown when the user is already invited to the organization + continue; + } + } + + var invitedUsers = await InviteUsersAsync(organizationId, importingUserId, userInvites); + foreach (var invitedUser in invitedUsers) + { + existingExternalUsersIdDict.Add(invitedUser.ExternalId, invitedUser.Id); + } + } + + + // Groups + if (groups?.Any() ?? false) + { + if (!organization.UseGroups) + { + throw new BadRequestException("Organization cannot use groups."); + } + + var groupsDict = groups.ToDictionary(g => g.Group.ExternalId); + var existingGroups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); + var existingExternalGroups = existingGroups + .Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); + var existingExternalGroupsDict = existingExternalGroups.ToDictionary(g => g.ExternalId); + + var newGroups = groups + .Where(g => !existingExternalGroupsDict.ContainsKey(g.Group.ExternalId)) + .Select(g => g.Group); + + foreach (var group in newGroups) + { + group.CreationDate = group.RevisionDate = DateTime.UtcNow; + + await _groupRepository.CreateAsync(group); await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, - existingExternalUsersIdDict, - existingGroupUsers.ContainsKey(group.Id) ? existingGroupUsers[group.Id] : null); + existingExternalUsersIdDict); } - } - } - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.DirectorySynced, organization)); - } + var updateGroups = existingExternalGroups + .Where(g => groupsDict.ContainsKey(g.ExternalId)) + .ToList(); - public async Task DeleteSsoUserAsync(Guid userId, Guid? organizationId) - { - await _ssoUserRepository.DeleteAsync(userId, organizationId); - if (organizationId.HasValue) - { - var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId.Value, userId); - if (organizationUser != null) - { - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UnlinkedSso); - } - } - } - - public async Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey) - { - if (!await _currentContext.ManageResetPassword(orgId)) - { - throw new UnauthorizedAccessException(); - } - - // If the keys already exist, error out - var org = await _organizationRepository.GetByIdAsync(orgId); - if (org.PublicKey != null && org.PrivateKey != null) - { - throw new BadRequestException("Organization Keys already exist"); - } - - // Update org with generated public/private key - org.PublicKey = publicKey; - org.PrivateKey = privateKey; - await UpdateAsync(org); - - return org; - } - - private async Task UpdateUsersAsync(Group group, HashSet groupUsers, - Dictionary existingUsersIdDict, HashSet existingUsers = null) - { - var availableUsers = groupUsers.Intersect(existingUsersIdDict.Keys); - var users = new HashSet(availableUsers.Select(u => existingUsersIdDict[u])); - if (existingUsers != null && existingUsers.Count == users.Count && users.SetEquals(existingUsers)) - { - return; - } - - await _groupRepository.UpdateUsersAsync(group.Id, users); - } - - private async Task> GetConfirmedOwnersAsync(Guid organizationId) - { - var owners = await _organizationUserRepository.GetManyByOrganizationAsync(organizationId, - OrganizationUserType.Owner); - return owners.Where(o => o.Status == OrganizationUserStatusType.Confirmed); - } - - private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId) - { - var deviceIds = await GetUserDeviceIdsAsync(userId); - await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(deviceIds, - organizationId.ToString()); - await _pushNotificationService.PushSyncOrgKeysAsync(userId); - } - - - private async Task> GetUserDeviceIdsAsync(Guid userId) - { - var devices = await _deviceRepository.GetManyByUserIdAsync(userId); - return devices.Where(d => !string.IsNullOrWhiteSpace(d.PushToken)).Select(d => d.Id.ToString()); - } - - private async Task ReplaceAndUpdateCache(Organization org, EventType? orgEvent = null) - { - await _organizationRepository.ReplaceAsync(org); - await _applicationCacheService.UpsertOrganizationAbilityAsync(org); - - if (orgEvent.HasValue) - { - await _eventService.LogOrganizationEventAsync(org, orgEvent.Value); - } - } - - private async Task GetOrgById(Guid id) - { - return await _organizationRepository.GetByIdAsync(id); - } - - private void ValidateOrganizationUpgradeParameters(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade) - { - if (!plan.HasAdditionalStorageOption && upgrade.AdditionalStorageGb > 0) - { - throw new BadRequestException("Plan does not allow additional storage."); - } - - if (upgrade.AdditionalStorageGb < 0) - { - throw new BadRequestException("You can't subtract storage!"); - } - - if (!plan.HasPremiumAccessOption && upgrade.PremiumAccessAddon) - { - throw new BadRequestException("This plan does not allow you to buy the premium access addon."); - } - - if (plan.BaseSeats + upgrade.AdditionalSeats <= 0) - { - throw new BadRequestException("You do not have any seats!"); - } - - if (upgrade.AdditionalSeats < 0) - { - throw new BadRequestException("You can't subtract seats!"); - } - - if (!plan.HasAdditionalSeatsOption && upgrade.AdditionalSeats > 0) - { - throw new BadRequestException("Plan does not allow additional users."); - } - - if (plan.HasAdditionalSeatsOption && plan.MaxAdditionalSeats.HasValue && - upgrade.AdditionalSeats > plan.MaxAdditionalSeats.Value) - { - throw new BadRequestException($"Selected plan allows a maximum of " + - $"{plan.MaxAdditionalSeats.GetValueOrDefault(0)} additional users."); - } - } - - private async Task ValidateOrganizationUserUpdatePermissions(Guid organizationId, OrganizationUserType newType, - OrganizationUserType? oldType) - { - if (await _currentContext.OrganizationOwner(organizationId)) - { - return; - } - - if (oldType == OrganizationUserType.Owner || newType == OrganizationUserType.Owner) - { - throw new BadRequestException("Only an Owner can configure another Owner's account."); - } - - if (await _currentContext.OrganizationAdmin(organizationId)) - { - return; - } - - if (oldType == OrganizationUserType.Custom || newType == OrganizationUserType.Custom) - { - throw new BadRequestException("Only Owners and Admins can configure Custom accounts."); - } - - if (!await _currentContext.ManageUsers(organizationId)) - { - throw new BadRequestException("Your account does not have permission to manage users."); - } - - if (oldType == OrganizationUserType.Admin || newType == OrganizationUserType.Admin) - { - throw new BadRequestException("Custom users can not manage Admins or Owners."); - } - } - - private async Task ValidateDeleteOrganizationAsync(Organization organization) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) - { - throw new BadRequestException("You cannot delete an Organization that is using Key Connector."); - } - } - - public async Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId) - { - if (organizationUser.Status == OrganizationUserStatusType.Revoked) - { - throw new BadRequestException("Already revoked."); - } - - if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId.Value) - { - throw new BadRequestException("You cannot revoke yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && - !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) - { - throw new BadRequestException("Only owners can revoke other owners."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationUser.OrganizationId, new[] { organizationUser.Id })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - await _organizationUserRepository.RevokeAsync(organizationUser.Id); - organizationUser.Status = OrganizationUserStatusType.Revoked; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); - } - - public async Task>> RevokeUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? revokingUserId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); - - if (!filteredUsers.Any()) - { - throw new BadRequestException("Users invalid."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUserIds)) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - var deletingUserIsOwner = false; - if (revokingUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } - - var result = new List>(); - - foreach (var organizationUser in filteredUsers) - { - try - { - if (organizationUser.Status == OrganizationUserStatusType.Revoked) + if (updateGroups.Any()) { - throw new BadRequestException("Already revoked."); + var groupUsers = await _groupRepository.GetManyGroupUsersByOrganizationIdAsync(organizationId); + var existingGroupUsers = groupUsers + .GroupBy(gu => gu.GroupId) + .ToDictionary(g => g.Key, g => new HashSet(g.Select(gr => gr.OrganizationUserId))); + + foreach (var group in updateGroups) + { + var updatedGroup = groupsDict[group.ExternalId].Group; + if (group.Name != updatedGroup.Name) + { + group.RevisionDate = DateTime.UtcNow; + group.Name = updatedGroup.Name; + + await _groupRepository.ReplaceAsync(group); + } + + await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, + existingExternalUsersIdDict, + existingGroupUsers.ContainsKey(group.Id) ? existingGroupUsers[group.Id] : null); + } } - - if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId) - { - throw new BadRequestException("You cannot revoke yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && !deletingUserIsOwner) - { - throw new BadRequestException("Only owners can revoke other owners."); - } - - await _organizationUserRepository.RevokeAsync(organizationUser.Id); - organizationUser.Status = OrganizationUserStatusType.Revoked; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); - - result.Add(Tuple.Create(organizationUser, "")); } - catch (BadRequestException e) + + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.DirectorySynced, organization)); + } + + public async Task DeleteSsoUserAsync(Guid userId, Guid? organizationId) + { + await _ssoUserRepository.DeleteAsync(userId, organizationId); + if (organizationId.HasValue) { - result.Add(Tuple.Create(organizationUser, e.Message)); - } - } - - return result; - } - - public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService) - { - if (organizationUser.Status != OrganizationUserStatusType.Revoked) - { - throw new BadRequestException("Already active."); - } - - if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId.Value) - { - throw new BadRequestException("You cannot restore yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && - !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) - { - throw new BadRequestException("Only owners can restore other owners."); - } - - await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); - - var status = GetPriorActiveOrganizationUserStatusType(organizationUser); - - await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); - organizationUser.Status = status; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); - } - - public async Task>> RestoreUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); - - if (!filteredUsers.Any()) - { - throw new BadRequestException("Users invalid."); - } - - var deletingUserIsOwner = false; - if (restoringUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } - - var result = new List>(); - - foreach (var organizationUser in filteredUsers) - { - try - { - if (organizationUser.Status != OrganizationUserStatusType.Revoked) + var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId.Value, userId); + if (organizationUser != null) { - throw new BadRequestException("Already active."); + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UnlinkedSso); } + } + } - if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId) + public async Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey) + { + if (!await _currentContext.ManageResetPassword(orgId)) + { + throw new UnauthorizedAccessException(); + } + + // If the keys already exist, error out + var org = await _organizationRepository.GetByIdAsync(orgId); + if (org.PublicKey != null && org.PrivateKey != null) + { + throw new BadRequestException("Organization Keys already exist"); + } + + // Update org with generated public/private key + org.PublicKey = publicKey; + org.PrivateKey = privateKey; + await UpdateAsync(org); + + return org; + } + + private async Task UpdateUsersAsync(Group group, HashSet groupUsers, + Dictionary existingUsersIdDict, HashSet existingUsers = null) + { + var availableUsers = groupUsers.Intersect(existingUsersIdDict.Keys); + var users = new HashSet(availableUsers.Select(u => existingUsersIdDict[u])); + if (existingUsers != null && existingUsers.Count == users.Count && users.SetEquals(existingUsers)) + { + return; + } + + await _groupRepository.UpdateUsersAsync(group.Id, users); + } + + private async Task> GetConfirmedOwnersAsync(Guid organizationId) + { + var owners = await _organizationUserRepository.GetManyByOrganizationAsync(organizationId, + OrganizationUserType.Owner); + return owners.Where(o => o.Status == OrganizationUserStatusType.Confirmed); + } + + private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId) + { + var deviceIds = await GetUserDeviceIdsAsync(userId); + await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(deviceIds, + organizationId.ToString()); + await _pushNotificationService.PushSyncOrgKeysAsync(userId); + } + + + private async Task> GetUserDeviceIdsAsync(Guid userId) + { + var devices = await _deviceRepository.GetManyByUserIdAsync(userId); + return devices.Where(d => !string.IsNullOrWhiteSpace(d.PushToken)).Select(d => d.Id.ToString()); + } + + private async Task ReplaceAndUpdateCache(Organization org, EventType? orgEvent = null) + { + await _organizationRepository.ReplaceAsync(org); + await _applicationCacheService.UpsertOrganizationAbilityAsync(org); + + if (orgEvent.HasValue) + { + await _eventService.LogOrganizationEventAsync(org, orgEvent.Value); + } + } + + private async Task GetOrgById(Guid id) + { + return await _organizationRepository.GetByIdAsync(id); + } + + private void ValidateOrganizationUpgradeParameters(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade) + { + if (!plan.HasAdditionalStorageOption && upgrade.AdditionalStorageGb > 0) + { + throw new BadRequestException("Plan does not allow additional storage."); + } + + if (upgrade.AdditionalStorageGb < 0) + { + throw new BadRequestException("You can't subtract storage!"); + } + + if (!plan.HasPremiumAccessOption && upgrade.PremiumAccessAddon) + { + throw new BadRequestException("This plan does not allow you to buy the premium access addon."); + } + + if (plan.BaseSeats + upgrade.AdditionalSeats <= 0) + { + throw new BadRequestException("You do not have any seats!"); + } + + if (upgrade.AdditionalSeats < 0) + { + throw new BadRequestException("You can't subtract seats!"); + } + + if (!plan.HasAdditionalSeatsOption && upgrade.AdditionalSeats > 0) + { + throw new BadRequestException("Plan does not allow additional users."); + } + + if (plan.HasAdditionalSeatsOption && plan.MaxAdditionalSeats.HasValue && + upgrade.AdditionalSeats > plan.MaxAdditionalSeats.Value) + { + throw new BadRequestException($"Selected plan allows a maximum of " + + $"{plan.MaxAdditionalSeats.GetValueOrDefault(0)} additional users."); + } + } + + private async Task ValidateOrganizationUserUpdatePermissions(Guid organizationId, OrganizationUserType newType, + OrganizationUserType? oldType) + { + if (await _currentContext.OrganizationOwner(organizationId)) + { + return; + } + + if (oldType == OrganizationUserType.Owner || newType == OrganizationUserType.Owner) + { + throw new BadRequestException("Only an Owner can configure another Owner's account."); + } + + if (await _currentContext.OrganizationAdmin(organizationId)) + { + return; + } + + if (oldType == OrganizationUserType.Custom || newType == OrganizationUserType.Custom) + { + throw new BadRequestException("Only Owners and Admins can configure Custom accounts."); + } + + if (!await _currentContext.ManageUsers(organizationId)) + { + throw new BadRequestException("Your account does not have permission to manage users."); + } + + if (oldType == OrganizationUserType.Admin || newType == OrganizationUserType.Admin) + { + throw new BadRequestException("Custom users can not manage Admins or Owners."); + } + } + + private async Task ValidateDeleteOrganizationAsync(Organization organization) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) + { + throw new BadRequestException("You cannot delete an Organization that is using Key Connector."); + } + } + + public async Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId) + { + if (organizationUser.Status == OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already revoked."); + } + + if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId.Value) + { + throw new BadRequestException("You cannot revoke yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && + !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) + { + throw new BadRequestException("Only owners can revoke other owners."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationUser.OrganizationId, new[] { organizationUser.Id })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + await _organizationUserRepository.RevokeAsync(organizationUser.Id); + organizationUser.Status = OrganizationUserStatusType.Revoked; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); + } + + public async Task>> RevokeUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? revokingUserId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) + .ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException("Users invalid."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUserIds)) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + var deletingUserIsOwner = false; + if (revokingUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var result = new List>(); + + foreach (var organizationUser in filteredUsers) + { + try { - throw new BadRequestException("You cannot restore yourself."); - } + if (organizationUser.Status == OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already revoked."); + } - if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && !deletingUserIsOwner) + if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId) + { + throw new BadRequestException("You cannot revoke yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException("Only owners can revoke other owners."); + } + + await _organizationUserRepository.RevokeAsync(organizationUser.Id); + organizationUser.Status = OrganizationUserStatusType.Revoked; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); + + result.Add(Tuple.Create(organizationUser, "")); + } + catch (BadRequestException e) { - throw new BadRequestException("Only owners can restore other owners."); + result.Add(Tuple.Create(organizationUser, e.Message)); } - - await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); - - var status = GetPriorActiveOrganizationUserStatusType(organizationUser); - - await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); - organizationUser.Status = status; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); - - result.Add(Tuple.Create(organizationUser, "")); } - catch (BadRequestException e) + + return result; + } + + public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService) + { + if (organizationUser.Status != OrganizationUserStatusType.Revoked) { - result.Add(Tuple.Create(organizationUser, e.Message)); + throw new BadRequestException("Already active."); } - } - return result; - } - - private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, IUserService userService) - { - // An invited OrganizationUser isn't linked with a user account yet, so these checks are irrelevant - // The user will be subject to the same checks when they try to accept the invite - if (GetPriorActiveOrganizationUserStatusType(orgUser) == OrganizationUserStatusType.Invited) - { - return; - } - - var userId = orgUser.UserId.Value; - - // Enforce Single Organization Policy of organization user is being restored to - var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(userId); - var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); - var singleOrgPoliciesApplyingToRevokedUsers = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, - PolicyType.SingleOrg, OrganizationUserStatusType.Revoked); - var singleOrgPolicyApplies = singleOrgPoliciesApplyingToRevokedUsers.Any(p => p.OrganizationId == orgUser.OrganizationId); - - if (hasOtherOrgs && singleOrgPolicyApplies) - { - throw new BadRequestException("You cannot restore this user until " + - "they leave or remove all other organizations."); - } - - // Enforce Single Organization Policy of other organizations user is a member of - var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId, - PolicyType.SingleOrg); - if (singleOrgPolicyCount > 0) - { - throw new BadRequestException("You cannot restore this user because they are a member of " + - "another organization which forbids it"); - } - - // Enforce Two Factor Authentication Policy of organization user is trying to join - var user = await _userRepository.GetByIdAsync(userId); - if (!await userService.TwoFactorIsEnabledAsync(user)) - { - var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, - PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); - if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId.Value) { - throw new BadRequestException("You cannot restore this user until they enable " + - "two-step login on their user account."); + throw new BadRequestException("You cannot restore yourself."); } - } - } - static OrganizationUserStatusType GetPriorActiveOrganizationUserStatusType(OrganizationUser organizationUser) - { - // Determine status to revert back to - var status = OrganizationUserStatusType.Invited; - if (organizationUser.UserId.HasValue && string.IsNullOrWhiteSpace(organizationUser.Email)) - { - // Has UserId & Email is null, then Accepted - status = OrganizationUserStatusType.Accepted; - if (!string.IsNullOrWhiteSpace(organizationUser.Key)) + if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && + !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) { - // We have an org key for this user, user was confirmed - status = OrganizationUserStatusType.Confirmed; + throw new BadRequestException("Only owners can restore other owners."); + } + + await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); + + var status = GetPriorActiveOrganizationUserStatusType(organizationUser); + + await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); + organizationUser.Status = status; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + } + + public async Task>> RestoreUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) + .ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException("Users invalid."); + } + + var deletingUserIsOwner = false; + if (restoringUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var result = new List>(); + + foreach (var organizationUser in filteredUsers) + { + try + { + if (organizationUser.Status != OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already active."); + } + + if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId) + { + throw new BadRequestException("You cannot restore yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException("Only owners can restore other owners."); + } + + await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); + + var status = GetPriorActiveOrganizationUserStatusType(organizationUser); + + await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); + organizationUser.Status = status; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + + result.Add(Tuple.Create(organizationUser, "")); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(organizationUser, e.Message)); + } + } + + return result; + } + + private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, IUserService userService) + { + // An invited OrganizationUser isn't linked with a user account yet, so these checks are irrelevant + // The user will be subject to the same checks when they try to accept the invite + if (GetPriorActiveOrganizationUserStatusType(orgUser) == OrganizationUserStatusType.Invited) + { + return; + } + + var userId = orgUser.UserId.Value; + + // Enforce Single Organization Policy of organization user is being restored to + var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(userId); + var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); + var singleOrgPoliciesApplyingToRevokedUsers = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, + PolicyType.SingleOrg, OrganizationUserStatusType.Revoked); + var singleOrgPolicyApplies = singleOrgPoliciesApplyingToRevokedUsers.Any(p => p.OrganizationId == orgUser.OrganizationId); + + if (hasOtherOrgs && singleOrgPolicyApplies) + { + throw new BadRequestException("You cannot restore this user until " + + "they leave or remove all other organizations."); + } + + // Enforce Single Organization Policy of other organizations user is a member of + var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId, + PolicyType.SingleOrg); + if (singleOrgPolicyCount > 0) + { + throw new BadRequestException("You cannot restore this user because they are a member of " + + "another organization which forbids it"); + } + + // Enforce Two Factor Authentication Policy of organization user is trying to join + var user = await _userRepository.GetByIdAsync(userId); + if (!await userService.TwoFactorIsEnabledAsync(user)) + { + var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, + PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); + if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + { + throw new BadRequestException("You cannot restore this user until they enable " + + "two-step login on their user account."); + } } } - return status; + static OrganizationUserStatusType GetPriorActiveOrganizationUserStatusType(OrganizationUser organizationUser) + { + // Determine status to revert back to + var status = OrganizationUserStatusType.Invited; + if (organizationUser.UserId.HasValue && string.IsNullOrWhiteSpace(organizationUser.Email)) + { + // Has UserId & Email is null, then Accepted + status = OrganizationUserStatusType.Accepted; + if (!string.IsNullOrWhiteSpace(organizationUser.Key)) + { + // We have an org key for this user, user was confirmed + status = OrganizationUserStatusType.Confirmed; + } + } + + return status; + } } } diff --git a/src/Core/Services/Implementations/PolicyService.cs b/src/Core/Services/Implementations/PolicyService.cs index 938975f591..e84a124e6c 100644 --- a/src/Core/Services/Implementations/PolicyService.cs +++ b/src/Core/Services/Implementations/PolicyService.cs @@ -3,169 +3,170 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class PolicyService : IPolicyService +namespace Bit.Core.Services { - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IMailService _mailService; - - public PolicyService( - IEventService eventService, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - IMailService mailService) + public class PolicyService : IPolicyService { - _eventService = eventService; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _policyRepository = policyRepository; - _ssoConfigRepository = ssoConfigRepository; - _mailService = mailService; - } + private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IMailService _mailService; - public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, - Guid? savingUserId) - { - var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); - if (org == null) + public PolicyService( + IEventService eventService, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + IMailService mailService) { - throw new BadRequestException("Organization not found"); + _eventService = eventService; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _policyRepository = policyRepository; + _ssoConfigRepository = ssoConfigRepository; + _mailService = mailService; } - if (!org.UsePolicies) + public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, + Guid? savingUserId) { - throw new BadRequestException("This organization cannot use policies."); - } - - // Handle dependent policy checks - switch (policy.Type) - { - case PolicyType.SingleOrg: - if (!policy.Enabled) - { - await RequiredBySsoAsync(org); - await RequiredByVaultTimeoutAsync(org); - await RequiredByKeyConnectorAsync(org); - } - break; - - case PolicyType.RequireSso: - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - else - { - await RequiredByKeyConnectorAsync(org); - } - break; - - case PolicyType.MaximumVaultTimeout: - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - break; - } - - var now = DateTime.UtcNow; - if (policy.Id == default(Guid)) - { - policy.CreationDate = now; - } - - if (policy.Enabled) - { - var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); - if (!currentPolicy?.Enabled ?? true) + var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); + if (org == null) { - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( - policy.OrganizationId); - var removableOrgUsers = orgUsers.Where(ou => - ou.Status != Enums.OrganizationUserStatusType.Invited && - ou.Type != Enums.OrganizationUserType.Owner && ou.Type != Enums.OrganizationUserType.Admin && - ou.UserId != savingUserId); - switch (policy.Type) + throw new BadRequestException("Organization not found"); + } + + if (!org.UsePolicies) + { + throw new BadRequestException("This organization cannot use policies."); + } + + // Handle dependent policy checks + switch (policy.Type) + { + case PolicyType.SingleOrg: + if (!policy.Enabled) + { + await RequiredBySsoAsync(org); + await RequiredByVaultTimeoutAsync(org); + await RequiredByKeyConnectorAsync(org); + } + break; + + case PolicyType.RequireSso: + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + else + { + await RequiredByKeyConnectorAsync(org); + } + break; + + case PolicyType.MaximumVaultTimeout: + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + break; + } + + var now = DateTime.UtcNow; + if (policy.Id == default(Guid)) + { + policy.CreationDate = now; + } + + if (policy.Enabled) + { + var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); + if (!currentPolicy?.Enabled ?? true) { - case Enums.PolicyType.TwoFactorAuthentication: - foreach (var orgUser in removableOrgUsers) - { - if (!await userService.TwoFactorIsEnabledAsync(orgUser)) + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( + policy.OrganizationId); + var removableOrgUsers = orgUsers.Where(ou => + ou.Status != Enums.OrganizationUserStatusType.Invited && + ou.Type != Enums.OrganizationUserType.Owner && ou.Type != Enums.OrganizationUserType.Admin && + ou.UserId != savingUserId); + switch (policy.Type) + { + case Enums.PolicyType.TwoFactorAuthentication: + foreach (var orgUser in removableOrgUsers) { - await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, - savingUserId); - await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( - org.Name, orgUser.Email); + if (!await userService.TwoFactorIsEnabledAsync(orgUser)) + { + await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + org.Name, orgUser.Email); + } } - } - break; - case Enums.PolicyType.SingleOrg: - var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( - removableOrgUsers.Select(ou => ou.UserId.Value)); - foreach (var orgUser in removableOrgUsers) - { - if (userOrgs.Any(ou => ou.UserId == orgUser.UserId - && ou.OrganizationId != org.Id - && ou.Status != OrganizationUserStatusType.Invited)) + break; + case Enums.PolicyType.SingleOrg: + var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( + removableOrgUsers.Select(ou => ou.UserId.Value)); + foreach (var orgUser in removableOrgUsers) { - await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, - savingUserId); - await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( - org.Name, orgUser.Email); + if (userOrgs.Any(ou => ou.UserId == orgUser.UserId + && ou.OrganizationId != org.Id + && ou.Status != OrganizationUserStatusType.Invited)) + { + await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( + org.Name, orgUser.Email); + } } - } - break; - default: - break; + break; + default: + break; + } } } + policy.RevisionDate = now; + await _policyRepository.UpsertAsync(policy); + await _eventService.LogPolicyEventAsync(policy, Enums.EventType.Policy_Updated); } - policy.RevisionDate = now; - await _policyRepository.UpsertAsync(policy); - await _eventService.LogPolicyEventAsync(policy, Enums.EventType.Policy_Updated); - } - private async Task DependsOnSingleOrgAsync(Organization org) - { - var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg); - if (singleOrg?.Enabled != true) + private async Task DependsOnSingleOrgAsync(Organization org) { - throw new BadRequestException("Single Organization policy not enabled."); + var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg); + if (singleOrg?.Enabled != true) + { + throw new BadRequestException("Single Organization policy not enabled."); + } } - } - private async Task RequiredBySsoAsync(Organization org) - { - var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso); - if (requireSso?.Enabled == true) + private async Task RequiredBySsoAsync(Organization org) { - throw new BadRequestException("Single Sign-On Authentication policy is enabled."); + var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso); + if (requireSso?.Enabled == true) + { + throw new BadRequestException("Single Sign-On Authentication policy is enabled."); + } } - } - private async Task RequiredByKeyConnectorAsync(Organization org) - { - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); - if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) + private async Task RequiredByKeyConnectorAsync(Organization org) { - throw new BadRequestException("Key Connector is enabled."); + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); + if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) + { + throw new BadRequestException("Key Connector is enabled."); + } } - } - private async Task RequiredByVaultTimeoutAsync(Organization org) - { - var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout); - if (vaultTimeout?.Enabled == true) + private async Task RequiredByVaultTimeoutAsync(Organization org) { - throw new BadRequestException("Maximum Vault Timeout policy is enabled."); + var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout); + if (vaultTimeout?.Enabled == true) + { + throw new BadRequestException("Maximum Vault Timeout policy is enabled."); + } } } } diff --git a/src/Core/Services/Implementations/RelayPushNotificationService.cs b/src/Core/Services/Implementations/RelayPushNotificationService.cs index b3670ad7b1..b66cb7ca10 100644 --- a/src/Core/Services/Implementations/RelayPushNotificationService.cs +++ b/src/Core/Services/Implementations/RelayPushNotificationService.cs @@ -8,218 +8,219 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class RelayPushNotificationService : BaseIdentityClientService, IPushNotificationService +namespace Bit.Core.Services { - private readonly IDeviceRepository _deviceRepository; - private readonly IHttpContextAccessor _httpContextAccessor; - - public RelayPushNotificationService( - IHttpClientFactory httpFactory, - IDeviceRepository deviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, - ILogger logger) - : base( - httpFactory, - globalSettings.PushRelayBaseUri, - globalSettings.Installation.IdentityUri, - "api.push", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) + public class RelayPushNotificationService : BaseIdentityClientService, IPushNotificationService { - _deviceRepository = deviceRepository; - _httpContextAccessor = httpContextAccessor; - } + private readonly IDeviceRepository _deviceRepository; + private readonly IHttpContextAccessor _httpContextAccessor; - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } - - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } - - public async Task PushSyncCipherDeleteAsync(Cipher cipher) - { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + public RelayPushNotificationService( + IHttpClientFactory httpFactory, + IDeviceRepository deviceRepository, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor, + ILogger logger) + : base( + httpFactory, + globalSettings.PushRelayBaseUri, + globalSettings.Installation.IdentityUri, + "api.push", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) { - // We cannot send org pushes since access logic is much more complicated than just the fact that they belong - // to the organization. Potentially we could blindly send to just users that have the access all permission - // device registration needs to be more granular to handle that appropriately. A more brute force approach could - // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. - - // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); + _deviceRepository = deviceRepository; + _httpContextAccessor = httpContextAccessor; } - else if (cipher.UserId.HasValue) - { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - }; - await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); } - } - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate - }; - - await SendPayloadToUserAsync(folder.UserId, type, message, true); - } - - public async Task PushSyncCiphersAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification - { - UserId = userId, - Date = DateTime.UtcNow - }; - - await SendPayloadToUserAsync(userId, type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) - { - var message = new SyncSendPushNotification - { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate - }; - - await SendPayloadToUserAsync(message.UserId, type, message, true); + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); } - } - private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) - { - var request = new PushSendRequestModel + public async Task PushSyncCipherDeleteAsync(Cipher cipher) { - UserId = userId.ToString(), - Type = type, - Payload = payload - }; + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } - await AddCurrentContextAsync(request, excludeCurrentContext); - await SendAsync(HttpMethod.Post, "push/send", request); - } - - private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) - { - var request = new PushSendRequestModel + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) { - OrganizationId = orgId.ToString(), - Type = type, - Payload = payload - }; - - await AddCurrentContextAsync(request, excludeCurrentContext); - await SendAsync(HttpMethod.Post, "push/send", request); - } - - private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier) - { - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier)) - { - var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier); - if (device != null) + if (cipher.OrganizationId.HasValue) { - request.DeviceId = device.Id.ToString(); + // We cannot send org pushes since access logic is much more complicated than just the fact that they belong + // to the organization. Potentially we could blindly send to just users that have the access all permission + // device registration needs to be more granular to handle that appropriately. A more brute force approach could + // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. + + // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); } - if (addIdentifier) + else if (cipher.UserId.HasValue) { - request.Identifier = currentContext.DeviceIdentifier; + var message = new SyncCipherPushNotification + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + }; + + await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); } } - } - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - throw new NotImplementedException(); - } + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - throw new NotImplementedException(); + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate + }; + + await SendPayloadToUserAsync(folder.UserId, type, message, true); + } + + public async Task PushSyncCiphersAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncCiphers); + } + + public async Task PushSyncVaultAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncVault); + } + + public async Task PushSyncOrgKeysAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification + { + UserId = userId, + Date = DateTime.UtcNow + }; + + await SendPayloadToUserAsync(userId, type, message, false); + } + + public async Task PushSyncSendCreateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendCreate); + } + + public async Task PushSyncSendUpdateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendUpdate); + } + + public async Task PushSyncSendDeleteAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendDelete); + } + + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) + { + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; + + await SendPayloadToUserAsync(message.UserId, type, message, true); + } + } + + private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) + { + var request = new PushSendRequestModel + { + UserId = userId.ToString(), + Type = type, + Payload = payload + }; + + await AddCurrentContextAsync(request, excludeCurrentContext); + await SendAsync(HttpMethod.Post, "push/send", request); + } + + private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) + { + var request = new PushSendRequestModel + { + OrganizationId = orgId.ToString(), + Type = type, + Payload = payload + }; + + await AddCurrentContextAsync(request, excludeCurrentContext); + await SendAsync(HttpMethod.Post, "push/send", request); + } + + private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier) + { + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier)) + { + var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier); + if (device != null) + { + request.DeviceId = device.Id.ToString(); + } + if (addIdentifier) + { + request.Identifier = currentContext.DeviceIdentifier; + } + } + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + throw new NotImplementedException(); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + throw new NotImplementedException(); + } } } diff --git a/src/Core/Services/Implementations/RelayPushRegistrationService.cs b/src/Core/Services/Implementations/RelayPushRegistrationService.cs index 2e3087421b..82ae88799d 100644 --- a/src/Core/Services/Implementations/RelayPushRegistrationService.cs +++ b/src/Core/Services/Implementations/RelayPushRegistrationService.cs @@ -3,64 +3,65 @@ using Bit.Core.Models.Api; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; - -public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegistrationService +namespace Bit.Core.Services { - - public RelayPushRegistrationService( - IHttpClientFactory httpFactory, - GlobalSettings globalSettings, - ILogger logger) - : base( - httpFactory, - globalSettings.PushRelayBaseUri, - globalSettings.Installation.IdentityUri, - "api.push", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) + public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegistrationService { - } - public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) - { - var requestModel = new PushRegistrationRequestModel + public RelayPushRegistrationService( + IHttpClientFactory httpFactory, + GlobalSettings globalSettings, + ILogger logger) + : base( + httpFactory, + globalSettings.PushRelayBaseUri, + globalSettings.Installation.IdentityUri, + "api.push", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) { - DeviceId = deviceId, - Identifier = identifier, - PushToken = pushToken, - Type = type, - UserId = userId - }; - await SendAsync(HttpMethod.Post, "push/register", requestModel); - } - - public async Task DeleteRegistrationAsync(string deviceId) - { - await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId)); - } - - public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - if (!deviceIds.Any()) - { - return; } - var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); - await SendAsync(HttpMethod.Put, "push/add-organization", requestModel); - } - - public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - if (!deviceIds.Any()) + public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type) { - return; + var requestModel = new PushRegistrationRequestModel + { + DeviceId = deviceId, + Identifier = identifier, + PushToken = pushToken, + Type = type, + UserId = userId + }; + await SendAsync(HttpMethod.Post, "push/register", requestModel); } - var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); - await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel); + public async Task DeleteRegistrationAsync(string deviceId) + { + await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId)); + } + + public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + if (!deviceIds.Any()) + { + return; + } + + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); + await SendAsync(HttpMethod.Put, "push/add-organization", requestModel); + } + + public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + if (!deviceIds.Any()) + { + return; + } + + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); + await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel); + } } } diff --git a/src/Core/Services/Implementations/RepositoryEventWriteService.cs b/src/Core/Services/Implementations/RepositoryEventWriteService.cs index a8299c1e88..11d028340b 100644 --- a/src/Core/Services/Implementations/RepositoryEventWriteService.cs +++ b/src/Core/Services/Implementations/RepositoryEventWriteService.cs @@ -1,25 +1,26 @@ using Bit.Core.Models.Data; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class RepositoryEventWriteService : IEventWriteService +namespace Bit.Core.Services { - private readonly IEventRepository _eventRepository; - - public RepositoryEventWriteService( - IEventRepository eventRepository) + public class RepositoryEventWriteService : IEventWriteService { - _eventRepository = eventRepository; - } + private readonly IEventRepository _eventRepository; - public async Task CreateAsync(IEvent e) - { - await _eventRepository.CreateAsync(e); - } + public RepositoryEventWriteService( + IEventRepository eventRepository) + { + _eventRepository = eventRepository; + } - public async Task CreateManyAsync(IEnumerable e) - { - await _eventRepository.CreateManyAsync(e); + public async Task CreateAsync(IEvent e) + { + await _eventRepository.CreateAsync(e); + } + + public async Task CreateManyAsync(IEnumerable e) + { + await _eventRepository.CreateManyAsync(e); + } } } diff --git a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs b/src/Core/Services/Implementations/SendGridMailDeliveryService.cs index a35d119970..0e34b170d4 100644 --- a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs +++ b/src/Core/Services/Implementations/SendGridMailDeliveryService.cs @@ -6,109 +6,110 @@ using Microsoft.Extensions.Logging; using SendGrid; using SendGrid.Helpers.Mail; -namespace Bit.Core.Services; - -public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable +namespace Bit.Core.Services { - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly ISendGridClient _client; - private readonly string _senderTag; - private readonly string _replyToEmail; - - public SendGridMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger) - : this(new SendGridClient(globalSettings.Mail.SendGridApiKey), - globalSettings, hostingEnvironment, logger) + public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable { - } + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly ISendGridClient _client; + private readonly string _senderTag; + private readonly string _replyToEmail; - public void Dispose() - { - // TODO: nothing to dispose - } - - public SendGridMailDeliveryService( - ISendGridClient client, - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger) - { - if (string.IsNullOrWhiteSpace(globalSettings.Mail?.SendGridApiKey)) + public SendGridMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger) + : this(new SendGridClient(globalSettings.Mail.SendGridApiKey), + globalSettings, hostingEnvironment, logger) { - throw new ArgumentNullException(nameof(globalSettings.Mail.SendGridApiKey)); } - _globalSettings = globalSettings; - _hostingEnvironment = hostingEnvironment; - _logger = logger; - _client = client; - _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; - _replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); - } - - public async Task SendEmailAsync(MailMessage message) - { - var msg = new SendGridMessage(); - msg.SetFrom(new EmailAddress(_replyToEmail, _globalSettings.SiteName)); - msg.AddTos(message.ToEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); - if (message.BccEmails?.Any() ?? false) + public void Dispose() { - msg.AddBccs(message.BccEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); + // TODO: nothing to dispose } - msg.SetSubject(message.Subject); - msg.AddContent(MimeType.Text, message.TextContent); - msg.AddContent(MimeType.Html, message.HtmlContent); - - msg.AddCategory($"type:{message.Category}"); - msg.AddCategory($"env:{_hostingEnvironment.EnvironmentName}"); - msg.AddCategory($"sender:{_senderTag}"); - - msg.SetClickTracking(false, false); - msg.SetOpenTracking(false); - - if (message.MetaData != null && - message.MetaData.ContainsKey("SendGridBypassListManagement") && - Convert.ToBoolean(message.MetaData["SendGridBypassListManagement"])) + public SendGridMailDeliveryService( + ISendGridClient client, + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger) { - msg.SetBypassListManagement(true); - } - - try - { - var success = await SendAsync(msg, false); - if (!success) + if (string.IsNullOrWhiteSpace(globalSettings.Mail?.SendGridApiKey)) { - _logger.LogWarning("Failed to send email. Retrying..."); + throw new ArgumentNullException(nameof(globalSettings.Mail.SendGridApiKey)); + } + + _globalSettings = globalSettings; + _hostingEnvironment = hostingEnvironment; + _logger = logger; + _client = client; + _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; + _replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); + } + + public async Task SendEmailAsync(MailMessage message) + { + var msg = new SendGridMessage(); + msg.SetFrom(new EmailAddress(_replyToEmail, _globalSettings.SiteName)); + msg.AddTos(message.ToEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); + if (message.BccEmails?.Any() ?? false) + { + msg.AddBccs(message.BccEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); + } + + msg.SetSubject(message.Subject); + msg.AddContent(MimeType.Text, message.TextContent); + msg.AddContent(MimeType.Html, message.HtmlContent); + + msg.AddCategory($"type:{message.Category}"); + msg.AddCategory($"env:{_hostingEnvironment.EnvironmentName}"); + msg.AddCategory($"sender:{_senderTag}"); + + msg.SetClickTracking(false, false); + msg.SetOpenTracking(false); + + if (message.MetaData != null && + message.MetaData.ContainsKey("SendGridBypassListManagement") && + Convert.ToBoolean(message.MetaData["SendGridBypassListManagement"])) + { + msg.SetBypassListManagement(true); + } + + try + { + var success = await SendAsync(msg, false); + if (!success) + { + _logger.LogWarning("Failed to send email. Retrying..."); + await SendAsync(msg, true); + } + } + catch (Exception e) + { + _logger.LogWarning(e, "Failed to send email (with exception). Retrying..."); await SendAsync(msg, true); + throw; } } - catch (Exception e) - { - _logger.LogWarning(e, "Failed to send email (with exception). Retrying..."); - await SendAsync(msg, true); - throw; - } - } - private async Task SendAsync(SendGridMessage message, bool retry) - { - if (retry) + private async Task SendAsync(SendGridMessage message, bool retry) { - // wait and try again - await Task.Delay(2000); - } + if (retry) + { + // wait and try again + await Task.Delay(2000); + } - var response = await _client.SendEmailAsync(message); - if (!response.IsSuccessStatusCode) - { - var responseBody = await response.Body.ReadAsStringAsync(); - _logger.LogError("SendGrid email sending failed with {0}: {1}", response.StatusCode, responseBody); + var response = await _client.SendEmailAsync(message); + if (!response.IsSuccessStatusCode) + { + var responseBody = await response.Body.ReadAsStringAsync(); + _logger.LogError("SendGrid email sending failed with {0}: {1}", response.StatusCode, responseBody); + } + return response.IsSuccessStatusCode; } - return response.IsSuccessStatusCode; } } diff --git a/src/Core/Services/Implementations/SendService.cs b/src/Core/Services/Implementations/SendService.cs index f16f2da419..6c41c6d9c7 100644 --- a/src/Core/Services/Implementations/SendService.cs +++ b/src/Core/Services/Implementations/SendService.cs @@ -11,329 +11,330 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Services; - -public class SendService : ISendService +namespace Bit.Core.Services { - public const long MAX_FILE_SIZE = Constants.FileSize501mb; - public const string MAX_FILE_SIZE_READABLE = "500 MB"; - private readonly ISendRepository _sendRepository; - private readonly IUserRepository _userRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly ISendFileStorageService _sendFileStorageService; - private readonly IPasswordHasher _passwordHasher; - private readonly IPushNotificationService _pushService; - private readonly IReferenceEventService _referenceEventService; - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - private const long _fileSizeLeeway = 1024L * 1024L; // 1MB - - public SendService( - ISendRepository sendRepository, - IUserRepository userRepository, - IUserService userService, - IOrganizationRepository organizationRepository, - ISendFileStorageService sendFileStorageService, - IPasswordHasher passwordHasher, - IPushNotificationService pushService, - IReferenceEventService referenceEventService, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - ICurrentContext currentContext) + public class SendService : ISendService { - _sendRepository = sendRepository; - _userRepository = userRepository; - _userService = userService; - _policyRepository = policyRepository; - _organizationRepository = organizationRepository; - _sendFileStorageService = sendFileStorageService; - _passwordHasher = passwordHasher; - _pushService = pushService; - _referenceEventService = referenceEventService; - _globalSettings = globalSettings; - _currentContext = currentContext; - } + public const long MAX_FILE_SIZE = Constants.FileSize501mb; + public const string MAX_FILE_SIZE_READABLE = "500 MB"; + private readonly ISendRepository _sendRepository; + private readonly IUserRepository _userRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IUserService _userService; + private readonly IOrganizationRepository _organizationRepository; + private readonly ISendFileStorageService _sendFileStorageService; + private readonly IPasswordHasher _passwordHasher; + private readonly IPushNotificationService _pushService; + private readonly IReferenceEventService _referenceEventService; + private readonly GlobalSettings _globalSettings; + private readonly ICurrentContext _currentContext; + private const long _fileSizeLeeway = 1024L * 1024L; // 1MB - public async Task SaveSendAsync(Send send) - { - // Make sure user can save Sends - await ValidateUserCanSaveAsync(send.UserId, send); - - if (send.Id == default(Guid)) + public SendService( + ISendRepository sendRepository, + IUserRepository userRepository, + IUserService userService, + IOrganizationRepository organizationRepository, + ISendFileStorageService sendFileStorageService, + IPasswordHasher passwordHasher, + IPushNotificationService pushService, + IReferenceEventService referenceEventService, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + ICurrentContext currentContext) { - await _sendRepository.CreateAsync(send); - await _pushService.PushSyncSendCreateAsync(send); - await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated); - } - else - { - send.RevisionDate = DateTime.UtcNow; - await _sendRepository.UpsertAsync(send); - await _pushService.PushSyncSendUpdateAsync(send); - } - } - - public async Task SaveFileSendAsync(Send send, SendFileData data, long fileLength) - { - if (send.Type != SendType.File) - { - throw new BadRequestException("Send is not of type \"file\"."); + _sendRepository = sendRepository; + _userRepository = userRepository; + _userService = userService; + _policyRepository = policyRepository; + _organizationRepository = organizationRepository; + _sendFileStorageService = sendFileStorageService; + _passwordHasher = passwordHasher; + _pushService = pushService; + _referenceEventService = referenceEventService; + _globalSettings = globalSettings; + _currentContext = currentContext; } - if (fileLength < 1) + public async Task SaveSendAsync(Send send) { - throw new BadRequestException("No file data."); - } + // Make sure user can save Sends + await ValidateUserCanSaveAsync(send.UserId, send); - var storageBytesRemaining = await StorageRemainingForSendAsync(send); - - if (storageBytesRemaining < fileLength) - { - throw new BadRequestException("Not enough storage available."); - } - - var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); - - try - { - data.Id = fileId; - data.Size = fileLength; - data.Validated = false; - send.Data = JsonSerializer.Serialize(data, - JsonHelpers.IgnoreWritingNull); - await SaveSendAsync(send); - return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId); - } - catch - { - // Clean up since this is not transactional - await _sendFileStorageService.DeleteFileAsync(send, fileId); - throw; - } - } - - public async Task UploadFileToExistingSendAsync(Stream stream, Send send) - { - if (send?.Data == null) - { - throw new BadRequestException("Send does not have file data"); - } - - if (send.Type != SendType.File) - { - throw new BadRequestException("Not a File Type Send."); - } - - var data = JsonSerializer.Deserialize(send.Data); - - if (data.Validated) - { - throw new BadRequestException("File has already been uploaded."); - } - - await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id); - - if (!await ValidateSendFile(send)) - { - throw new BadRequestException("File received does not match expected file length."); - } - } - - public async Task ValidateSendFile(Send send) - { - var fileData = JsonSerializer.Deserialize(send.Data); - - var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway); - - if (!valid || realSize > MAX_FILE_SIZE) - { - // File reported differs in size from that promised. Must be a rogue client. Delete Send - await DeleteSendAsync(send); - return false; - } - - // Update Send data if necessary - if (realSize != fileData.Size) - { - fileData.Size = realSize.Value; - } - fileData.Validated = true; - send.Data = JsonSerializer.Serialize(fileData, - JsonHelpers.IgnoreWritingNull); - await SaveSendAsync(send); - - return valid; - } - - public async Task DeleteSendAsync(Send send) - { - await _sendRepository.DeleteAsync(send); - if (send.Type == Enums.SendType.File) - { - var data = JsonSerializer.Deserialize(send.Data); - await _sendFileStorageService.DeleteFileAsync(send, data.Id); - } - await _pushService.PushSyncSendDeleteAsync(send); - } - - public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send, - string password) - { - var now = DateTime.UtcNow; - if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || - send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled || - send.DeletionDate < now) - { - return (false, false, false); - } - if (!string.IsNullOrWhiteSpace(send.Password)) - { - if (string.IsNullOrWhiteSpace(password)) + if (send.Id == default(Guid)) { - return (false, true, false); - } - var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password); - if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded) - { - send.Password = HashPassword(password); - } - if (passwordResult == PasswordVerificationResult.Failed) - { - return (false, false, true); - } - } - - return (true, false, false); - } - - // Response: Send, password required, password invalid - public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password) - { - if (send.Type != SendType.File) - { - throw new BadRequestException("Can only get a download URL for a file type of Send"); - } - - var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); - - if (!grantAccess) - { - return (null, passwordRequired, passwordInvalid); - } - - send.AccessCount++; - await _sendRepository.ReplaceAsync(send); - await _pushService.PushSyncSendUpdateAsync(send); - return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false); - } - - // Response: Send, password required, password invalid - public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password) - { - var send = await _sendRepository.GetByIdAsync(sendId); - var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); - - if (!grantAccess) - { - return (null, passwordRequired, passwordInvalid); - } - - // TODO: maybe move this to a simple ++ sproc? - if (send.Type != SendType.File) - { - // File sends are incremented during file download - send.AccessCount++; - } - - await _sendRepository.ReplaceAsync(send); - await _pushService.PushSyncSendUpdateAsync(send); - await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed); - return (send, false, false); - } - - private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType) - { - await _referenceEventService.RaiseEventAsync(new ReferenceEvent - { - Id = send.UserId ?? default, - Type = eventType, - Source = ReferenceEventSource.User, - SendType = send.Type, - MaxAccessCount = send.MaxAccessCount, - HasPassword = !string.IsNullOrWhiteSpace(send.Password), - }); - } - - public string HashPassword(string password) - { - return _passwordHasher.HashPassword(new User(), password); - } - - private async Task ValidateUserCanSaveAsync(Guid? userId, Send send) - { - if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true)) - { - return; - } - - var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, - PolicyType.DisableSend); - if (disableSendPolicyCount > 0) - { - throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send."); - } - - if (send.HideEmail.GetValueOrDefault()) - { - var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions); - if (sendOptionsPolicies.Any(p => p.GetDataModel()?.DisableHideEmail ?? false)) - { - throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send."); - } - } - } - - private async Task StorageRemainingForSendAsync(Send send) - { - var storageBytesRemaining = 0L; - if (send.UserId.HasValue) - { - var user = await _userRepository.GetByIdAsync(send.UserId.Value); - if (!await _userService.CanAccessPremium(user)) - { - throw new BadRequestException("You must have premium status to use file Sends."); - } - - if (!user.EmailVerified) - { - throw new BadRequestException("You must confirm your email to use file Sends."); - } - - if (user.Premium) - { - storageBytesRemaining = user.StorageBytesRemaining(); + await _sendRepository.CreateAsync(send); + await _pushService.PushSyncSendCreateAsync(send); + await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated); } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? (short)10240 : (short)1); + send.RevisionDate = DateTime.UtcNow; + await _sendRepository.UpsertAsync(send); + await _pushService.PushSyncSendUpdateAsync(send); } } - else if (send.OrganizationId.HasValue) + + public async Task SaveFileSendAsync(Send send, SendFileData data, long fileLength) { - var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value); - if (!org.MaxStorageGb.HasValue) + if (send.Type != SendType.File) { - throw new BadRequestException("This organization cannot use file sends."); + throw new BadRequestException("Send is not of type \"file\"."); } - storageBytesRemaining = org.StorageBytesRemaining(); + if (fileLength < 1) + { + throw new BadRequestException("No file data."); + } + + var storageBytesRemaining = await StorageRemainingForSendAsync(send); + + if (storageBytesRemaining < fileLength) + { + throw new BadRequestException("Not enough storage available."); + } + + var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); + + try + { + data.Id = fileId; + data.Size = fileLength; + data.Validated = false; + send.Data = JsonSerializer.Serialize(data, + JsonHelpers.IgnoreWritingNull); + await SaveSendAsync(send); + return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId); + } + catch + { + // Clean up since this is not transactional + await _sendFileStorageService.DeleteFileAsync(send, fileId); + throw; + } } - return storageBytesRemaining; + public async Task UploadFileToExistingSendAsync(Stream stream, Send send) + { + if (send?.Data == null) + { + throw new BadRequestException("Send does not have file data"); + } + + if (send.Type != SendType.File) + { + throw new BadRequestException("Not a File Type Send."); + } + + var data = JsonSerializer.Deserialize(send.Data); + + if (data.Validated) + { + throw new BadRequestException("File has already been uploaded."); + } + + await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id); + + if (!await ValidateSendFile(send)) + { + throw new BadRequestException("File received does not match expected file length."); + } + } + + public async Task ValidateSendFile(Send send) + { + var fileData = JsonSerializer.Deserialize(send.Data); + + var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway); + + if (!valid || realSize > MAX_FILE_SIZE) + { + // File reported differs in size from that promised. Must be a rogue client. Delete Send + await DeleteSendAsync(send); + return false; + } + + // Update Send data if necessary + if (realSize != fileData.Size) + { + fileData.Size = realSize.Value; + } + fileData.Validated = true; + send.Data = JsonSerializer.Serialize(fileData, + JsonHelpers.IgnoreWritingNull); + await SaveSendAsync(send); + + return valid; + } + + public async Task DeleteSendAsync(Send send) + { + await _sendRepository.DeleteAsync(send); + if (send.Type == Enums.SendType.File) + { + var data = JsonSerializer.Deserialize(send.Data); + await _sendFileStorageService.DeleteFileAsync(send, data.Id); + } + await _pushService.PushSyncSendDeleteAsync(send); + } + + public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send, + string password) + { + var now = DateTime.UtcNow; + if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || + send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled || + send.DeletionDate < now) + { + return (false, false, false); + } + if (!string.IsNullOrWhiteSpace(send.Password)) + { + if (string.IsNullOrWhiteSpace(password)) + { + return (false, true, false); + } + var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password); + if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded) + { + send.Password = HashPassword(password); + } + if (passwordResult == PasswordVerificationResult.Failed) + { + return (false, false, true); + } + } + + return (true, false, false); + } + + // Response: Send, password required, password invalid + public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password) + { + if (send.Type != SendType.File) + { + throw new BadRequestException("Can only get a download URL for a file type of Send"); + } + + var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); + + if (!grantAccess) + { + return (null, passwordRequired, passwordInvalid); + } + + send.AccessCount++; + await _sendRepository.ReplaceAsync(send); + await _pushService.PushSyncSendUpdateAsync(send); + return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false); + } + + // Response: Send, password required, password invalid + public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password) + { + var send = await _sendRepository.GetByIdAsync(sendId); + var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); + + if (!grantAccess) + { + return (null, passwordRequired, passwordInvalid); + } + + // TODO: maybe move this to a simple ++ sproc? + if (send.Type != SendType.File) + { + // File sends are incremented during file download + send.AccessCount++; + } + + await _sendRepository.ReplaceAsync(send); + await _pushService.PushSyncSendUpdateAsync(send); + await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed); + return (send, false, false); + } + + private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType) + { + await _referenceEventService.RaiseEventAsync(new ReferenceEvent + { + Id = send.UserId ?? default, + Type = eventType, + Source = ReferenceEventSource.User, + SendType = send.Type, + MaxAccessCount = send.MaxAccessCount, + HasPassword = !string.IsNullOrWhiteSpace(send.Password), + }); + } + + public string HashPassword(string password) + { + return _passwordHasher.HashPassword(new User(), password); + } + + private async Task ValidateUserCanSaveAsync(Guid? userId, Send send) + { + if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true)) + { + return; + } + + var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, + PolicyType.DisableSend); + if (disableSendPolicyCount > 0) + { + throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send."); + } + + if (send.HideEmail.GetValueOrDefault()) + { + var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions); + if (sendOptionsPolicies.Any(p => p.GetDataModel()?.DisableHideEmail ?? false)) + { + throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send."); + } + } + } + + private async Task StorageRemainingForSendAsync(Send send) + { + var storageBytesRemaining = 0L; + if (send.UserId.HasValue) + { + var user = await _userRepository.GetByIdAsync(send.UserId.Value); + if (!await _userService.CanAccessPremium(user)) + { + throw new BadRequestException("You must have premium status to use file Sends."); + } + + if (!user.EmailVerified) + { + throw new BadRequestException("You must confirm your email to use file Sends."); + } + + if (user.Premium) + { + storageBytesRemaining = user.StorageBytesRemaining(); + } + else + { + // Users that get access to file storage/premium from their organization get the default + // 1 GB max storage. + storageBytesRemaining = user.StorageBytesRemaining( + _globalSettings.SelfHosted ? (short)10240 : (short)1); + } + } + else if (send.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value); + if (!org.MaxStorageGb.HasValue) + { + throw new BadRequestException("This organization cannot use file sends."); + } + + storageBytesRemaining = org.StorageBytesRemaining(); + } + + return storageBytesRemaining; + } } } diff --git a/src/Core/Services/Implementations/SsoConfigService.cs b/src/Core/Services/Implementations/SsoConfigService.cs index 5f44cb9310..4af794967b 100644 --- a/src/Core/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Services/Implementations/SsoConfigService.cs @@ -3,104 +3,105 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -namespace Bit.Core.Services; - -public class SsoConfigService : ISsoConfigService +namespace Bit.Core.Services { - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IEventService _eventService; - - public SsoConfigService( - ISsoConfigRepository ssoConfigRepository, - IPolicyRepository policyRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IEventService eventService) + public class SsoConfigService : ISsoConfigService { - _ssoConfigRepository = ssoConfigRepository; - _policyRepository = policyRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _eventService = eventService; - } + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IEventService _eventService; - public async Task SaveAsync(SsoConfig config, Organization organization) - { - var now = DateTime.UtcNow; - config.RevisionDate = now; - if (config.Id == default) + public SsoConfigService( + ISsoConfigRepository ssoConfigRepository, + IPolicyRepository policyRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IEventService eventService) { - config.CreationDate = now; + _ssoConfigRepository = ssoConfigRepository; + _policyRepository = policyRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _eventService = eventService; } - var useKeyConnector = config.GetData().KeyConnectorEnabled; - if (useKeyConnector) + public async Task SaveAsync(SsoConfig config, Organization organization) { - await VerifyDependenciesAsync(config, organization); + var now = DateTime.UtcNow; + config.RevisionDate = now; + if (config.Id == default) + { + config.CreationDate = now; + } + + var useKeyConnector = config.GetData().KeyConnectorEnabled; + if (useKeyConnector) + { + await VerifyDependenciesAsync(config, organization); + } + + var oldConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(config.OrganizationId); + var disabledKeyConnector = oldConfig?.GetData()?.KeyConnectorEnabled == true && !useKeyConnector; + if (disabledKeyConnector && await AnyOrgUserHasKeyConnectorEnabledAsync(config.OrganizationId)) + { + throw new BadRequestException("Key Connector cannot be disabled at this moment."); + } + + await LogEventsAsync(config, oldConfig); + await _ssoConfigRepository.UpsertAsync(config); } - var oldConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(config.OrganizationId); - var disabledKeyConnector = oldConfig?.GetData()?.KeyConnectorEnabled == true && !useKeyConnector; - if (disabledKeyConnector && await AnyOrgUserHasKeyConnectorEnabledAsync(config.OrganizationId)) + private async Task AnyOrgUserHasKeyConnectorEnabledAsync(Guid organizationId) { - throw new BadRequestException("Key Connector cannot be disabled at this moment."); + var userDetails = + await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + return userDetails.Any(u => u.UsesKeyConnector); } - await LogEventsAsync(config, oldConfig); - await _ssoConfigRepository.UpsertAsync(config); - } - - private async Task AnyOrgUserHasKeyConnectorEnabledAsync(Guid organizationId) - { - var userDetails = - await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - return userDetails.Any(u => u.UsesKeyConnector); - } - - private async Task VerifyDependenciesAsync(SsoConfig config, Organization organization) - { - if (!organization.UseKeyConnector) + private async Task VerifyDependenciesAsync(SsoConfig config, Organization organization) { - throw new BadRequestException("Organization cannot use Key Connector."); + if (!organization.UseKeyConnector) + { + throw new BadRequestException("Organization cannot use Key Connector."); + } + + var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg); + if (singleOrgPolicy is not { Enabled: true }) + { + throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled."); + } + + var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso); + if (ssoPolicy is not { Enabled: true }) + { + throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled."); + } + + if (!config.Enabled) + { + throw new BadRequestException("You must enable SSO to use Key Connector."); + } } - var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg); - if (singleOrgPolicy is not { Enabled: true }) + private async Task LogEventsAsync(SsoConfig config, SsoConfig oldConfig) { - throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled."); - } + var organization = await _organizationRepository.GetByIdAsync(config.OrganizationId); + if (oldConfig?.Enabled != config.Enabled) + { + var e = config.Enabled ? EventType.Organization_EnabledSso : EventType.Organization_DisabledSso; + await _eventService.LogOrganizationEventAsync(organization, e); + } - var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso); - if (ssoPolicy is not { Enabled: true }) - { - throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled."); - } - - if (!config.Enabled) - { - throw new BadRequestException("You must enable SSO to use Key Connector."); - } - } - - private async Task LogEventsAsync(SsoConfig config, SsoConfig oldConfig) - { - var organization = await _organizationRepository.GetByIdAsync(config.OrganizationId); - if (oldConfig?.Enabled != config.Enabled) - { - var e = config.Enabled ? EventType.Organization_EnabledSso : EventType.Organization_DisabledSso; - await _eventService.LogOrganizationEventAsync(organization, e); - } - - var keyConnectorEnabled = config.GetData().KeyConnectorEnabled; - if (oldConfig?.GetData()?.KeyConnectorEnabled != keyConnectorEnabled) - { - var e = keyConnectorEnabled - ? EventType.Organization_EnabledKeyConnector - : EventType.Organization_DisabledKeyConnector; - await _eventService.LogOrganizationEventAsync(organization, e); + var keyConnectorEnabled = config.GetData().KeyConnectorEnabled; + if (oldConfig?.GetData()?.KeyConnectorEnabled != keyConnectorEnabled) + { + var e = keyConnectorEnabled + ? EventType.Organization_EnabledKeyConnector + : EventType.Organization_DisabledKeyConnector; + await _eventService.LogOrganizationEventAsync(organization, e); + } } } } diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index b4776bc6ef..eb467dd573 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -1,217 +1,218 @@ using Bit.Core.Models.BitStripe; -namespace Bit.Core.Services; - -public class StripeAdapter : IStripeAdapter +namespace Bit.Core.Services { - private readonly Stripe.CustomerService _customerService; - private readonly Stripe.SubscriptionService _subscriptionService; - private readonly Stripe.InvoiceService _invoiceService; - private readonly Stripe.PaymentMethodService _paymentMethodService; - private readonly Stripe.TaxRateService _taxRateService; - private readonly Stripe.TaxIdService _taxIdService; - private readonly Stripe.ChargeService _chargeService; - private readonly Stripe.RefundService _refundService; - private readonly Stripe.CardService _cardService; - private readonly Stripe.BankAccountService _bankAccountService; - private readonly Stripe.PriceService _priceService; - private readonly Stripe.TestHelpers.TestClockService _testClockService; - - public StripeAdapter() + public class StripeAdapter : IStripeAdapter { - _customerService = new Stripe.CustomerService(); - _subscriptionService = new Stripe.SubscriptionService(); - _invoiceService = new Stripe.InvoiceService(); - _paymentMethodService = new Stripe.PaymentMethodService(); - _taxRateService = new Stripe.TaxRateService(); - _taxIdService = new Stripe.TaxIdService(); - _chargeService = new Stripe.ChargeService(); - _refundService = new Stripe.RefundService(); - _cardService = new Stripe.CardService(); - _bankAccountService = new Stripe.BankAccountService(); - _priceService = new Stripe.PriceService(); - _testClockService = new Stripe.TestHelpers.TestClockService(); - } + private readonly Stripe.CustomerService _customerService; + private readonly Stripe.SubscriptionService _subscriptionService; + private readonly Stripe.InvoiceService _invoiceService; + private readonly Stripe.PaymentMethodService _paymentMethodService; + private readonly Stripe.TaxRateService _taxRateService; + private readonly Stripe.TaxIdService _taxIdService; + private readonly Stripe.ChargeService _chargeService; + private readonly Stripe.RefundService _refundService; + private readonly Stripe.CardService _cardService; + private readonly Stripe.BankAccountService _bankAccountService; + private readonly Stripe.PriceService _priceService; + private readonly Stripe.TestHelpers.TestClockService _testClockService; - public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) - { - return _customerService.CreateAsync(options); - } - - public Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null) - { - return _customerService.GetAsync(id, options); - } - - public Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null) - { - return _customerService.UpdateAsync(id, options); - } - - public Task CustomerDeleteAsync(string id) - { - return _customerService.DeleteAsync(id); - } - - public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) - { - return _subscriptionService.CreateAsync(options); - } - - public Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null) - { - return _subscriptionService.GetAsync(id, options); - } - - public Task SubscriptionUpdateAsync(string id, - Stripe.SubscriptionUpdateOptions options = null) - { - return _subscriptionService.UpdateAsync(id, options); - } - - public Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null) - { - return _subscriptionService.CancelAsync(Id, options); - } - - public Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options) - { - return _invoiceService.UpcomingAsync(options); - } - - public Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options) - { - return _invoiceService.GetAsync(id, options); - } - - public Task> InvoiceListAsync(Stripe.InvoiceListOptions options) - { - return _invoiceService.ListAsync(options); - } - - public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) - { - return _invoiceService.UpdateAsync(id, options); - } - - public Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options) - { - return _invoiceService.FinalizeInvoiceAsync(id, options); - } - - public Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options) - { - return _invoiceService.SendInvoiceAsync(id, options); - } - - public Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null) - { - return _invoiceService.PayAsync(id, options); - } - - public Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null) - { - return _invoiceService.DeleteAsync(id, options); - } - - public Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null) - { - return _invoiceService.VoidInvoiceAsync(id, options); - } - - public IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options) - { - return _paymentMethodService.ListAutoPaging(options); - } - - public Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null) - { - return _paymentMethodService.AttachAsync(id, options); - } - - public Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null) - { - return _paymentMethodService.DetachAsync(id, options); - } - - public Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options) - { - return _taxRateService.CreateAsync(options); - } - - public Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options) - { - return _taxRateService.UpdateAsync(id, options); - } - - public Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options) - { - return _taxIdService.CreateAsync(id, options); - } - - public Task TaxIdDeleteAsync(string customerId, string taxIdId, - Stripe.TaxIdDeleteOptions options = null) - { - return _taxIdService.DeleteAsync(customerId, taxIdId); - } - - public Task> ChargeListAsync(Stripe.ChargeListOptions options) - { - return _chargeService.ListAsync(options); - } - - public Task RefundCreateAsync(Stripe.RefundCreateOptions options) - { - return _refundService.CreateAsync(options); - } - - public Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null) - { - return _cardService.DeleteAsync(customerId, cardId, options); - } - - public Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null) - { - return _bankAccountService.CreateAsync(customerId, options); - } - - public Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null) - { - return _bankAccountService.DeleteAsync(customerId, bankAccount, options); - } - - public async Task> SubscriptionListAsync(StripeSubscriptionListOptions options) - { - if (!options.SelectAll) + public StripeAdapter() { - return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data; + _customerService = new Stripe.CustomerService(); + _subscriptionService = new Stripe.SubscriptionService(); + _invoiceService = new Stripe.InvoiceService(); + _paymentMethodService = new Stripe.PaymentMethodService(); + _taxRateService = new Stripe.TaxRateService(); + _taxIdService = new Stripe.TaxIdService(); + _chargeService = new Stripe.ChargeService(); + _refundService = new Stripe.RefundService(); + _cardService = new Stripe.CardService(); + _bankAccountService = new Stripe.BankAccountService(); + _priceService = new Stripe.PriceService(); + _testClockService = new Stripe.TestHelpers.TestClockService(); } - options.Limit = 100; - var items = new List(); - await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions())) + public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) { - items.Add(i); + return _customerService.CreateAsync(options); } - return items; - } - public async Task> PriceListAsync(Stripe.PriceListOptions options = null) - { - return await _priceService.ListAsync(options); - } - - public async Task> TestClockListAsync() - { - var items = new List(); - var options = new Stripe.TestHelpers.TestClockListOptions() + public Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null) { - Limit = 100 - }; - await foreach (var i in _testClockService.ListAutoPagingAsync(options)) - { - items.Add(i); + return _customerService.GetAsync(id, options); + } + + public Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null) + { + return _customerService.UpdateAsync(id, options); + } + + public Task CustomerDeleteAsync(string id) + { + return _customerService.DeleteAsync(id); + } + + public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) + { + return _subscriptionService.CreateAsync(options); + } + + public Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null) + { + return _subscriptionService.GetAsync(id, options); + } + + public Task SubscriptionUpdateAsync(string id, + Stripe.SubscriptionUpdateOptions options = null) + { + return _subscriptionService.UpdateAsync(id, options); + } + + public Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null) + { + return _subscriptionService.CancelAsync(Id, options); + } + + public Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options) + { + return _invoiceService.UpcomingAsync(options); + } + + public Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options) + { + return _invoiceService.GetAsync(id, options); + } + + public Task> InvoiceListAsync(Stripe.InvoiceListOptions options) + { + return _invoiceService.ListAsync(options); + } + + public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) + { + return _invoiceService.UpdateAsync(id, options); + } + + public Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options) + { + return _invoiceService.FinalizeInvoiceAsync(id, options); + } + + public Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options) + { + return _invoiceService.SendInvoiceAsync(id, options); + } + + public Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null) + { + return _invoiceService.PayAsync(id, options); + } + + public Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null) + { + return _invoiceService.DeleteAsync(id, options); + } + + public Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null) + { + return _invoiceService.VoidInvoiceAsync(id, options); + } + + public IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options) + { + return _paymentMethodService.ListAutoPaging(options); + } + + public Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null) + { + return _paymentMethodService.AttachAsync(id, options); + } + + public Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null) + { + return _paymentMethodService.DetachAsync(id, options); + } + + public Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options) + { + return _taxRateService.CreateAsync(options); + } + + public Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options) + { + return _taxRateService.UpdateAsync(id, options); + } + + public Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options) + { + return _taxIdService.CreateAsync(id, options); + } + + public Task TaxIdDeleteAsync(string customerId, string taxIdId, + Stripe.TaxIdDeleteOptions options = null) + { + return _taxIdService.DeleteAsync(customerId, taxIdId); + } + + public Task> ChargeListAsync(Stripe.ChargeListOptions options) + { + return _chargeService.ListAsync(options); + } + + public Task RefundCreateAsync(Stripe.RefundCreateOptions options) + { + return _refundService.CreateAsync(options); + } + + public Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null) + { + return _cardService.DeleteAsync(customerId, cardId, options); + } + + public Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null) + { + return _bankAccountService.CreateAsync(customerId, options); + } + + public Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null) + { + return _bankAccountService.DeleteAsync(customerId, bankAccount, options); + } + + public async Task> SubscriptionListAsync(StripeSubscriptionListOptions options) + { + if (!options.SelectAll) + { + return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data; + } + + options.Limit = 100; + var items = new List(); + await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions())) + { + items.Add(i); + } + return items; + } + + public async Task> PriceListAsync(Stripe.PriceListOptions options = null) + { + return await _priceService.ListAsync(options); + } + + public async Task> TestClockListAsync() + { + var items = new List(); + var options = new Stripe.TestHelpers.TestClockListOptions() + { + Limit = 100 + }; + await foreach (var i in _testClockService.ListAutoPagingAsync(options)) + { + items.Add(i); + } + return items; } - return items; } } diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 25561db4be..e3ed163687 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -8,388 +8,74 @@ using Microsoft.Extensions.Logging; using StaticStore = Bit.Core.Models.StaticStore; using TaxRate = Bit.Core.Entities.TaxRate; -namespace Bit.Core.Services; - -public class StripePaymentService : IPaymentService +namespace Bit.Core.Services { - private const string PremiumPlanId = "premium-annually"; - private const string PremiumPlanAppleIapId = "premium-annually-appleiap"; - private const decimal PremiumPlanAppleIapPrice = 14.99M; - private const string StoragePlanId = "storage-gb-annually"; - - private readonly ITransactionRepository _transactionRepository; - private readonly IUserRepository _userRepository; - private readonly IAppleIapService _appleIapService; - private readonly ILogger _logger; - private readonly Braintree.IBraintreeGateway _btGateway; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IStripeAdapter _stripeAdapter; - - public StripePaymentService( - ITransactionRepository transactionRepository, - IUserRepository userRepository, - IAppleIapService appleIapService, - ILogger logger, - ITaxRateRepository taxRateRepository, - IStripeAdapter stripeAdapter, - Braintree.IBraintreeGateway braintreeGateway) + public class StripePaymentService : IPaymentService { - _transactionRepository = transactionRepository; - _userRepository = userRepository; - _appleIapService = appleIapService; - _logger = logger; - _taxRateRepository = taxRateRepository; - _stripeAdapter = stripeAdapter; - _btGateway = braintreeGateway; - } + private const string PremiumPlanId = "premium-annually"; + private const string PremiumPlanAppleIapId = "premium-annually-appleiap"; + private const decimal PremiumPlanAppleIapPrice = 14.99M; + private const string StoragePlanId = "storage-gb-annually"; - public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, - string paymentToken, StaticStore.Plan plan, short additionalStorageGb, - int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) - { - Braintree.Customer braintreeCustomer = null; - string stipeCustomerSourceToken = null; - string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary(); - var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount; + private readonly ITransactionRepository _transactionRepository; + private readonly IUserRepository _userRepository; + private readonly IAppleIapService _appleIapService; + private readonly ILogger _logger; + private readonly Braintree.IBraintreeGateway _btGateway; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IStripeAdapter _stripeAdapter; - if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) + public StripePaymentService( + ITransactionRepository transactionRepository, + IUserRepository userRepository, + IAppleIapService appleIapService, + ILogger logger, + ITaxRateRepository taxRateRepository, + IStripeAdapter stripeAdapter, + Braintree.IBraintreeGateway braintreeGateway) { - if (paymentToken.StartsWith("pm_")) - { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; - } - } - else if (paymentMethodType == PaymentMethodType.PayPal) - { - var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); - var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest - { - PaymentMethodNonce = paymentToken, - Email = org.BillingEmail, - Id = org.BraintreeCustomerIdPrefix() + org.Id.ToString("N").ToLower() + randomSuffix, - CustomFields = new Dictionary - { - [org.BraintreeIdField()] = org.Id.ToString() - } - }); - - if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) - { - throw new GatewayException("Failed to create PayPal customer record."); - } - - braintreeCustomer = customerResult.Target; - stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); - } - else - { - throw new GatewayException("Payment method is not supported at this time."); + _transactionRepository = transactionRepository; + _userRepository = userRepository; + _appleIapService = appleIapService; + _logger = logger; + _taxRateRepository = taxRateRepository; + _stripeAdapter = stripeAdapter; + _btGateway = braintreeGateway; } - if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) - { - var taxRateSearch = new TaxRate - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode - }; - var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); - - // should only be one tax rate per country/zip combo - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null) - { - taxInfo.StripeTaxRateId = taxRate.Id; - } - } - - var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); - - Stripe.Customer customer = null; - Stripe.Subscription subscription; - try - { - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions - { - Description = org.BusinessName, - Email = org.BillingEmail, - Source = stipeCustomerSourceToken, - PaymentMethod = stipeCustomerPaymentMethodId, - Metadata = stripeCustomerMetadata, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId - }, - Address = new Stripe.AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - // Line1 is required in Stripe's API, suggestion in Docs is to use Business Name intead. - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - }, - TaxIdData = !taxInfo.HasTaxId ? null : new List - { - new Stripe.CustomerTaxIdDataOptions - { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber, - }, - }, - }); - subCreateOptions.AddExpand("latest_invoice.payment_intent"); - subCreateOptions.Customer = customer.Id; - subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); - if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) - { - if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") - { - await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); - throw new GatewayException("Payment method was declined."); - } - } - } - catch (Exception ex) - { - _logger.LogError(ex, "Error creating customer, walking back operation."); - if (customer != null) - { - await _stripeAdapter.CustomerDeleteAsync(customer.Id); - } - if (braintreeCustomer != null) - { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); - } - throw; - } - - org.Gateway = GatewayType.Stripe; - org.GatewayCustomerId = customer.Id; - org.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == "incomplete" && - subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") - { - org.Enabled = false; - return subscription.LatestInvoice.PaymentIntent.ClientSecret; - } - else - { - org.Enabled = true; - org.ExpirationDate = subscription.CurrentPeriodEnd; - return null; - } - } - - private async Task ChangeOrganizationSponsorship(Organization org, OrganizationSponsorship sponsorship, bool applySponsorship) - { - var existingPlan = Utilities.StaticStore.GetPlan(org.PlanType); - var sponsoredPlan = sponsorship != null ? - Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) : - null; - var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); - - await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, DateTime.UtcNow); - - var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); - org.ExpirationDate = sub.CurrentPeriodEnd; - sponsorship.ValidUntil = sub.CurrentPeriodEnd; - - } - - public Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship) => - ChangeOrganizationSponsorship(org, sponsorship, true); - - public Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship) => - ChangeOrganizationSponsorship(org, sponsorship, false); - - public async Task UpgradeFreeOrganizationAsync(Organization org, StaticStore.Plan plan, - short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) - { - if (!string.IsNullOrWhiteSpace(org.GatewaySubscriptionId)) - { - throw new BadRequestException("Organization already has a subscription."); - } - - var customerOptions = new Stripe.CustomerGetOptions(); - customerOptions.AddExpand("default_source"); - customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(org.GatewayCustomerId, customerOptions); - if (customer == null) - { - throw new GatewayException("Could not find customer payment profile."); - } - - if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) - { - var taxRateSearch = new TaxRate - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode - }; - var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); - - // should only be one tax rate per country/zip combo - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null) - { - taxInfo.StripeTaxRateId = taxRate.Id; - } - } - - var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); - var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions); - - var subscription = await ChargeForNewSubscriptionAsync(org, customer, false, - stripePaymentMethod, paymentMethodType, subCreateOptions, null); - org.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == "incomplete" && - subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") - { - org.Enabled = false; - return subscription.LatestInvoice.PaymentIntent.ClientSecret; - } - else - { - org.Enabled = true; - org.ExpirationDate = subscription.CurrentPeriodEnd; - return null; - } - } - - private (bool stripePaymentMethod, PaymentMethodType PaymentMethodType) IdentifyPaymentMethod( - Stripe.Customer customer, Stripe.SubscriptionCreateOptions subCreateOptions) - { - var stripePaymentMethod = false; - var paymentMethodType = PaymentMethodType.Credit; - var hasBtCustomerId = customer.Metadata.ContainsKey("btCustomerId"); - if (hasBtCustomerId) - { - paymentMethodType = PaymentMethodType.PayPal; - } - else - { - if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") - { - paymentMethodType = PaymentMethodType.Card; - stripePaymentMethod = true; - } - else if (customer.DefaultSource != null) - { - if (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.SourceCard) - { - paymentMethodType = PaymentMethodType.Card; - stripePaymentMethod = true; - } - else if (customer.DefaultSource is Stripe.BankAccount || customer.DefaultSource is Stripe.SourceAchDebit) - { - paymentMethodType = PaymentMethodType.BankAccount; - stripePaymentMethod = true; - } - } - else - { - var paymentMethod = GetLatestCardPaymentMethod(customer.Id); - if (paymentMethod != null) - { - paymentMethodType = PaymentMethodType.Card; - stripePaymentMethod = true; - subCreateOptions.DefaultPaymentMethod = paymentMethod.Id; - } - } - } - return (stripePaymentMethod, paymentMethodType); - } - - public async Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, - string paymentToken, short additionalStorageGb, TaxInfo taxInfo) - { - if (paymentMethodType != PaymentMethodType.Credit && string.IsNullOrWhiteSpace(paymentToken)) - { - throw new BadRequestException("Payment token is required."); - } - if (paymentMethodType == PaymentMethodType.Credit && - (user.Gateway != GatewayType.Stripe || string.IsNullOrWhiteSpace(user.GatewayCustomerId))) - { - throw new BadRequestException("Your account does not have any credit available."); - } - if (paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.GoogleInApp) - { - throw new GatewayException("Payment method is not supported at this time."); - } - if ((paymentMethodType == PaymentMethodType.GoogleInApp || - paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) - { - throw new BadRequestException("You cannot add storage with this payment method."); - } - - var createdStripeCustomer = false; - Stripe.Customer customer = null; - Braintree.Customer braintreeCustomer = null; - var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.Credit; - - string stipeCustomerPaymentMethodId = null; - string stipeCustomerSourceToken = null; - if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) - { - if (paymentToken.StartsWith("pm_")) - { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; - } - } - - if (user.Gateway == GatewayType.Stripe && !string.IsNullOrWhiteSpace(user.GatewayCustomerId)) - { - if (!string.IsNullOrWhiteSpace(paymentToken)) - { - try - { - await UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, true, taxInfo); - } - catch (Exception e) - { - var message = e.Message.ToLowerInvariant(); - if (message.Contains("apple") || message.Contains("in-app")) - { - throw; - } - } - } - try - { - customer = await _stripeAdapter.CustomerGetAsync(user.GatewayCustomerId); - } - catch { } - } - - if (customer == null && !string.IsNullOrWhiteSpace(paymentToken)) + public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, + string paymentToken, StaticStore.Plan plan, short additionalStorageGb, + int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) { + Braintree.Customer braintreeCustomer = null; + string stipeCustomerSourceToken = null; + string stipeCustomerPaymentMethodId = null; var stripeCustomerMetadata = new Dictionary(); - if (paymentMethodType == PaymentMethodType.PayPal) + var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || + paymentMethodType == PaymentMethodType.BankAccount; + + if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) + { + if (paymentToken.StartsWith("pm_")) + { + stipeCustomerPaymentMethodId = paymentToken; + } + else + { + stipeCustomerSourceToken = paymentToken; + } + } + else if (paymentMethodType == PaymentMethodType.PayPal) { var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest { PaymentMethodNonce = paymentToken, - Email = user.Email, - Id = user.BraintreeCustomerIdPrefix() + user.Id.ToString("N").ToLower() + randomSuffix, + Email = org.BillingEmail, + Id = org.BraintreeCustomerIdPrefix() + org.Id.ToString("N").ToLower() + randomSuffix, CustomFields = new Dictionary { - [user.BraintreeIdField()] = user.Id.ToString() + [org.BraintreeIdField()] = org.Id.ToString() } }); @@ -401,1386 +87,1701 @@ public class StripePaymentService : IPaymentService braintreeCustomer = customerResult.Target; stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); } - else if (paymentMethodType == PaymentMethodType.AppleInApp) - { - var verifiedReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); - if (verifiedReceiptStatus == null) - { - throw new GatewayException("Cannot verify apple in-app purchase."); - } - var receiptOriginalTransactionId = verifiedReceiptStatus.GetOriginalTransactionId(); - await VerifyAppleReceiptNotInUseAsync(receiptOriginalTransactionId, user); - await _appleIapService.SaveReceiptAsync(verifiedReceiptStatus, user.Id); - stripeCustomerMetadata.Add("appleReceipt", receiptOriginalTransactionId); - } - else if (!stripePaymentMethod) + else { throw new GatewayException("Payment method is not supported at this time."); } - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions + if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) { - Description = user.Name, - Email = user.Email, - Metadata = stripeCustomerMetadata, - PaymentMethod = stipeCustomerPaymentMethodId, - Source = stipeCustomerSourceToken, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId - }, - Address = new Stripe.AddressOptions - { - Line1 = string.Empty, - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - }, - }); - createdStripeCustomer = true; - } - - if (customer == null) - { - throw new GatewayException("Could not set up customer payment profile."); - } - - var subCreateOptions = new Stripe.SubscriptionCreateOptions - { - Customer = customer.Id, - Items = new List(), - Metadata = new Dictionary - { - [user.GatewayIdField()] = user.Id.ToString() - } - }; - - subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions - { - Plan = paymentMethodType == PaymentMethodType.AppleInApp ? PremiumPlanAppleIapId : PremiumPlanId, - Quantity = 1, - }); - - if (!string.IsNullOrWhiteSpace(taxInfo?.BillingAddressCountry) - && !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressPostalCode)) - { - var taxRates = await _taxRateRepository.GetByLocationAsync( - new TaxRate() + var taxRateSearch = new TaxRate { Country = taxInfo.BillingAddressCountry, PostalCode = taxInfo.BillingAddressPostalCode - } - ); - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null) - { - subCreateOptions.DefaultTaxRates = new List(1) - { - taxRate.Id }; + var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); + + // should only be one tax rate per country/zip combo + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + taxInfo.StripeTaxRateId = taxRate.Id; + } } - } - if (additionalStorageGb > 0) - { - subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions + var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); + + Stripe.Customer customer = null; + Stripe.Subscription subscription; + try { - Plan = StoragePlanId, - Quantity = additionalStorageGb - }); - } - - var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer, - stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer); - - user.Gateway = GatewayType.Stripe; - user.GatewayCustomerId = customer.Id; - user.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == "incomplete" && - subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") - { - return subscription.LatestInvoice.PaymentIntent.ClientSecret; - } - else - { - user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; - return null; - } - } - - private async Task ChargeForNewSubscriptionAsync(ISubscriber subcriber, Stripe.Customer customer, - bool createdStripeCustomer, bool stripePaymentMethod, PaymentMethodType paymentMethodType, - Stripe.SubscriptionCreateOptions subCreateOptions, Braintree.Customer braintreeCustomer) - { - var addedCreditToStripeCustomer = false; - Braintree.Transaction braintreeTransaction = null; - Transaction appleTransaction = null; - - var subInvoiceMetadata = new Dictionary(); - Stripe.Subscription subscription = null; - try - { - if (!stripePaymentMethod) - { - var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions { - Customer = customer.Id, - SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), - SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, - }); - - if (previewInvoice.AmountDue > 0) - { - var appleReceiptOrigTransactionId = customer.Metadata != null && - customer.Metadata.ContainsKey("appleReceipt") ? customer.Metadata["appleReceipt"] : null; - var braintreeCustomerId = customer.Metadata != null && - customer.Metadata.ContainsKey("btCustomerId") ? customer.Metadata["btCustomerId"] : null; - if (!string.IsNullOrWhiteSpace(appleReceiptOrigTransactionId)) + Description = org.BusinessName, + Email = org.BillingEmail, + Source = stipeCustomerSourceToken, + PaymentMethod = stipeCustomerPaymentMethodId, + Metadata = stripeCustomerMetadata, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions { - if (!subcriber.IsUser()) + DefaultPaymentMethod = stipeCustomerPaymentMethodId + }, + Address = new Stripe.AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + // Line1 is required in Stripe's API, suggestion in Docs is to use Business Name intead. + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + }, + TaxIdData = !taxInfo.HasTaxId ? null : new List + { + new Stripe.CustomerTaxIdDataOptions { - throw new GatewayException("In-app purchase is only allowed for users."); - } - - var appleReceipt = await _appleIapService.GetReceiptAsync( - appleReceiptOrigTransactionId); - var verifiedAppleReceipt = await _appleIapService.GetVerifiedReceiptStatusAsync( - appleReceipt.Item1); - if (verifiedAppleReceipt == null) - { - throw new GatewayException("Failed to get Apple in-app purchase receipt data."); - } - subInvoiceMetadata.Add("appleReceipt", verifiedAppleReceipt.GetOriginalTransactionId()); - var lastTransactionId = verifiedAppleReceipt.GetLastTransactionId(); - subInvoiceMetadata.Add("appleReceiptTransactionId", lastTransactionId); - var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.AppStore, lastTransactionId); - if (existingTransaction == null) - { - appleTransaction = verifiedAppleReceipt.BuildTransactionFromLastTransaction( - PremiumPlanAppleIapPrice, subcriber.Id); - appleTransaction.Type = TransactionType.Charge; - await _transactionRepository.CreateAsync(appleTransaction); - } + Type = taxInfo.TaxIdType, + Value = taxInfo.TaxIdNumber, + }, + }, + }); + subCreateOptions.AddExpand("latest_invoice.payment_intent"); + subCreateOptions.Customer = customer.Id; + subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); + if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) + { + if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") + { + await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); + throw new GatewayException("Payment method was declined."); } - else if (!string.IsNullOrWhiteSpace(braintreeCustomerId)) - { - var btInvoiceAmount = (previewInvoice.AmountDue / 100M); - var transactionResult = await _btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest - { - Amount = btInvoiceAmount, - CustomerId = braintreeCustomerId, - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest - { - CustomField = $"{subcriber.BraintreeIdField()}:{subcriber.Id}" - } - }, - CustomFields = new Dictionary - { - [subcriber.BraintreeIdField()] = subcriber.Id.ToString() - } - }); - - if (!transactionResult.IsSuccess()) - { - throw new GatewayException("Failed to charge PayPal customer."); - } - - braintreeTransaction = transactionResult.Target; - subInvoiceMetadata.Add("btTransactionId", braintreeTransaction.Id); - subInvoiceMetadata.Add("btPayPalTransactionId", - braintreeTransaction.PayPalDetails.AuthorizationId); - } - else - { - throw new GatewayException("No payment was able to be collected."); - } - - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = customer.Balance - previewInvoice.AmountDue - }); - addedCreditToStripeCustomer = true; } } - else if (paymentMethodType == PaymentMethodType.Credit) + catch (Exception ex) { - var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions - { - Customer = customer.Id, - SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), - SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, - }); - if (previewInvoice.AmountDue > 0) - { - throw new GatewayException("Your account does not have enough credit available."); - } - } - - subCreateOptions.OffSession = true; - subCreateOptions.AddExpand("latest_invoice.payment_intent"); - subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); - if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) - { - if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") - { - await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); - throw new GatewayException("Payment method was declined."); - } - } - - if (!stripePaymentMethod && subInvoiceMetadata.Any()) - { - var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions - { - Subscription = subscription.Id - }); - - var invoice = invoices?.FirstOrDefault(); - if (invoice == null) - { - throw new GatewayException("Invoice not found."); - } - - await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions - { - Metadata = subInvoiceMetadata - }); - } - - return subscription; - } - catch (Exception e) - { - if (customer != null) - { - if (createdStripeCustomer) + _logger.LogError(ex, "Error creating customer, walking back operation."); + if (customer != null) { await _stripeAdapter.CustomerDeleteAsync(customer.Id); } - else if (addedCreditToStripeCustomer || customer.Balance < 0) + if (braintreeCustomer != null) { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = customer.Balance - }); + await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); } + throw; } - if (braintreeTransaction != null) + + org.Gateway = GatewayType.Stripe; + org.GatewayCustomerId = customer.Id; + org.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == "incomplete" && + subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") { - await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); + org.Enabled = false; + return subscription.LatestInvoice.PaymentIntent.ClientSecret; } - if (braintreeCustomer != null) + else { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + org.Enabled = true; + org.ExpirationDate = subscription.CurrentPeriodEnd; + return null; } - if (appleTransaction != null) - { - await _transactionRepository.DeleteAsync(appleTransaction); - } - - if (e is Stripe.StripeException strEx && - (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) - { - throw new GatewayException("Bank account is not yet verified."); - } - - throw; - } - } - - private List ToInvoiceSubscriptionItemOptions( - List subItemOptions) - { - return subItemOptions.Select(si => new Stripe.InvoiceSubscriptionItemOptions - { - Plan = si.Plan, - Quantity = si.Quantity - }).ToList(); - } - - private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber, - SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate) - { - // remember, when in doubt, throw - - var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId); - if (sub == null) - { - throw new GatewayException("Subscription not found."); } - prorationDate ??= DateTime.UtcNow; - var collectionMethod = sub.CollectionMethod; - var daysUntilDue = sub.DaysUntilDue; - var chargeNow = collectionMethod == "charge_automatically"; - var updatedItemOptions = subscriptionUpdate.UpgradeItemsOptions(sub); - - var subUpdateOptions = new Stripe.SubscriptionUpdateOptions + private async Task ChangeOrganizationSponsorship(Organization org, OrganizationSponsorship sponsorship, bool applySponsorship) { - Items = updatedItemOptions, - ProrationBehavior = "always_invoice", - DaysUntilDue = daysUntilDue ?? 1, - CollectionMethod = "send_invoice", - ProrationDate = prorationDate, - }; + var existingPlan = Utilities.StaticStore.GetPlan(org.PlanType); + var sponsoredPlan = sponsorship != null ? + Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) : + null; + var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); + + await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, DateTime.UtcNow); + + var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); + org.ExpirationDate = sub.CurrentPeriodEnd; + sponsorship.ValidUntil = sub.CurrentPeriodEnd; - if (!subscriptionUpdate.UpdateNeeded(sub)) - { - // No need to update subscription, quantity matches - return null; } - var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); + public Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship) => + ChangeOrganizationSponsorship(org, sponsorship, true); - if (!string.IsNullOrWhiteSpace(customer?.Address?.Country) - && !string.IsNullOrWhiteSpace(customer?.Address?.PostalCode)) + public Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship) => + ChangeOrganizationSponsorship(org, sponsorship, false); + + public async Task UpgradeFreeOrganizationAsync(Organization org, StaticStore.Plan plan, + short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) { - var taxRates = await _taxRateRepository.GetByLocationAsync( - new TaxRate() + if (!string.IsNullOrWhiteSpace(org.GatewaySubscriptionId)) + { + throw new BadRequestException("Organization already has a subscription."); + } + + var customerOptions = new Stripe.CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + var customer = await _stripeAdapter.CustomerGetAsync(org.GatewayCustomerId, customerOptions); + if (customer == null) + { + throw new GatewayException("Could not find customer payment profile."); + } + + if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) + { + var taxRateSearch = new TaxRate { - Country = customer.Address.Country, - PostalCode = customer.Address.PostalCode - } - ); - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null && !sub.DefaultTaxRates.Any(x => x.Equals(taxRate.Id))) - { - subUpdateOptions.DefaultTaxRates = new List(1) - { - taxRate.Id + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode }; + var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); + + // should only be one tax rate per country/zip combo + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + taxInfo.StripeTaxRateId = taxRate.Id; + } + } + + var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); + var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions); + + var subscription = await ChargeForNewSubscriptionAsync(org, customer, false, + stripePaymentMethod, paymentMethodType, subCreateOptions, null); + org.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == "incomplete" && + subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") + { + org.Enabled = false; + return subscription.LatestInvoice.PaymentIntent.ClientSecret; + } + else + { + org.Enabled = true; + org.ExpirationDate = subscription.CurrentPeriodEnd; + return null; } } - string paymentIntentClientSecret = null; - try + private (bool stripePaymentMethod, PaymentMethodType PaymentMethodType) IdentifyPaymentMethod( + Stripe.Customer customer, Stripe.SubscriptionCreateOptions subCreateOptions) { - var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); - - var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new Stripe.InvoiceGetOptions()); - if (invoice == null) + var stripePaymentMethod = false; + var paymentMethodType = PaymentMethodType.Credit; + var hasBtCustomerId = customer.Metadata.ContainsKey("btCustomerId"); + if (hasBtCustomerId) { - throw new BadRequestException("Unable to locate draft invoice for subscription update."); + paymentMethodType = PaymentMethodType.PayPal; } - - if (invoice.AmountDue > 0 && updatedItemOptions.Any(i => i.Quantity > 0)) + else { - try + if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") { - if (chargeNow) + paymentMethodType = PaymentMethodType.Card; + stripePaymentMethod = true; + } + else if (customer.DefaultSource != null) + { + if (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.SourceCard) { - paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync( - storableSubscriber, invoice); + paymentMethodType = PaymentMethodType.Card; + stripePaymentMethod = true; } - else + else if (customer.DefaultSource is Stripe.BankAccount || customer.DefaultSource is Stripe.SourceAchDebit) { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, new Stripe.InvoiceFinalizeOptions - { - AutoAdvance = false, - }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new Stripe.InvoiceSendOptions()); - paymentIntentClientSecret = null; + paymentMethodType = PaymentMethodType.BankAccount; + stripePaymentMethod = true; } } - catch + else { - // Need to revert the subscription - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions + var paymentMethod = GetLatestCardPaymentMethod(customer.Id); + if (paymentMethod != null) { - Items = subscriptionUpdate.RevertItemsOptions(sub), - // This proration behavior prevents a false "credit" from - // being applied forward to the next month's invoice - ProrationBehavior = "none", - CollectionMethod = collectionMethod, - DaysUntilDue = daysUntilDue, - }); - throw; + paymentMethodType = PaymentMethodType.Card; + stripePaymentMethod = true; + subCreateOptions.DefaultPaymentMethod = paymentMethod.Id; + } } } - else if (!invoice.Paid) + return (stripePaymentMethod, paymentMethodType); + } + + public async Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, + string paymentToken, short additionalStorageGb, TaxInfo taxInfo) + { + if (paymentMethodType != PaymentMethodType.Credit && string.IsNullOrWhiteSpace(paymentToken)) { - // Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h - invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); - paymentIntentClientSecret = null; + throw new BadRequestException("Payment token is required."); + } + if (paymentMethodType == PaymentMethodType.Credit && + (user.Gateway != GatewayType.Stripe || string.IsNullOrWhiteSpace(user.GatewayCustomerId))) + { + throw new BadRequestException("Your account does not have any credit available."); + } + if (paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.GoogleInApp) + { + throw new GatewayException("Payment method is not supported at this time."); + } + if ((paymentMethodType == PaymentMethodType.GoogleInApp || + paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) + { + throw new BadRequestException("You cannot add storage with this payment method."); } - } - finally - { - // Change back the subscription collection method and/or days until due - if (collectionMethod != "send_invoice" || daysUntilDue == null) + var createdStripeCustomer = false; + Stripe.Customer customer = null; + Braintree.Customer braintreeCustomer = null; + var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || + paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.Credit; + + string stipeCustomerPaymentMethodId = null; + string stipeCustomerSourceToken = null; + if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) { - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions + if (paymentToken.StartsWith("pm_")) { - CollectionMethod = collectionMethod, - DaysUntilDue = daysUntilDue, - }); - } - } - - return paymentIntentClientSecret; - } - - public Task AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null) - { - return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); - } - - public Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, - string storagePlanId, DateTime? prorationDate = null) - { - return FinalizeSubscriptionChangeAsync(storableSubscriber, new StorageSubscriptionUpdate(storagePlanId, additionalStorage), prorationDate); - } - - public async Task CancelAndRecoverChargesAsync(ISubscriber subscriber) - { - if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, - new Stripe.SubscriptionCancelOptions()); - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - if (customer == null) - { - return; - } - - if (customer.Metadata.ContainsKey("btCustomerId")) - { - var transactionRequest = new Braintree.TransactionSearchRequest() - .CustomerId.Is(customer.Metadata["btCustomerId"]); - var transactions = _btGateway.Transaction.Search(transactionRequest); - - if ((transactions?.MaximumCount ?? 0) > 0) - { - var txs = transactions.Cast().Where(c => c.RefundedTransactionId == null); - foreach (var transaction in txs) + stipeCustomerPaymentMethodId = paymentToken; + } + else { - await _btGateway.Transaction.RefundAsync(transaction.Id); + stipeCustomerSourceToken = paymentToken; } } - await _btGateway.Customer.DeleteAsync(customer.Metadata["btCustomerId"]); - } - else - { - var charges = await _stripeAdapter.ChargeListAsync(new Stripe.ChargeListOptions + if (user.Gateway == GatewayType.Stripe && !string.IsNullOrWhiteSpace(user.GatewayCustomerId)) { - Customer = subscriber.GatewayCustomerId - }); - - if (charges?.Data != null) - { - foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) + if (!string.IsNullOrWhiteSpace(paymentToken)) { - await _stripeAdapter.RefundCreateAsync(new Stripe.RefundCreateOptions { Charge = charge.Id }); - } - } - } - - await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); - } - - public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Stripe.Invoice invoice) - { - var customerOptions = new Stripe.CustomerGetOptions(); - customerOptions.AddExpand("default_source"); - customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); - var usingInAppPaymentMethod = customer.Metadata.ContainsKey("appleReceipt"); - if (usingInAppPaymentMethod) - { - throw new BadRequestException("Cannot perform this action with in-app purchase payment method. " + - "Contact support."); - } - - string paymentIntentClientSecret = null; - - // Invoice them and pay now instead of waiting until Stripe does this automatically. - - string cardPaymentMethodId = null; - if (!customer.Metadata.ContainsKey("btCustomerId")) - { - var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card"; - var hasDefaultValidSource = customer.DefaultSource != null && - (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount); - if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource) - { - cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id; - if (cardPaymentMethodId == null) - { - // We're going to delete this draft invoice, it can't be paid try { - await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); + await UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, true, taxInfo); + } + catch (Exception e) + { + var message = e.Message.ToLowerInvariant(); + if (message.Contains("apple") || message.Contains("in-app")) + { + throw; + } + } + } + try + { + customer = await _stripeAdapter.CustomerGetAsync(user.GatewayCustomerId); + } + catch { } + } + + if (customer == null && !string.IsNullOrWhiteSpace(paymentToken)) + { + var stripeCustomerMetadata = new Dictionary(); + if (paymentMethodType == PaymentMethodType.PayPal) + { + var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); + var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest + { + PaymentMethodNonce = paymentToken, + Email = user.Email, + Id = user.BraintreeCustomerIdPrefix() + user.Id.ToString("N").ToLower() + randomSuffix, + CustomFields = new Dictionary + { + [user.BraintreeIdField()] = user.Id.ToString() + } + }); + + if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) + { + throw new GatewayException("Failed to create PayPal customer record."); + } + + braintreeCustomer = customerResult.Target; + stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); + } + else if (paymentMethodType == PaymentMethodType.AppleInApp) + { + var verifiedReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); + if (verifiedReceiptStatus == null) + { + throw new GatewayException("Cannot verify apple in-app purchase."); + } + var receiptOriginalTransactionId = verifiedReceiptStatus.GetOriginalTransactionId(); + await VerifyAppleReceiptNotInUseAsync(receiptOriginalTransactionId, user); + await _appleIapService.SaveReceiptAsync(verifiedReceiptStatus, user.Id); + stripeCustomerMetadata.Add("appleReceipt", receiptOriginalTransactionId); + } + else if (!stripePaymentMethod) + { + throw new GatewayException("Payment method is not supported at this time."); + } + + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions + { + Description = user.Name, + Email = user.Email, + Metadata = stripeCustomerMetadata, + PaymentMethod = stipeCustomerPaymentMethodId, + Source = stipeCustomerSourceToken, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + { + DefaultPaymentMethod = stipeCustomerPaymentMethodId + }, + Address = new Stripe.AddressOptions + { + Line1 = string.Empty, + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + }, + }); + createdStripeCustomer = true; + } + + if (customer == null) + { + throw new GatewayException("Could not set up customer payment profile."); + } + + var subCreateOptions = new Stripe.SubscriptionCreateOptions + { + Customer = customer.Id, + Items = new List(), + Metadata = new Dictionary + { + [user.GatewayIdField()] = user.Id.ToString() + } + }; + + subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions + { + Plan = paymentMethodType == PaymentMethodType.AppleInApp ? PremiumPlanAppleIapId : PremiumPlanId, + Quantity = 1, + }); + + if (!string.IsNullOrWhiteSpace(taxInfo?.BillingAddressCountry) + && !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressPostalCode)) + { + var taxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode + } + ); + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + subCreateOptions.DefaultTaxRates = new List(1) + { + taxRate.Id + }; + } + } + + if (additionalStorageGb > 0) + { + subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions + { + Plan = StoragePlanId, + Quantity = additionalStorageGb + }); + } + + var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer, + stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer); + + user.Gateway = GatewayType.Stripe; + user.GatewayCustomerId = customer.Id; + user.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == "incomplete" && + subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") + { + return subscription.LatestInvoice.PaymentIntent.ClientSecret; + } + else + { + user.Premium = true; + user.PremiumExpirationDate = subscription.CurrentPeriodEnd; + return null; + } + } + + private async Task ChargeForNewSubscriptionAsync(ISubscriber subcriber, Stripe.Customer customer, + bool createdStripeCustomer, bool stripePaymentMethod, PaymentMethodType paymentMethodType, + Stripe.SubscriptionCreateOptions subCreateOptions, Braintree.Customer braintreeCustomer) + { + var addedCreditToStripeCustomer = false; + Braintree.Transaction braintreeTransaction = null; + Transaction appleTransaction = null; + + var subInvoiceMetadata = new Dictionary(); + Stripe.Subscription subscription = null; + try + { + if (!stripePaymentMethod) + { + var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions + { + Customer = customer.Id, + SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), + SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, + }); + + if (previewInvoice.AmountDue > 0) + { + var appleReceiptOrigTransactionId = customer.Metadata != null && + customer.Metadata.ContainsKey("appleReceipt") ? customer.Metadata["appleReceipt"] : null; + var braintreeCustomerId = customer.Metadata != null && + customer.Metadata.ContainsKey("btCustomerId") ? customer.Metadata["btCustomerId"] : null; + if (!string.IsNullOrWhiteSpace(appleReceiptOrigTransactionId)) + { + if (!subcriber.IsUser()) + { + throw new GatewayException("In-app purchase is only allowed for users."); + } + + var appleReceipt = await _appleIapService.GetReceiptAsync( + appleReceiptOrigTransactionId); + var verifiedAppleReceipt = await _appleIapService.GetVerifiedReceiptStatusAsync( + appleReceipt.Item1); + if (verifiedAppleReceipt == null) + { + throw new GatewayException("Failed to get Apple in-app purchase receipt data."); + } + subInvoiceMetadata.Add("appleReceipt", verifiedAppleReceipt.GetOriginalTransactionId()); + var lastTransactionId = verifiedAppleReceipt.GetLastTransactionId(); + subInvoiceMetadata.Add("appleReceiptTransactionId", lastTransactionId); + var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.AppStore, lastTransactionId); + if (existingTransaction == null) + { + appleTransaction = verifiedAppleReceipt.BuildTransactionFromLastTransaction( + PremiumPlanAppleIapPrice, subcriber.Id); + appleTransaction.Type = TransactionType.Charge; + await _transactionRepository.CreateAsync(appleTransaction); + } + } + else if (!string.IsNullOrWhiteSpace(braintreeCustomerId)) + { + var btInvoiceAmount = (previewInvoice.AmountDue / 100M); + var transactionResult = await _btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest + { + Amount = btInvoiceAmount, + CustomerId = braintreeCustomerId, + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{subcriber.BraintreeIdField()}:{subcriber.Id}" + } + }, + CustomFields = new Dictionary + { + [subcriber.BraintreeIdField()] = subcriber.Id.ToString() + } + }); + + if (!transactionResult.IsSuccess()) + { + throw new GatewayException("Failed to charge PayPal customer."); + } + + braintreeTransaction = transactionResult.Target; + subInvoiceMetadata.Add("btTransactionId", braintreeTransaction.Id); + subInvoiceMetadata.Add("btPayPalTransactionId", + braintreeTransaction.PayPalDetails.AuthorizationId); + } + else + { + throw new GatewayException("No payment was able to be collected."); + } + + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = customer.Balance - previewInvoice.AmountDue + }); + addedCreditToStripeCustomer = true; + } + } + else if (paymentMethodType == PaymentMethodType.Credit) + { + var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions + { + Customer = customer.Id, + SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), + SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, + }); + if (previewInvoice.AmountDue > 0) + { + throw new GatewayException("Your account does not have enough credit available."); + } + } + + subCreateOptions.OffSession = true; + subCreateOptions.AddExpand("latest_invoice.payment_intent"); + subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); + if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) + { + if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") + { + await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); + throw new GatewayException("Payment method was declined."); + } + } + + if (!stripePaymentMethod && subInvoiceMetadata.Any()) + { + var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions + { + Subscription = subscription.Id + }); + + var invoice = invoices?.FirstOrDefault(); + if (invoice == null) + { + throw new GatewayException("Invoice not found."); + } + + await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions + { + Metadata = subInvoiceMetadata + }); + } + + return subscription; + } + catch (Exception e) + { + if (customer != null) + { + if (createdStripeCustomer) + { + await _stripeAdapter.CustomerDeleteAsync(customer.Id); + } + else if (addedCreditToStripeCustomer || customer.Balance < 0) + { + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = customer.Balance + }); + } + } + if (braintreeTransaction != null) + { + await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); + } + if (braintreeCustomer != null) + { + await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + } + if (appleTransaction != null) + { + await _transactionRepository.DeleteAsync(appleTransaction); + } + + if (e is Stripe.StripeException strEx && + (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) + { + throw new GatewayException("Bank account is not yet verified."); + } + + throw; + } + } + + private List ToInvoiceSubscriptionItemOptions( + List subItemOptions) + { + return subItemOptions.Select(si => new Stripe.InvoiceSubscriptionItemOptions + { + Plan = si.Plan, + Quantity = si.Quantity + }).ToList(); + } + + private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber, + SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate) + { + // remember, when in doubt, throw + + var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId); + if (sub == null) + { + throw new GatewayException("Subscription not found."); + } + + prorationDate ??= DateTime.UtcNow; + var collectionMethod = sub.CollectionMethod; + var daysUntilDue = sub.DaysUntilDue; + var chargeNow = collectionMethod == "charge_automatically"; + var updatedItemOptions = subscriptionUpdate.UpgradeItemsOptions(sub); + + var subUpdateOptions = new Stripe.SubscriptionUpdateOptions + { + Items = updatedItemOptions, + ProrationBehavior = "always_invoice", + DaysUntilDue = daysUntilDue ?? 1, + CollectionMethod = "send_invoice", + ProrationDate = prorationDate, + }; + + if (!subscriptionUpdate.UpdateNeeded(sub)) + { + // No need to update subscription, quantity matches + return null; + } + + var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); + + if (!string.IsNullOrWhiteSpace(customer?.Address?.Country) + && !string.IsNullOrWhiteSpace(customer?.Address?.PostalCode)) + { + var taxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() + { + Country = customer.Address.Country, + PostalCode = customer.Address.PostalCode + } + ); + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null && !sub.DefaultTaxRates.Any(x => x.Equals(taxRate.Id))) + { + subUpdateOptions.DefaultTaxRates = new List(1) + { + taxRate.Id + }; + } + } + + string paymentIntentClientSecret = null; + try + { + var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); + + var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new Stripe.InvoiceGetOptions()); + if (invoice == null) + { + throw new BadRequestException("Unable to locate draft invoice for subscription update."); + } + + if (invoice.AmountDue > 0 && updatedItemOptions.Any(i => i.Quantity > 0)) + { + try + { + if (chargeNow) + { + paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync( + storableSubscriber, invoice); + } + else + { + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, new Stripe.InvoiceFinalizeOptions + { + AutoAdvance = false, + }); + await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new Stripe.InvoiceSendOptions()); + paymentIntentClientSecret = null; + } } catch { - await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions + // Need to revert the subscription + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions { - AutoAdvance = false + Items = subscriptionUpdate.RevertItemsOptions(sub), + // This proration behavior prevents a false "credit" from + // being applied forward to the next month's invoice + ProrationBehavior = "none", + CollectionMethod = collectionMethod, + DaysUntilDue = daysUntilDue, }); - await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); + throw; } - throw new BadRequestException("No payment method is available."); + } + else if (!invoice.Paid) + { + // Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h + invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); + paymentIntentClientSecret = null; + } + + } + finally + { + // Change back the subscription collection method and/or days until due + if (collectionMethod != "send_invoice" || daysUntilDue == null) + { + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions + { + CollectionMethod = collectionMethod, + DaysUntilDue = daysUntilDue, + }); + } + } + + return paymentIntentClientSecret; + } + + public Task AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null) + { + return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); + } + + public Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, + string storagePlanId, DateTime? prorationDate = null) + { + return FinalizeSubscriptionChangeAsync(storableSubscriber, new StorageSubscriptionUpdate(storagePlanId, additionalStorage), prorationDate); + } + + public async Task CancelAndRecoverChargesAsync(ISubscriber subscriber) + { + if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + { + await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, + new Stripe.SubscriptionCancelOptions()); + } + + if (string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + return; + } + + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + if (customer == null) + { + return; + } + + if (customer.Metadata.ContainsKey("btCustomerId")) + { + var transactionRequest = new Braintree.TransactionSearchRequest() + .CustomerId.Is(customer.Metadata["btCustomerId"]); + var transactions = _btGateway.Transaction.Search(transactionRequest); + + if ((transactions?.MaximumCount ?? 0) > 0) + { + var txs = transactions.Cast().Where(c => c.RefundedTransactionId == null); + foreach (var transaction in txs) + { + await _btGateway.Transaction.RefundAsync(transaction.Id); + } + } + + await _btGateway.Customer.DeleteAsync(customer.Metadata["btCustomerId"]); + } + else + { + var charges = await _stripeAdapter.ChargeListAsync(new Stripe.ChargeListOptions + { + Customer = subscriber.GatewayCustomerId + }); + + if (charges?.Data != null) + { + foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) + { + await _stripeAdapter.RefundCreateAsync(new Stripe.RefundCreateOptions { Charge = charge.Id }); + } + } + } + + await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); + } + + public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Stripe.Invoice invoice) + { + var customerOptions = new Stripe.CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + var usingInAppPaymentMethod = customer.Metadata.ContainsKey("appleReceipt"); + if (usingInAppPaymentMethod) + { + throw new BadRequestException("Cannot perform this action with in-app purchase payment method. " + + "Contact support."); + } + + string paymentIntentClientSecret = null; + + // Invoice them and pay now instead of waiting until Stripe does this automatically. + + string cardPaymentMethodId = null; + if (!customer.Metadata.ContainsKey("btCustomerId")) + { + var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card"; + var hasDefaultValidSource = customer.DefaultSource != null && + (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount); + if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource) + { + cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id; + if (cardPaymentMethodId == null) + { + // We're going to delete this draft invoice, it can't be paid + try + { + await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); + } + catch + { + await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions + { + AutoAdvance = false + }); + await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); + } + throw new BadRequestException("No payment method is available."); + } + } + } + + Braintree.Transaction braintreeTransaction = null; + try + { + // Finalize the invoice (from Draft) w/o auto-advance so we + // can attempt payment manually. + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions + { + AutoAdvance = false, + }); + var invoicePayOptions = new Stripe.InvoicePayOptions + { + PaymentMethod = cardPaymentMethodId, + }; + if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) + { + invoicePayOptions.PaidOutOfBand = true; + var btInvoiceAmount = (invoice.AmountDue / 100M); + var transactionResult = await _btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest + { + Amount = btInvoiceAmount, + CustomerId = customer.Metadata["btCustomerId"], + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id}" + } + }, + CustomFields = new Dictionary + { + [subscriber.BraintreeIdField()] = subscriber.Id.ToString() + } + }); + + if (!transactionResult.IsSuccess()) + { + throw new GatewayException("Failed to charge PayPal customer."); + } + + braintreeTransaction = transactionResult.Target; + invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions + { + Metadata = new Dictionary + { + ["btTransactionId"] = braintreeTransaction.Id, + ["btPayPalTransactionId"] = + braintreeTransaction.PayPalDetails.AuthorizationId + }, + }); + invoicePayOptions.PaidOutOfBand = true; + } + + try + { + invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); + } + catch (Stripe.StripeException e) + { + if (e.HttpStatusCode == System.Net.HttpStatusCode.PaymentRequired && + e.StripeError?.Code == "invoice_payment_intent_requires_action") + { + // SCA required, get intent client secret + var invoiceGetOptions = new Stripe.InvoiceGetOptions(); + invoiceGetOptions.AddExpand("payment_intent"); + invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); + paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret; + } + else + { + throw new GatewayException("Unable to pay invoice."); + } + } + } + catch (Exception e) + { + if (braintreeTransaction != null) + { + await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); + } + if (invoice != null) + { + if (invoice.Status == "paid") + { + // It's apparently paid, so we need to return w/o throwing an exception + return paymentIntentClientSecret; + } + + invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new Stripe.InvoiceVoidOptions()); + + // HACK: Workaround for customer balance credit + if (invoice.StartingBalance < 0) + { + // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to + // credit it back to the customer (even though their docs claim they will), we need to + // check that balance against the current customer balance and determine if it needs to be re-applied + customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + + // Assumption: Customer balance should now be $0, otherwise payment would not have failed. + if (customer.Balance == 0) + { + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = invoice.StartingBalance + }); + } + } + } + + if (e is Stripe.StripeException strEx && + (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) + { + throw new GatewayException("Bank account is not yet verified."); + } + + // Let the caller perform any subscription change cleanup + throw; + } + return paymentIntentClientSecret; + } + + public async Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, + bool skipInAppPurchaseCheck = false) + { + if (subscriber == null) + { + throw new ArgumentNullException(nameof(subscriber)); + } + + if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + { + throw new GatewayException("No subscription."); + } + + if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId) && !skipInAppPurchaseCheck) + { + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + if (customer.Metadata.ContainsKey("appleReceipt")) + { + throw new BadRequestException("You are required to manage your subscription from the app store."); + } + } + + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + if (sub == null) + { + throw new GatewayException("Subscription was not found."); + } + + if (sub.CanceledAt.HasValue || sub.Status == "canceled" || sub.Status == "unpaid" || + sub.Status == "incomplete_expired") + { + // Already canceled + return; + } + + try + { + var canceledSub = endOfPeriod ? + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) : + await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new Stripe.SubscriptionCancelOptions()); + if (!canceledSub.CanceledAt.HasValue) + { + throw new GatewayException("Unable to cancel subscription."); + } + } + catch (Stripe.StripeException e) + { + if (e.Message != $"No such subscription: {subscriber.GatewaySubscriptionId}") + { + throw; } } } - Braintree.Transaction braintreeTransaction = null; - try + public async Task ReinstateSubscriptionAsync(ISubscriber subscriber) { - // Finalize the invoice (from Draft) w/o auto-advance so we - // can attempt payment manually. - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions + if (subscriber == null) { - AutoAdvance = false, - }); - var invoicePayOptions = new Stripe.InvoicePayOptions + throw new ArgumentNullException(nameof(subscriber)); + } + + if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) { - PaymentMethod = cardPaymentMethodId, - }; - if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) + throw new GatewayException("No subscription."); + } + + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + if (sub == null) { - invoicePayOptions.PaidOutOfBand = true; - var btInvoiceAmount = (invoice.AmountDue / 100M); - var transactionResult = await _btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest + throw new GatewayException("Subscription was not found."); + } + + if ((sub.Status != "active" && sub.Status != "trialing" && !sub.Status.StartsWith("incomplete")) || + !sub.CanceledAt.HasValue) + { + throw new GatewayException("Subscription is not marked for cancellation."); + } + + var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); + if (updatedSub.CanceledAt.HasValue) + { + throw new GatewayException("Unable to reinstate subscription."); + } + } + + public async Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, + string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null) + { + if (subscriber == null) + { + throw new ArgumentNullException(nameof(subscriber)); + } + + if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe) + { + throw new GatewayException("Switching from one payment type to another is not supported. " + + "Contact us for assistance."); + } + + var createdCustomer = false; + AppleReceiptStatus appleReceiptStatus = null; + Braintree.Customer braintreeCustomer = null; + string stipeCustomerSourceToken = null; + string stipeCustomerPaymentMethodId = null; + var stripeCustomerMetadata = new Dictionary(); + var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || + paymentMethodType == PaymentMethodType.BankAccount; + var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp || + paymentMethodType == PaymentMethodType.GoogleInApp; + + Stripe.Customer customer = null; + + if (!allowInAppPurchases && inAppPurchase) + { + throw new GatewayException("In-app purchase payment method is not allowed."); + } + + if (!subscriber.IsUser() && inAppPurchase) + { + throw new GatewayException("In-app purchase payment method is only allowed for users."); + } + + if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + var options = new Stripe.CustomerGetOptions(); + options.AddExpand("sources"); + customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options); + if (customer.Metadata?.Any() ?? false) + { + stripeCustomerMetadata = customer.Metadata; + } + } + + if (inAppPurchase && customer != null && customer.Balance != 0) + { + throw new GatewayException("Customer balance cannot exist when using in-app purchases."); + } + + if (!inAppPurchase && customer != null && stripeCustomerMetadata.ContainsKey("appleReceipt")) + { + throw new GatewayException("Cannot change from in-app payment method. Contact support."); + } + + var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId"); + if (stripePaymentMethod) + { + if (paymentToken.StartsWith("pm_")) + { + stipeCustomerPaymentMethodId = paymentToken; + } + else + { + stipeCustomerSourceToken = paymentToken; + } + } + else if (paymentMethodType == PaymentMethodType.PayPal) + { + if (hadBtCustomer) + { + var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest { - Amount = btInvoiceAmount, - CustomerId = customer.Metadata["btCustomerId"], - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest + CustomerId = stripeCustomerMetadata["btCustomerId"], + PaymentMethodNonce = paymentToken + }); + + if (pmResult.IsSuccess()) + { + var customerResult = await _btGateway.Customer.UpdateAsync( + stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest { - CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id}" - } - }, + DefaultPaymentMethodToken = pmResult.Target.Token + }); + + if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0) + { + braintreeCustomer = customerResult.Target; + } + else + { + await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token); + hadBtCustomer = false; + } + } + else + { + hadBtCustomer = false; + } + } + + if (!hadBtCustomer) + { + var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest + { + PaymentMethodNonce = paymentToken, + Email = subscriber.BillingEmailAddress(), + Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() + + Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false), CustomFields = new Dictionary { [subscriber.BraintreeIdField()] = subscriber.Id.ToString() } }); - if (!transactionResult.IsSuccess()) - { - throw new GatewayException("Failed to charge PayPal customer."); - } - - braintreeTransaction = transactionResult.Target; - invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions - { - Metadata = new Dictionary + if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) { - ["btTransactionId"] = braintreeTransaction.Id, - ["btPayPalTransactionId"] = - braintreeTransaction.PayPalDetails.AuthorizationId - }, - }); - invoicePayOptions.PaidOutOfBand = true; + throw new GatewayException("Failed to create PayPal customer record."); + } + + braintreeCustomer = customerResult.Target; + } + } + else if (paymentMethodType == PaymentMethodType.AppleInApp) + { + appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); + if (appleReceiptStatus == null) + { + throw new GatewayException("Cannot verify Apple in-app purchase."); + } + await VerifyAppleReceiptNotInUseAsync(appleReceiptStatus.GetOriginalTransactionId(), subscriber); + } + else + { + throw new GatewayException("Payment method is not supported at this time."); + } + + if (stripeCustomerMetadata.ContainsKey("btCustomerId")) + { + if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"]) + { + var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); + stripeCustomerMetadata.Add($"btCustomerId_{nowSec}", stripeCustomerMetadata["btCustomerId"]); + } + stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id; + } + else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id)) + { + stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); + } + + if (appleReceiptStatus != null) + { + var originalTransactionId = appleReceiptStatus.GetOriginalTransactionId(); + if (stripeCustomerMetadata.ContainsKey("appleReceipt")) + { + if (originalTransactionId != stripeCustomerMetadata["appleReceipt"]) + { + var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); + stripeCustomerMetadata.Add($"appleReceipt_{nowSec}", stripeCustomerMetadata["appleReceipt"]); + } + stripeCustomerMetadata["appleReceipt"] = originalTransactionId; + } + else + { + stripeCustomerMetadata.Add("appleReceipt", originalTransactionId); + } + await _appleIapService.SaveReceiptAsync(appleReceiptStatus, subscriber.Id); } try { - invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); - } - catch (Stripe.StripeException e) - { - if (e.HttpStatusCode == System.Net.HttpStatusCode.PaymentRequired && - e.StripeError?.Code == "invoice_payment_intent_requires_action") + if (customer == null) { - // SCA required, get intent client secret - var invoiceGetOptions = new Stripe.InvoiceGetOptions(); - invoiceGetOptions.AddExpand("payment_intent"); - invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); - paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret; - } - else - { - throw new GatewayException("Unable to pay invoice."); - } - } - } - catch (Exception e) - { - if (braintreeTransaction != null) - { - await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); - } - if (invoice != null) - { - if (invoice.Status == "paid") - { - // It's apparently paid, so we need to return w/o throwing an exception - return paymentIntentClientSecret; - } - - invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new Stripe.InvoiceVoidOptions()); - - // HACK: Workaround for customer balance credit - if (invoice.StartingBalance < 0) - { - // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to - // credit it back to the customer (even though their docs claim they will), we need to - // check that balance against the current customer balance and determine if it needs to be re-applied - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); - - // Assumption: Customer balance should now be $0, otherwise payment would not have failed. - if (customer.Balance == 0) + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + Description = subscriber.BillingName(), + Email = subscriber.BillingEmailAddress(), + Metadata = stripeCustomerMetadata, + Source = stipeCustomerSourceToken, + PaymentMethod = stipeCustomerPaymentMethodId, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions { - Balance = invoice.StartingBalance - }); - } - } - } - - if (e is Stripe.StripeException strEx && - (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) - { - throw new GatewayException("Bank account is not yet verified."); - } - - // Let the caller perform any subscription change cleanup - throw; - } - return paymentIntentClientSecret; - } - - public async Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, - bool skipInAppPurchaseCheck = false) - { - if (subscriber == null) - { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - throw new GatewayException("No subscription."); - } - - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId) && !skipInAppPurchaseCheck) - { - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - if (customer.Metadata.ContainsKey("appleReceipt")) - { - throw new BadRequestException("You are required to manage your subscription from the app store."); - } - } - - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - if (sub == null) - { - throw new GatewayException("Subscription was not found."); - } - - if (sub.CanceledAt.HasValue || sub.Status == "canceled" || sub.Status == "unpaid" || - sub.Status == "incomplete_expired") - { - // Already canceled - return; - } - - try - { - var canceledSub = endOfPeriod ? - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, - new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) : - await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new Stripe.SubscriptionCancelOptions()); - if (!canceledSub.CanceledAt.HasValue) - { - throw new GatewayException("Unable to cancel subscription."); - } - } - catch (Stripe.StripeException e) - { - if (e.Message != $"No such subscription: {subscriber.GatewaySubscriptionId}") - { - throw; - } - } - } - - public async Task ReinstateSubscriptionAsync(ISubscriber subscriber) - { - if (subscriber == null) - { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - throw new GatewayException("No subscription."); - } - - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - if (sub == null) - { - throw new GatewayException("Subscription was not found."); - } - - if ((sub.Status != "active" && sub.Status != "trialing" && !sub.Status.StartsWith("incomplete")) || - !sub.CanceledAt.HasValue) - { - throw new GatewayException("Subscription is not marked for cancellation."); - } - - var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, - new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); - if (updatedSub.CanceledAt.HasValue) - { - throw new GatewayException("Unable to reinstate subscription."); - } - } - - public async Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null) - { - if (subscriber == null) - { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe) - { - throw new GatewayException("Switching from one payment type to another is not supported. " + - "Contact us for assistance."); - } - - var createdCustomer = false; - AppleReceiptStatus appleReceiptStatus = null; - Braintree.Customer braintreeCustomer = null; - string stipeCustomerSourceToken = null; - string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary(); - var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount; - var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp || - paymentMethodType == PaymentMethodType.GoogleInApp; - - Stripe.Customer customer = null; - - if (!allowInAppPurchases && inAppPurchase) - { - throw new GatewayException("In-app purchase payment method is not allowed."); - } - - if (!subscriber.IsUser() && inAppPurchase) - { - throw new GatewayException("In-app purchase payment method is only allowed for users."); - } - - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var options = new Stripe.CustomerGetOptions(); - options.AddExpand("sources"); - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options); - if (customer.Metadata?.Any() ?? false) - { - stripeCustomerMetadata = customer.Metadata; - } - } - - if (inAppPurchase && customer != null && customer.Balance != 0) - { - throw new GatewayException("Customer balance cannot exist when using in-app purchases."); - } - - if (!inAppPurchase && customer != null && stripeCustomerMetadata.ContainsKey("appleReceipt")) - { - throw new GatewayException("Cannot change from in-app payment method. Contact support."); - } - - var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId"); - if (stripePaymentMethod) - { - if (paymentToken.StartsWith("pm_")) - { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; - } - } - else if (paymentMethodType == PaymentMethodType.PayPal) - { - if (hadBtCustomer) - { - var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest - { - CustomerId = stripeCustomerMetadata["btCustomerId"], - PaymentMethodNonce = paymentToken - }); - - if (pmResult.IsSuccess()) - { - var customerResult = await _btGateway.Customer.UpdateAsync( - stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest + DefaultPaymentMethod = stipeCustomerPaymentMethodId + }, + Address = taxInfo == null ? null : new Stripe.AddressOptions { - DefaultPaymentMethodToken = pmResult.Target.Token - }); + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + }, + Expand = new List { "sources" }, + }); - if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0) - { - braintreeCustomer = customerResult.Target; - } - else - { - await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token); - hadBtCustomer = false; - } - } - else - { - hadBtCustomer = false; - } - } - - if (!hadBtCustomer) - { - var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest - { - PaymentMethodNonce = paymentToken, - Email = subscriber.BillingEmailAddress(), - Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() + - Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false), - CustomFields = new Dictionary - { - [subscriber.BraintreeIdField()] = subscriber.Id.ToString() - } - }); - - if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) - { - throw new GatewayException("Failed to create PayPal customer record."); + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = customer.Id; + createdCustomer = true; } - braintreeCustomer = customerResult.Target; - } - } - else if (paymentMethodType == PaymentMethodType.AppleInApp) - { - appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); - if (appleReceiptStatus == null) - { - throw new GatewayException("Cannot verify Apple in-app purchase."); - } - await VerifyAppleReceiptNotInUseAsync(appleReceiptStatus.GetOriginalTransactionId(), subscriber); - } - else - { - throw new GatewayException("Payment method is not supported at this time."); - } - - if (stripeCustomerMetadata.ContainsKey("btCustomerId")) - { - if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"]) - { - var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); - stripeCustomerMetadata.Add($"btCustomerId_{nowSec}", stripeCustomerMetadata["btCustomerId"]); - } - stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id; - } - else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id)) - { - stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); - } - - if (appleReceiptStatus != null) - { - var originalTransactionId = appleReceiptStatus.GetOriginalTransactionId(); - if (stripeCustomerMetadata.ContainsKey("appleReceipt")) - { - if (originalTransactionId != stripeCustomerMetadata["appleReceipt"]) + if (!createdCustomer) { - var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); - stripeCustomerMetadata.Add($"appleReceipt_{nowSec}", stripeCustomerMetadata["appleReceipt"]); - } - stripeCustomerMetadata["appleReceipt"] = originalTransactionId; - } - else - { - stripeCustomerMetadata.Add("appleReceipt", originalTransactionId); - } - await _appleIapService.SaveReceiptAsync(appleReceiptStatus, subscriber.Id); - } - - try - { - if (customer == null) - { - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions - { - Description = subscriber.BillingName(), - Email = subscriber.BillingEmailAddress(), - Metadata = stripeCustomerMetadata, - Source = stipeCustomerSourceToken, - PaymentMethod = stipeCustomerPaymentMethodId, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + string defaultSourceId = null; + string defaultPaymentMethodId = null; + if (stripePaymentMethod) { - DefaultPaymentMethod = stipeCustomerPaymentMethodId - }, - Address = taxInfo == null ? null : new Stripe.AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - }, - Expand = new List { "sources" }, - }); - - subscriber.Gateway = GatewayType.Stripe; - subscriber.GatewayCustomerId = customer.Id; - createdCustomer = true; - } - - if (!createdCustomer) - { - string defaultSourceId = null; - string defaultPaymentMethodId = null; - if (stripePaymentMethod) - { - if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_")) - { - var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new Stripe.BankAccountCreateOptions + if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_")) { - Source = paymentToken - }); - defaultSourceId = bankAccount.Id; - } - else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId, - new Stripe.PaymentMethodAttachOptions { Customer = customer.Id }); - defaultPaymentMethodId = stipeCustomerPaymentMethodId; - } - } - - if (customer.Sources != null) - { - foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId)) - { - if (source is Stripe.BankAccount) - { - await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new Stripe.BankAccountCreateOptions + { + Source = paymentToken + }); + defaultSourceId = bankAccount.Id; } - else if (source is Stripe.Card) + else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId)) { - await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId, + new Stripe.PaymentMethodAttachOptions { Customer = customer.Id }); + defaultPaymentMethodId = stipeCustomerPaymentMethodId; } } - } - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new Stripe.PaymentMethodListOptions - { - Customer = customer.Id, - Type = "card" - }); - foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new Stripe.PaymentMethodDetachOptions()); - } - - customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Metadata = stripeCustomerMetadata, - DefaultSource = defaultSourceId, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + if (customer.Sources != null) { - DefaultPaymentMethod = defaultPaymentMethodId - }, - Address = taxInfo == null ? null : new Stripe.AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - }, - }); - } - } - catch - { - if (braintreeCustomer != null && !hadBtCustomer) - { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); - } - throw; - } - - return createdCustomer; - } - - public async Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) - { - Stripe.Customer customer = null; - var customerExists = subscriber.Gateway == GatewayType.Stripe && - !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); - if (customerExists) - { - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - } - else - { - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions - { - Email = subscriber.BillingEmailAddress(), - Description = subscriber.BillingName(), - }); - subscriber.Gateway = GatewayType.Stripe; - subscriber.GatewayCustomerId = customer.Id; - } - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = customer.Balance - (long)(creditAmount * 100) - }); - return !customerExists; - } - - public async Task GetBillingAsync(ISubscriber subscriber) - { - var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); - var billingInfo = new BillingInfo - { - Balance = GetBillingBalance(customer), - PaymentSource = await GetBillingPaymentSourceAsync(customer), - Invoices = await GetBillingInvoicesAsync(customer), - Transactions = await GetBillingTransactionsAsync(subscriber) - }; - - return billingInfo; - } - - public async Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber) - { - var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); - var billingInfo = new BillingInfo - { - Balance = GetBillingBalance(customer), - PaymentSource = await GetBillingPaymentSourceAsync(customer) - }; - - return billingInfo; - } - - public async Task GetBillingHistoryAsync(ISubscriber subscriber) - { - var customer = await GetCustomerAsync(subscriber.GatewayCustomerId); - var billingInfo = new BillingInfo - { - Transactions = await GetBillingTransactionsAsync(subscriber), - Invoices = await GetBillingInvoicesAsync(customer) - }; - - return billingInfo; - } - - public async Task GetSubscriptionAsync(ISubscriber subscriber) - { - var subscriptionInfo = new SubscriptionInfo(); - - if (subscriber.IsUser() && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - subscriptionInfo.UsingInAppPurchase = customer.Metadata.ContainsKey("appleReceipt"); - } - - if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - if (sub != null) - { - subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); - } - - if (!sub.CanceledAt.HasValue && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - try - { - var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync( - new Stripe.UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId }); - if (upcomingInvoice != null) - { - subscriptionInfo.UpcomingInvoice = - new SubscriptionInfo.BillingUpcomingInvoice(upcomingInvoice); + foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId)) + { + if (source is Stripe.BankAccount) + { + await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + } + else if (source is Stripe.Card) + { + await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + } + } } - } - catch (Stripe.StripeException) { } - } - } - return subscriptionInfo; - } - - public async Task GetTaxInfoAsync(ISubscriber subscriber) - { - if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return null; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, - new Stripe.CustomerGetOptions { Expand = new List { "tax_ids" } }); - - if (customer == null) - { - return null; - } - - var address = customer.Address; - var taxId = customer.TaxIds?.FirstOrDefault(); - - // Line1 is required, so if missing we're using the subscriber name - // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 - if (address != null && string.IsNullOrWhiteSpace(address.Line1)) - { - address.Line1 = null; - } - - return new TaxInfo - { - TaxIdNumber = taxId?.Value, - BillingAddressLine1 = address?.Line1, - BillingAddressLine2 = address?.Line2, - BillingAddressCity = address?.City, - BillingAddressState = address?.State, - BillingAddressPostalCode = address?.PostalCode, - BillingAddressCountry = address?.Country, - }; - } - - public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) - { - if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new Stripe.CustomerUpdateOptions - { - Address = new Stripe.AddressOptions - { - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - PostalCode = taxInfo.BillingAddressPostalCode, - Country = taxInfo.BillingAddressCountry, - }, - Expand = new List { "tax_ids" } - }); - - if (!subscriber.IsUser() && customer != null) - { - var taxId = customer.TaxIds?.FirstOrDefault(); - - if (taxId != null) - { - await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); - } - if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) && - !string.IsNullOrWhiteSpace(taxInfo.TaxIdType)) - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, new Stripe.TaxIdCreateOptions + var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new Stripe.PaymentMethodListOptions { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber, + Customer = customer.Id, + Type = "card" + }); + foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId)) + { + await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new Stripe.PaymentMethodDetachOptions()); + } + + customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Metadata = stripeCustomerMetadata, + DefaultSource = defaultSourceId, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + { + DefaultPaymentMethod = defaultPaymentMethodId + }, + Address = taxInfo == null ? null : new Stripe.AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + }, }); } } - } - } - - public async Task CreateTaxRateAsync(TaxRate taxRate) - { - var stripeTaxRateOptions = new Stripe.TaxRateCreateOptions() - { - DisplayName = $"{taxRate.Country} - {taxRate.PostalCode}", - Inclusive = false, - Percentage = taxRate.Rate, - Active = true - }; - var stripeTaxRate = await _stripeAdapter.TaxRateCreateAsync(stripeTaxRateOptions); - taxRate.Id = stripeTaxRate.Id; - await _taxRateRepository.CreateAsync(taxRate); - return taxRate; - } - - public async Task UpdateTaxRateAsync(TaxRate taxRate) - { - if (string.IsNullOrWhiteSpace(taxRate.Id)) - { - return; - } - - await ArchiveTaxRateAsync(taxRate); - await CreateTaxRateAsync(taxRate); - } - - public async Task ArchiveTaxRateAsync(TaxRate taxRate) - { - if (string.IsNullOrWhiteSpace(taxRate.Id)) - { - return; - } - - var updatedStripeTaxRate = await _stripeAdapter.TaxRateUpdateAsync( - taxRate.Id, - new Stripe.TaxRateUpdateOptions() { Active = false } - ); - if (!updatedStripeTaxRate.Active) - { - taxRate.Active = false; - await _taxRateRepository.ArchiveAsync(taxRate); - } - } - - private Stripe.PaymentMethod GetLatestCardPaymentMethod(string customerId) - { - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( - new Stripe.PaymentMethodListOptions { Customer = customerId, Type = "card" }); - return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); - } - - private async Task VerifyAppleReceiptNotInUseAsync(string receiptOriginalTransactionId, ISubscriber subscriber) - { - var existingReceipt = await _appleIapService.GetReceiptAsync(receiptOriginalTransactionId); - if (existingReceipt != null && existingReceipt.Item2.HasValue && existingReceipt.Item2 != subscriber.Id) - { - var existingUser = await _userRepository.GetByIdAsync(existingReceipt.Item2.Value); - if (existingUser != null) + catch { - throw new GatewayException("Apple receipt already in use by another user."); + if (braintreeCustomer != null && !hadBtCustomer) + { + await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + } + throw; } - } - } - private decimal GetBillingBalance(Stripe.Customer customer) - { - return customer != null ? customer.Balance / 100M : default; - } - - private async Task GetBillingPaymentSourceAsync(Stripe.Customer customer) - { - if (customer == null) - { - return null; + return createdCustomer; } - if (customer.Metadata?.ContainsKey("appleReceipt") ?? false) + public async Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) { - return new BillingInfo.BillingSource + Stripe.Customer customer = null; + var customerExists = subscriber.Gateway == GatewayType.Stripe && + !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); + if (customerExists) { - Type = PaymentMethodType.AppleInApp + customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + } + else + { + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions + { + Email = subscriber.BillingEmailAddress(), + Description = subscriber.BillingName(), + }); + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = customer.Id; + } + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = customer.Balance - (long)(creditAmount * 100) + }); + return !customerExists; + } + + public async Task GetBillingAsync(ISubscriber subscriber) + { + var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); + var billingInfo = new BillingInfo + { + Balance = GetBillingBalance(customer), + PaymentSource = await GetBillingPaymentSourceAsync(customer), + Invoices = await GetBillingInvoicesAsync(customer), + Transactions = await GetBillingTransactionsAsync(subscriber) + }; + + return billingInfo; + } + + public async Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber) + { + var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); + var billingInfo = new BillingInfo + { + Balance = GetBillingBalance(customer), + PaymentSource = await GetBillingPaymentSourceAsync(customer) + }; + + return billingInfo; + } + + public async Task GetBillingHistoryAsync(ISubscriber subscriber) + { + var customer = await GetCustomerAsync(subscriber.GatewayCustomerId); + var billingInfo = new BillingInfo + { + Transactions = await GetBillingTransactionsAsync(subscriber), + Invoices = await GetBillingInvoicesAsync(customer) + }; + + return billingInfo; + } + + public async Task GetSubscriptionAsync(ISubscriber subscriber) + { + var subscriptionInfo = new SubscriptionInfo(); + + if (subscriber.IsUser() && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + subscriptionInfo.UsingInAppPurchase = customer.Metadata.ContainsKey("appleReceipt"); + } + + if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + { + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + if (sub != null) + { + subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); + } + + if (!sub.CanceledAt.HasValue && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + try + { + var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync( + new Stripe.UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId }); + if (upcomingInvoice != null) + { + subscriptionInfo.UpcomingInvoice = + new SubscriptionInfo.BillingUpcomingInvoice(upcomingInvoice); + } + } + catch (Stripe.StripeException) { } + } + } + + return subscriptionInfo; + } + + public async Task GetTaxInfoAsync(ISubscriber subscriber) + { + if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + return null; + } + + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, + new Stripe.CustomerGetOptions { Expand = new List { "tax_ids" } }); + + if (customer == null) + { + return null; + } + + var address = customer.Address; + var taxId = customer.TaxIds?.FirstOrDefault(); + + // Line1 is required, so if missing we're using the subscriber name + // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 + if (address != null && string.IsNullOrWhiteSpace(address.Line1)) + { + address.Line1 = null; + } + + return new TaxInfo + { + TaxIdNumber = taxId?.Value, + BillingAddressLine1 = address?.Line1, + BillingAddressLine2 = address?.Line2, + BillingAddressCity = address?.City, + BillingAddressState = address?.State, + BillingAddressPostalCode = address?.PostalCode, + BillingAddressCountry = address?.Country, }; } - if (customer.Metadata?.ContainsKey("btCustomerId") ?? false) + public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) { - try + if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) { - var braintreeCustomer = await _btGateway.Customer.FindAsync( - customer.Metadata["btCustomerId"]); - if (braintreeCustomer?.DefaultPaymentMethod != null) + var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new Stripe.CustomerUpdateOptions { - return new BillingInfo.BillingSource( - braintreeCustomer.DefaultPaymentMethod); + Address = new Stripe.AddressOptions + { + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + PostalCode = taxInfo.BillingAddressPostalCode, + Country = taxInfo.BillingAddressCountry, + }, + Expand = new List { "tax_ids" } + }); + + if (!subscriber.IsUser() && customer != null) + { + var taxId = customer.TaxIds?.FirstOrDefault(); + + if (taxId != null) + { + await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + } + if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) && + !string.IsNullOrWhiteSpace(taxInfo.TaxIdType)) + { + await _stripeAdapter.TaxIdCreateAsync(customer.Id, new Stripe.TaxIdCreateOptions + { + Type = taxInfo.TaxIdType, + Value = taxInfo.TaxIdNumber, + }); + } } } - catch (Braintree.Exceptions.NotFoundException) { } } - if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") + public async Task CreateTaxRateAsync(TaxRate taxRate) { - return new BillingInfo.BillingSource( - customer.InvoiceSettings.DefaultPaymentMethod); + var stripeTaxRateOptions = new Stripe.TaxRateCreateOptions() + { + DisplayName = $"{taxRate.Country} - {taxRate.PostalCode}", + Inclusive = false, + Percentage = taxRate.Rate, + Active = true + }; + var stripeTaxRate = await _stripeAdapter.TaxRateCreateAsync(stripeTaxRateOptions); + taxRate.Id = stripeTaxRate.Id; + await _taxRateRepository.CreateAsync(taxRate); + return taxRate; } - if (customer.DefaultSource != null && - (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount)) + public async Task UpdateTaxRateAsync(TaxRate taxRate) { - return new BillingInfo.BillingSource(customer.DefaultSource); + if (string.IsNullOrWhiteSpace(taxRate.Id)) + { + return; + } + + await ArchiveTaxRateAsync(taxRate); + await CreateTaxRateAsync(taxRate); } - var paymentMethod = GetLatestCardPaymentMethod(customer.Id); - return paymentMethod != null ? new BillingInfo.BillingSource(paymentMethod) : null; - } - - private Stripe.CustomerGetOptions GetCustomerPaymentOptions() - { - var customerOptions = new Stripe.CustomerGetOptions(); - customerOptions.AddExpand("default_source"); - customerOptions.AddExpand("invoice_settings.default_payment_method"); - return customerOptions; - } - - private async Task GetCustomerAsync(string gatewayCustomerId, Stripe.CustomerGetOptions options = null) - { - if (string.IsNullOrWhiteSpace(gatewayCustomerId)) + public async Task ArchiveTaxRateAsync(TaxRate taxRate) { - return null; + if (string.IsNullOrWhiteSpace(taxRate.Id)) + { + return; + } + + var updatedStripeTaxRate = await _stripeAdapter.TaxRateUpdateAsync( + taxRate.Id, + new Stripe.TaxRateUpdateOptions() { Active = false } + ); + if (!updatedStripeTaxRate.Active) + { + taxRate.Active = false; + await _taxRateRepository.ArchiveAsync(taxRate); + } } - Stripe.Customer customer = null; - try + private Stripe.PaymentMethod GetLatestCardPaymentMethod(string customerId) { - customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); - } - catch (Stripe.StripeException) { } - - return customer; - } - - private async Task> GetBillingTransactionsAsync(ISubscriber subscriber) - { - ICollection transactions = null; - if (subscriber is User) - { - transactions = await _transactionRepository.GetManyByUserIdAsync(subscriber.Id); - } - else if (subscriber is Organization) - { - transactions = await _transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id); + var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( + new Stripe.PaymentMethodListOptions { Customer = customerId, Type = "card" }); + return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); } - return transactions?.OrderByDescending(i => i.CreationDate) - .Select(t => new BillingInfo.BillingTransaction(t)); - - } - - private async Task> GetBillingInvoicesAsync(Stripe.Customer customer) - { - if (customer == null) + private async Task VerifyAppleReceiptNotInUseAsync(string receiptOriginalTransactionId, ISubscriber subscriber) { - return null; + var existingReceipt = await _appleIapService.GetReceiptAsync(receiptOriginalTransactionId); + if (existingReceipt != null && existingReceipt.Item2.HasValue && existingReceipt.Item2 != subscriber.Id) + { + var existingUser = await _userRepository.GetByIdAsync(existingReceipt.Item2.Value); + if (existingUser != null) + { + throw new GatewayException("Apple receipt already in use by another user."); + } + } } - var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions + private decimal GetBillingBalance(Stripe.Customer customer) { - Customer = customer.Id, - Limit = 50 - }); + return customer != null ? customer.Balance / 100M : default; + } - return invoices.Data.Where(i => i.Status != "void" && i.Status != "draft") - .OrderByDescending(i => i.Created).Select(i => new BillingInfo.BillingInvoice(i)); + private async Task GetBillingPaymentSourceAsync(Stripe.Customer customer) + { + if (customer == null) + { + return null; + } + if (customer.Metadata?.ContainsKey("appleReceipt") ?? false) + { + return new BillingInfo.BillingSource + { + Type = PaymentMethodType.AppleInApp + }; + } + + if (customer.Metadata?.ContainsKey("btCustomerId") ?? false) + { + try + { + var braintreeCustomer = await _btGateway.Customer.FindAsync( + customer.Metadata["btCustomerId"]); + if (braintreeCustomer?.DefaultPaymentMethod != null) + { + return new BillingInfo.BillingSource( + braintreeCustomer.DefaultPaymentMethod); + } + } + catch (Braintree.Exceptions.NotFoundException) { } + } + + if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") + { + return new BillingInfo.BillingSource( + customer.InvoiceSettings.DefaultPaymentMethod); + } + + if (customer.DefaultSource != null && + (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount)) + { + return new BillingInfo.BillingSource(customer.DefaultSource); + } + + var paymentMethod = GetLatestCardPaymentMethod(customer.Id); + return paymentMethod != null ? new BillingInfo.BillingSource(paymentMethod) : null; + } + + private Stripe.CustomerGetOptions GetCustomerPaymentOptions() + { + var customerOptions = new Stripe.CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + return customerOptions; + } + + private async Task GetCustomerAsync(string gatewayCustomerId, Stripe.CustomerGetOptions options = null) + { + if (string.IsNullOrWhiteSpace(gatewayCustomerId)) + { + return null; + } + + Stripe.Customer customer = null; + try + { + customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); + } + catch (Stripe.StripeException) { } + + return customer; + } + + private async Task> GetBillingTransactionsAsync(ISubscriber subscriber) + { + ICollection transactions = null; + if (subscriber is User) + { + transactions = await _transactionRepository.GetManyByUserIdAsync(subscriber.Id); + } + else if (subscriber is Organization) + { + transactions = await _transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id); + } + + return transactions?.OrderByDescending(i => i.CreationDate) + .Select(t => new BillingInfo.BillingTransaction(t)); + + } + + private async Task> GetBillingInvoicesAsync(Stripe.Customer customer) + { + if (customer == null) + { + return null; + } + + var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions + { + Customer = customer.Id, + Limit = 50 + }); + + return invoices.Data.Where(i => i.Status != "void" && i.Status != "draft") + .OrderByDescending(i => i.Created).Select(i => new BillingInfo.BillingInvoice(i)); + + } } } diff --git a/src/Core/Services/Implementations/StripeSyncService.cs b/src/Core/Services/Implementations/StripeSyncService.cs index b2700e65d1..f042eac5c0 100644 --- a/src/Core/Services/Implementations/StripeSyncService.cs +++ b/src/Core/Services/Implementations/StripeSyncService.cs @@ -1,31 +1,32 @@ using Bit.Core.Exceptions; -namespace Bit.Core.Services; - -public class StripeSyncService : IStripeSyncService +namespace Bit.Core.Services { - private readonly IStripeAdapter _stripeAdapter; - - public StripeSyncService(IStripeAdapter stripeAdapter) + public class StripeSyncService : IStripeSyncService { - _stripeAdapter = stripeAdapter; - } + private readonly IStripeAdapter _stripeAdapter; - public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) - { - if (string.IsNullOrWhiteSpace(gatewayCustomerId)) + public StripeSyncService(IStripeAdapter stripeAdapter) { - throw new InvalidGatewayCustomerIdException(); + _stripeAdapter = stripeAdapter; } - if (string.IsNullOrWhiteSpace(emailAddress)) + public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) { - throw new InvalidEmailException(); + if (string.IsNullOrWhiteSpace(gatewayCustomerId)) + { + throw new InvalidGatewayCustomerIdException(); + } + + if (string.IsNullOrWhiteSpace(emailAddress)) + { + throw new InvalidEmailException(); + } + + var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + + await _stripeAdapter.CustomerUpdateAsync(customer.Id, + new Stripe.CustomerUpdateOptions { Email = emailAddress }); } - - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - await _stripeAdapter.CustomerUpdateAsync(customer.Id, - new Stripe.CustomerUpdateOptions { Email = emailAddress }); } } diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 46509ceda4..d54ea7bb41 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -17,980 +17,1072 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using File = System.IO.File; -namespace Bit.Core.Services; - -public class UserService : UserManager, IUserService, IDisposable +namespace Bit.Core.Services { - private const string PremiumPlanId = "premium-annually"; - private const string StoragePlanId = "storage-gb-annually"; - - private readonly IUserRepository _userRepository; - private readonly ICipherRepository _cipherRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IMailService _mailService; - private readonly IPushNotificationService _pushService; - private readonly IdentityErrorDescriber _identityErrorDescriber; - private readonly IdentityOptions _identityOptions; - private readonly IPasswordHasher _passwordHasher; - private readonly IEnumerable> _passwordValidators; - private readonly ILicensingService _licenseService; - private readonly IEventService _eventService; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; - private readonly IDataProtector _organizationServiceDataProtector; - private readonly IReferenceEventService _referenceEventService; - private readonly IFido2 _fido2; - private readonly ICurrentContext _currentContext; - private readonly IGlobalSettings _globalSettings; - private readonly IOrganizationService _organizationService; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IDeviceRepository _deviceRepository; - private readonly IStripeSyncService _stripeSyncService; - - public UserService( - IUserRepository userRepository, - ICipherRepository cipherRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationRepository organizationRepository, - IMailService mailService, - IPushNotificationService pushService, - IUserStore store, - IOptions optionsAccessor, - IPasswordHasher passwordHasher, - IEnumerable> userValidators, - IEnumerable> passwordValidators, - ILookupNormalizer keyNormalizer, - IdentityErrorDescriber errors, - IServiceProvider services, - ILogger> logger, - ILicensingService licenseService, - IEventService eventService, - IApplicationCacheService applicationCacheService, - IDataProtectionProvider dataProtectionProvider, - IPaymentService paymentService, - IPolicyRepository policyRepository, - IReferenceEventService referenceEventService, - IFido2 fido2, - ICurrentContext currentContext, - IGlobalSettings globalSettings, - IOrganizationService organizationService, - IProviderUserRepository providerUserRepository, - IDeviceRepository deviceRepository, - IStripeSyncService stripeSyncService) - : base( - store, - optionsAccessor, - passwordHasher, - userValidators, - passwordValidators, - keyNormalizer, - errors, - services, - logger) + public class UserService : UserManager, IUserService, IDisposable { - _userRepository = userRepository; - _cipherRepository = cipherRepository; - _organizationUserRepository = organizationUserRepository; - _organizationRepository = organizationRepository; - _mailService = mailService; - _pushService = pushService; - _identityOptions = optionsAccessor?.Value ?? new IdentityOptions(); - _identityErrorDescriber = errors; - _passwordHasher = passwordHasher; - _passwordValidators = passwordValidators; - _licenseService = licenseService; - _eventService = eventService; - _applicationCacheService = applicationCacheService; - _paymentService = paymentService; - _policyRepository = policyRepository; - _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( - "OrganizationServiceDataProtector"); - _referenceEventService = referenceEventService; - _fido2 = fido2; - _currentContext = currentContext; - _globalSettings = globalSettings; - _organizationService = organizationService; - _providerUserRepository = providerUserRepository; - _deviceRepository = deviceRepository; - _stripeSyncService = stripeSyncService; - } + private const string PremiumPlanId = "premium-annually"; + private const string StoragePlanId = "storage-gb-annually"; - public Guid? GetProperUserId(ClaimsPrincipal principal) - { - if (!Guid.TryParse(GetUserId(principal), out var userIdGuid)) + private readonly IUserRepository _userRepository; + private readonly ICipherRepository _cipherRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IMailService _mailService; + private readonly IPushNotificationService _pushService; + private readonly IdentityErrorDescriber _identityErrorDescriber; + private readonly IdentityOptions _identityOptions; + private readonly IPasswordHasher _passwordHasher; + private readonly IEnumerable> _passwordValidators; + private readonly ILicensingService _licenseService; + private readonly IEventService _eventService; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IPaymentService _paymentService; + private readonly IPolicyRepository _policyRepository; + private readonly IDataProtector _organizationServiceDataProtector; + private readonly IReferenceEventService _referenceEventService; + private readonly IFido2 _fido2; + private readonly ICurrentContext _currentContext; + private readonly IGlobalSettings _globalSettings; + private readonly IOrganizationService _organizationService; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IDeviceRepository _deviceRepository; + private readonly IStripeSyncService _stripeSyncService; + + public UserService( + IUserRepository userRepository, + ICipherRepository cipherRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + IMailService mailService, + IPushNotificationService pushService, + IUserStore store, + IOptions optionsAccessor, + IPasswordHasher passwordHasher, + IEnumerable> userValidators, + IEnumerable> passwordValidators, + ILookupNormalizer keyNormalizer, + IdentityErrorDescriber errors, + IServiceProvider services, + ILogger> logger, + ILicensingService licenseService, + IEventService eventService, + IApplicationCacheService applicationCacheService, + IDataProtectionProvider dataProtectionProvider, + IPaymentService paymentService, + IPolicyRepository policyRepository, + IReferenceEventService referenceEventService, + IFido2 fido2, + ICurrentContext currentContext, + IGlobalSettings globalSettings, + IOrganizationService organizationService, + IProviderUserRepository providerUserRepository, + IDeviceRepository deviceRepository, + IStripeSyncService stripeSyncService) + : base( + store, + optionsAccessor, + passwordHasher, + userValidators, + passwordValidators, + keyNormalizer, + errors, + services, + logger) { - return null; + _userRepository = userRepository; + _cipherRepository = cipherRepository; + _organizationUserRepository = organizationUserRepository; + _organizationRepository = organizationRepository; + _mailService = mailService; + _pushService = pushService; + _identityOptions = optionsAccessor?.Value ?? new IdentityOptions(); + _identityErrorDescriber = errors; + _passwordHasher = passwordHasher; + _passwordValidators = passwordValidators; + _licenseService = licenseService; + _eventService = eventService; + _applicationCacheService = applicationCacheService; + _paymentService = paymentService; + _policyRepository = policyRepository; + _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( + "OrganizationServiceDataProtector"); + _referenceEventService = referenceEventService; + _fido2 = fido2; + _currentContext = currentContext; + _globalSettings = globalSettings; + _organizationService = organizationService; + _providerUserRepository = providerUserRepository; + _deviceRepository = deviceRepository; + _stripeSyncService = stripeSyncService; } - return userIdGuid; - } - - public async Task GetUserByIdAsync(string userId) - { - if (_currentContext?.User != null && - string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) + public Guid? GetProperUserId(ClaimsPrincipal principal) { - return _currentContext.User; - } - - if (!Guid.TryParse(userId, out var userIdGuid)) - { - return null; - } - - _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); - return _currentContext.User; - } - - public async Task GetUserByIdAsync(Guid userId) - { - if (_currentContext?.User != null && _currentContext.User.Id == userId) - { - return _currentContext.User; - } - - _currentContext.User = await _userRepository.GetByIdAsync(userId); - return _currentContext.User; - } - - public async Task GetUserByPrincipalAsync(ClaimsPrincipal principal) - { - var userId = GetProperUserId(principal); - if (!userId.HasValue) - { - return null; - } - - return await GetUserByIdAsync(userId.Value); - } - - public async Task GetAccountRevisionDateByIdAsync(Guid userId) - { - return await _userRepository.GetAccountRevisionDateAsync(userId); - } - - public async Task SaveUserAsync(User user, bool push = false) - { - if (user.Id == default(Guid)) - { - throw new ApplicationException("Use register method to create a new user."); - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - - if (push) - { - // push - await _pushService.PushSyncSettingsAsync(user.Id); - } - } - - public override async Task DeleteAsync(User user) - { - // Check if user is the only owner of any organizations. - var onlyOwnerCount = await _organizationUserRepository.GetCountByOnlyOwnerAsync(user.Id); - if (onlyOwnerCount > 0) - { - var deletedOrg = false; - var orgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - if (orgs.Count == 1) + if (!Guid.TryParse(GetUserId(principal), out var userIdGuid)) { - var org = await _organizationRepository.GetByIdAsync(orgs.First().OrganizationId); - if (org != null && (!org.Enabled || string.IsNullOrWhiteSpace(org.GatewaySubscriptionId))) + return null; + } + + return userIdGuid; + } + + public async Task GetUserByIdAsync(string userId) + { + if (_currentContext?.User != null && + string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) + { + return _currentContext.User; + } + + if (!Guid.TryParse(userId, out var userIdGuid)) + { + return null; + } + + _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); + return _currentContext.User; + } + + public async Task GetUserByIdAsync(Guid userId) + { + if (_currentContext?.User != null && _currentContext.User.Id == userId) + { + return _currentContext.User; + } + + _currentContext.User = await _userRepository.GetByIdAsync(userId); + return _currentContext.User; + } + + public async Task GetUserByPrincipalAsync(ClaimsPrincipal principal) + { + var userId = GetProperUserId(principal); + if (!userId.HasValue) + { + return null; + } + + return await GetUserByIdAsync(userId.Value); + } + + public async Task GetAccountRevisionDateByIdAsync(Guid userId) + { + return await _userRepository.GetAccountRevisionDateAsync(userId); + } + + public async Task SaveUserAsync(User user, bool push = false) + { + if (user.Id == default(Guid)) + { + throw new ApplicationException("Use register method to create a new user."); + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + + if (push) + { + // push + await _pushService.PushSyncSettingsAsync(user.Id); + } + } + + public override async Task DeleteAsync(User user) + { + // Check if user is the only owner of any organizations. + var onlyOwnerCount = await _organizationUserRepository.GetCountByOnlyOwnerAsync(user.Id); + if (onlyOwnerCount > 0) + { + var deletedOrg = false; + var orgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + if (orgs.Count == 1) { - var orgCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(org.Id); - if (orgCount <= 1) + var org = await _organizationRepository.GetByIdAsync(orgs.First().OrganizationId); + if (org != null && (!org.Enabled || string.IsNullOrWhiteSpace(org.GatewaySubscriptionId))) { - await _organizationRepository.DeleteAsync(org); - deletedOrg = true; + var orgCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(org.Id); + if (orgCount <= 1) + { + await _organizationRepository.DeleteAsync(org); + deletedOrg = true; + } + } + } + + if (!deletedOrg) + { + return IdentityResult.Failed(new IdentityError + { + Description = "Cannot delete this user because it is the sole owner of at least one organization. Please delete these organizations or upgrade another user.", + }); + } + } + + var onlyOwnerProviderCount = await _providerUserRepository.GetCountByOnlyOwnerAsync(user.Id); + if (onlyOwnerProviderCount > 0) + { + return IdentityResult.Failed(new IdentityError + { + Description = "Cannot delete this user because it is the sole owner of at least one provider. Please delete these providers or upgrade another user.", + }); + } + + if (!string.IsNullOrWhiteSpace(user.GatewaySubscriptionId)) + { + try + { + await CancelPremiumAsync(user, null, true); + } + catch (GatewayException) { } + } + + await _userRepository.DeleteAsync(user); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.DeleteAccount, user)); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + public async Task DeleteAsync(User user, string token) + { + if (!(await VerifyUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount", token))) + { + return IdentityResult.Failed(ErrorDescriber.InvalidToken()); + } + + return await DeleteAsync(user); + } + + public async Task SendDeleteConfirmationAsync(string email) + { + var user = await _userRepository.GetByEmailAsync(email); + if (user == null) + { + // No user exists. + return; + } + + var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount"); + await _mailService.SendVerifyDeleteEmailAsync(user.Email, user.Id, token); + } + + public async Task RegisterUserAsync(User user, string masterPassword, + string token, Guid? orgUserId) + { + var tokenValid = false; + if (_globalSettings.DisableUserRegistration && !string.IsNullOrWhiteSpace(token) && orgUserId.HasValue) + { + tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, + user.Email, orgUserId.Value, _globalSettings); + } + + if (_globalSettings.DisableUserRegistration && !tokenValid) + { + throw new BadRequestException("Open registration has been disabled by the system administrator."); + } + + if (orgUserId.HasValue) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); + if (orgUser != null) + { + var twoFactorPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, + PolicyType.TwoFactorAuthentication); + if (twoFactorPolicy != null && twoFactorPolicy.Enabled) + { + user.SetTwoFactorProviders(new Dictionary + { + + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); + SetTwoFactorProvider(user, TwoFactorProviderType.Email); } } } - if (!deletedOrg) + user.ApiKey = CoreHelpers.SecureRandomString(30); + var result = await base.CreateAsync(user, masterPassword); + if (result == IdentityResult.Success) { - return IdentityResult.Failed(new IdentityError - { - Description = "Cannot delete this user because it is the sole owner of at least one organization. Please delete these organizations or upgrade another user.", - }); + await _mailService.SendWelcomeEmailAsync(user); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); + } + + return result; + } + + public async Task RegisterUserAsync(User user) + { + var result = await base.CreateAsync(user); + if (result == IdentityResult.Success) + { + await _mailService.SendWelcomeEmailAsync(user); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); + } + + return result; + } + + public async Task SendMasterPasswordHintAsync(string email) + { + var user = await _userRepository.GetByEmailAsync(email); + if (user == null) + { + // No user exists. Do we want to send an email telling them this in the future? + return; + } + + if (string.IsNullOrWhiteSpace(user.MasterPasswordHint)) + { + await _mailService.SendNoMasterPasswordHintEmailAsync(email); + return; + } + + await _mailService.SendMasterPasswordHintEmailAsync(email, user.MasterPasswordHint); + } + + public async Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) + { + throw new ArgumentNullException("No email."); + } + + var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); + var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "2faEmail:" + email); + + if (isBecauseNewDeviceLogin) + { + await _mailService.SendNewDeviceLoginTwoFactorEmailAsync(email, token); + } + else + { + await _mailService.SendTwoFactorEmailAsync(email, token); } } - var onlyOwnerProviderCount = await _providerUserRepository.GetCountByOnlyOwnerAsync(user.Id); - if (onlyOwnerProviderCount > 0) + public async Task VerifyTwoFactorEmailAsync(User user, string token) { - return IdentityResult.Failed(new IdentityError + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) { - Description = "Cannot delete this user because it is the sole owner of at least one provider. Please delete these providers or upgrade another user.", - }); + throw new ArgumentNullException("No email."); + } + + var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); + return await base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "2faEmail:" + email, token); } - if (!string.IsNullOrWhiteSpace(user.GatewaySubscriptionId)) + public async Task StartWebAuthnRegistrationAsync(User user) { + var providers = user.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (provider == null) + { + provider = new TwoFactorProvider + { + Enabled = false + }; + } + if (provider.MetaData == null) + { + provider.MetaData = new Dictionary(); + } + + var fidoUser = new Fido2User + { + DisplayName = user.Name, + Name = user.Email, + Id = user.Id.ToByteArray(), + }; + + var excludeCredentials = provider.MetaData + .Where(k => k.Key.StartsWith("Key")) + .Select(k => new TwoFactorProvider.WebAuthnData((dynamic)k.Value).Descriptor) + .ToList(); + + var authenticatorSelection = new AuthenticatorSelection + { + AuthenticatorAttachment = null, + RequireResidentKey = false, + UserVerification = UserVerificationRequirement.Discouraged + }; + var options = _fido2.RequestNewCredential(fidoUser, excludeCredentials, authenticatorSelection, AttestationConveyancePreference.None); + + provider.MetaData["pending"] = options.ToJson(); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, false); + + return options; + } + + public async Task CompleteWebAuthRegistrationAsync(User user, int id, string name, AuthenticatorAttestationRawResponse attestationResponse) + { + var keyId = $"Key{id}"; + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (!provider?.MetaData?.ContainsKey("pending") ?? true) + { + return false; + } + + var options = CredentialCreateOptions.FromJson((string)provider.MetaData["pending"]); + + // Callback to ensure credential id is unique. Always return true since we don't care if another + // account uses the same 2fa key. + IsCredentialIdUniqueToUserAsyncDelegate callback = args => Task.FromResult(true); + + var success = await _fido2.MakeNewCredentialAsync(attestationResponse, options, callback); + + provider.MetaData.Remove("pending"); + provider.MetaData[keyId] = new TwoFactorProvider.WebAuthnData + { + Name = name, + Descriptor = new PublicKeyCredentialDescriptor(success.Result.CredentialId), + PublicKey = success.Result.PublicKey, + UserHandle = success.Result.User.Id, + SignatureCounter = success.Result.Counter, + CredType = success.Result.CredType, + RegDate = DateTime.Now, + AaGuid = success.Result.Aaguid + }; + + var providers = user.GetTwoFactorProviders(); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); + + return true; + } + + public async Task DeleteWebAuthnKeyAsync(User user, int id) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null) + { + return false; + } + + var keyName = $"Key{id}"; + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (!provider?.MetaData?.ContainsKey(keyName) ?? true) + { + return false; + } + + if (provider.MetaData.Count < 2) + { + return false; + } + + provider.MetaData.Remove(keyName); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); + return true; + } + + public async Task SendEmailVerificationAsync(User user) + { + if (user.EmailVerified) + { + throw new BadRequestException("Email already verified."); + } + + var token = await base.GenerateEmailConfirmationTokenAsync(user); + await _mailService.SendVerifyEmailEmailAsync(user.Email, user.Id, token); + } + + public async Task InitiateEmailChangeAsync(User user, string newEmail) + { + var existingUser = await _userRepository.GetByEmailAsync(newEmail); + if (existingUser != null) + { + await _mailService.SendChangeEmailAlreadyExistsEmailAsync(user.Email, newEmail); + return; + } + + var token = await base.GenerateChangeEmailTokenAsync(user, newEmail); + await _mailService.SendChangeEmailEmailAsync(newEmail, token); + } + + public async Task ChangeEmailAsync(User user, string masterPassword, string newEmail, + string newMasterPassword, string token, string key) + { + var verifyPasswordResult = _passwordHasher.VerifyHashedPassword(user, user.MasterPassword, masterPassword); + if (verifyPasswordResult == PasswordVerificationResult.Failed) + { + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + if (!await base.VerifyUserTokenAsync(user, _identityOptions.Tokens.ChangeEmailTokenProvider, + GetChangeEmailTokenPurpose(newEmail), token)) + { + return IdentityResult.Failed(_identityErrorDescriber.InvalidToken()); + } + + var existingUser = await _userRepository.GetByEmailAsync(newEmail); + if (existingUser != null && existingUser.Id != user.Id) + { + return IdentityResult.Failed(_identityErrorDescriber.DuplicateEmail(newEmail)); + } + + var previousState = new + { + Key = user.Key, + MasterPassword = user.MasterPassword, + SecurityStamp = user.SecurityStamp, + Email = user.Email + }; + + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.Key = key; + user.Email = newEmail; + user.EmailVerified = true; + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + + if (user.Gateway == GatewayType.Stripe) + { + + try + { + await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, + user.BillingEmailAddress()); + } + catch (Exception ex) + { + //if sync to strip fails, update email and securityStamp to previous + user.Key = previousState.Key; + user.Email = previousState.Email; + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.MasterPassword = previousState.MasterPassword; + user.SecurityStamp = previousState.SecurityStamp; + + await _userRepository.ReplaceAsync(user); + return IdentityResult.Failed(new IdentityError + { + Description = ex.Message + }); + } + } + + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + public override Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword) + { + throw new NotImplementedException(); + } + + public async Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, + string key) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await CheckPasswordAsync(user, masterPassword)) + { + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.MasterPasswordHint = passwordHint; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + Logger.LogWarning("Change password failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task SetPasswordAsync(User user, string masterPassword, string key, + string orgIdentifier = null) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (!string.IsNullOrWhiteSpace(user.MasterPassword)) + { + Logger.LogWarning("Change password failed for user {userId} - already has password.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); + } + + var result = await UpdatePasswordHash(user, masterPassword, true, false); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); + + if (!string.IsNullOrWhiteSpace(orgIdentifier)) + { + await _organizationService.AcceptUserAsync(orgIdentifier, user, this); + } + + return IdentityResult.Success; + } + + public async Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier) + { + var identityResult = CheckCanUseKeyConnector(user); + if (identityResult != null) + { + return identityResult; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.UsesKeyConnector = true; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); + + await _organizationService.AcceptUserAsync(orgIdentifier, user, this); + + return IdentityResult.Success; + } + + public async Task ConvertToKeyConnectorAsync(User user) + { + var identityResult = CheckCanUseKeyConnector(user); + if (identityResult != null) + { + return identityResult; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.MasterPassword = null; + user.UsesKeyConnector = true; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); + + return IdentityResult.Success; + } + + private IdentityResult CheckCanUseKeyConnector(User user) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (user.UsesKeyConnector) + { + Logger.LogWarning("Already uses Key Connector."); + return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); + } + + if (_currentContext.Organizations.Any(u => + u.Type is OrganizationUserType.Owner or OrganizationUserType.Admin)) + { + throw new BadRequestException("Cannot use Key Connector when admin or owner of an organization."); + } + + return null; + } + + public async Task AdminResetPasswordAsync(OrganizationUserType callingUserType, Guid orgId, Guid id, string newMasterPassword, string key) + { + // Org must be able to use reset password + var org = await _organizationRepository.GetByIdAsync(orgId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset."); + } + + // Enterprise policy must be enabled + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Org User must be confirmed and have a ResetPasswordKey + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.Status != OrganizationUserStatusType.Confirmed || + orgUser.OrganizationId != orgId || string.IsNullOrEmpty(orgUser.ResetPasswordKey) || + !orgUser.UserId.HasValue) + { + throw new BadRequestException("Organization User not valid"); + } + + // Calling User must be of higher/equal user type to reset user's password + var canAdjustPassword = false; + switch (callingUserType) + { + case OrganizationUserType.Owner: + canAdjustPassword = true; + break; + case OrganizationUserType.Admin: + canAdjustPassword = orgUser.Type != OrganizationUserType.Owner; + break; + case OrganizationUserType.Custom: + canAdjustPassword = orgUser.Type != OrganizationUserType.Owner && + orgUser.Type != OrganizationUserType.Admin; + break; + } + + if (!canAdjustPassword) + { + throw new BadRequestException("Calling user does not have permission to reset this user's master password"); + } + + var user = await GetUserByIdAsync(orgUser.UserId.Value); + if (user == null) + { + throw new NotFoundException(); + } + + if (user.UsesKeyConnector) + { + throw new BadRequestException("Cannot reset password of a user with Key Connector."); + } + + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.ForcePasswordReset = true; + + await _userRepository.ReplaceAsync(user); + await _mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.Name); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_AdminResetPassword); + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + public async Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint) + { + if (!user.ForcePasswordReset) + { + throw new BadRequestException("User does not have a temporary password to update."); + } + + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.ForcePasswordReset = false; + user.Key = key; + user.MasterPasswordHint = hint; + + await _userRepository.ReplaceAsync(user); + await _mailService.SendUpdatedTempPasswordEmailAsync(user.Email, user.Name); + await _eventService.LogUserEventAsync(user.Id, EventType.User_UpdatedTempPassword); + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + public async Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, + string key, KdfType kdf, int kdfIterations) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await CheckPasswordAsync(user, masterPassword)) + { + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.Kdf = kdf; + user.KdfIterations = kdfIterations; + await _userRepository.ReplaceAsync(user); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + Logger.LogWarning("Change KDF failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, + IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await CheckPasswordAsync(user, masterPassword)) + { + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.SecurityStamp = Guid.NewGuid().ToString(); + user.Key = key; + user.PrivateKey = privateKey; + if (ciphers.Any() || folders.Any() || sends.Any()) + { + await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends); + } + else + { + await _userRepository.ReplaceAsync(user); + } + + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + Logger.LogWarning("Update key failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task RefreshSecurityStampAsync(User user, string secret) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await VerifySecretAsync(user, secret)) + { + var result = await base.UpdateSecurityStampAsync(user); + if (!result.Succeeded) + { + return result; + } + + await SaveUserAsync(user); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + Logger.LogWarning("Refresh security stamp failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true) + { + SetTwoFactorProvider(user, type, setEnabled); + await SaveUserAsync(user); + if (logEvent) + { + await _eventService.LogUserEventAsync(user.Id, EventType.User_Updated2fa); + } + } + + public async Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, + IOrganizationService organizationService) + { + var providers = user.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + providers.Remove(type); + user.SetTwoFactorProviders(providers); + await SaveUserAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_Disabled2fa); + + if (!await TwoFactorIsEnabledAsync(user)) + { + await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); + } + } + + public async Task RecoverTwoFactorAsync(string email, string secret, string recoveryCode, + IOrganizationService organizationService) + { + var user = await _userRepository.GetByEmailAsync(email); + if (user == null) + { + // No user exists. Do we want to send an email telling them this in the future? + return false; + } + + if (!await VerifySecretAsync(user, secret)) + { + return false; + } + + if (!CoreHelpers.FixedTimeEquals(user.TwoFactorRecoveryCode, recoveryCode)) + { + return false; + } + + user.TwoFactorProviders = null; + user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); + await SaveUserAsync(user); + await _mailService.SendRecoverTwoFactorEmail(user.Email, DateTime.UtcNow, _currentContext.IpAddress); + await _eventService.LogUserEventAsync(user.Id, EventType.User_Recovered2fa); + await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); + + return true; + } + + public async Task> SignUpPremiumAsync(User user, string paymentToken, + PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, + TaxInfo taxInfo) + { + if (user.Premium) + { + throw new BadRequestException("Already a premium user."); + } + + if (additionalStorageGb < 0) + { + throw new BadRequestException("You can't subtract storage!"); + } + + if ((paymentMethodType == PaymentMethodType.GoogleInApp || + paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) + { + throw new BadRequestException("You cannot add storage with this payment method."); + } + + string paymentIntentClientSecret = null; + IPaymentService paymentService = null; + if (_globalSettings.SelfHosted) + { + if (license == null || !_licenseService.VerifyLicense(license)) + { + throw new BadRequestException("Invalid license."); + } + + if (!license.CanUse(user)) + { + throw new BadRequestException("This license is not valid for this user."); + } + + var dir = $"{_globalSettings.LicenseDirectory}/user"; + Directory.CreateDirectory(dir); + using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + } + else + { + paymentIntentClientSecret = await _paymentService.PurchasePremiumAsync(user, paymentMethodType, + paymentToken, additionalStorageGb, taxInfo); + } + + user.Premium = true; + user.RevisionDate = DateTime.UtcNow; + + if (_globalSettings.SelfHosted) + { + user.MaxStorageGb = 10240; // 10 TB + user.LicenseKey = license.LicenseKey; + user.PremiumExpirationDate = license.Expires; + } + else + { + user.MaxStorageGb = (short)(1 + additionalStorageGb); + user.LicenseKey = CoreHelpers.SecureRandomString(20); + } + try { - await CancelPremiumAsync(user, null, true); - } - catch (GatewayException) { } - } - - await _userRepository.DeleteAsync(user); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.DeleteAccount, user)); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - public async Task DeleteAsync(User user, string token) - { - if (!(await VerifyUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount", token))) - { - return IdentityResult.Failed(ErrorDescriber.InvalidToken()); - } - - return await DeleteAsync(user); - } - - public async Task SendDeleteConfirmationAsync(string email) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. - return; - } - - var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount"); - await _mailService.SendVerifyDeleteEmailAsync(user.Email, user.Id, token); - } - - public async Task RegisterUserAsync(User user, string masterPassword, - string token, Guid? orgUserId) - { - var tokenValid = false; - if (_globalSettings.DisableUserRegistration && !string.IsNullOrWhiteSpace(token) && orgUserId.HasValue) - { - tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, - user.Email, orgUserId.Value, _globalSettings); - } - - if (_globalSettings.DisableUserRegistration && !tokenValid) - { - throw new BadRequestException("Open registration has been disabled by the system administrator."); - } - - if (orgUserId.HasValue) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); - if (orgUser != null) - { - var twoFactorPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, - PolicyType.TwoFactorAuthentication); - if (twoFactorPolicy != null && twoFactorPolicy.Enabled) - { - user.SetTwoFactorProviders(new Dictionary + await SaveUserAsync(user); + await _pushService.PushSyncVaultAsync(user.Id); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.UpgradePlan, user) { - - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - } + Storage = user.MaxStorageGb, + PlanName = PremiumPlanId, }); - SetTwoFactorProvider(user, TwoFactorProviderType.Email); + } + catch when (!_globalSettings.SelfHosted) + { + await paymentService.CancelAndRecoverChargesAsync(user); + throw; + } + return new Tuple(string.IsNullOrWhiteSpace(paymentIntentClientSecret), + paymentIntentClientSecret); + } + + public async Task IapCheckAsync(User user, PaymentMethodType paymentMethodType) + { + if (paymentMethodType != PaymentMethodType.AppleInApp) + { + throw new BadRequestException("Payment method not supported for in-app purchases."); + } + + if (user.Premium) + { + throw new BadRequestException("Already a premium user."); + } + + if (!string.IsNullOrWhiteSpace(user.GatewayCustomerId)) + { + var customerService = new Stripe.CustomerService(); + var customer = await customerService.GetAsync(user.GatewayCustomerId); + if (customer != null && customer.Balance != 0) + { + throw new BadRequestException("Customer balance cannot exist when using in-app purchases."); } } } - user.ApiKey = CoreHelpers.SecureRandomString(30); - var result = await base.CreateAsync(user, masterPassword); - if (result == IdentityResult.Success) + public async Task UpdateLicenseAsync(User user, UserLicense license) { - await _mailService.SendWelcomeEmailAsync(user); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); - } - - return result; - } - - public async Task RegisterUserAsync(User user) - { - var result = await base.CreateAsync(user); - if (result == IdentityResult.Success) - { - await _mailService.SendWelcomeEmailAsync(user); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); - } - - return result; - } - - public async Task SendMasterPasswordHintAsync(string email) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. Do we want to send an email telling them this in the future? - return; - } - - if (string.IsNullOrWhiteSpace(user.MasterPasswordHint)) - { - await _mailService.SendNoMasterPasswordHintEmailAsync(email); - return; - } - - await _mailService.SendMasterPasswordHintEmailAsync(email, user.MasterPasswordHint); - } - - public async Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) - { - throw new ArgumentNullException("No email."); - } - - var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); - var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "2faEmail:" + email); - - if (isBecauseNewDeviceLogin) - { - await _mailService.SendNewDeviceLoginTwoFactorEmailAsync(email, token); - } - else - { - await _mailService.SendTwoFactorEmailAsync(email, token); - } - } - - public async Task VerifyTwoFactorEmailAsync(User user, string token) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) - { - throw new ArgumentNullException("No email."); - } - - var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); - return await base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "2faEmail:" + email, token); - } - - public async Task StartWebAuthnRegistrationAsync(User user) - { - var providers = user.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (provider == null) - { - provider = new TwoFactorProvider + if (!_globalSettings.SelfHosted) { - Enabled = false - }; - } - if (provider.MetaData == null) - { - provider.MetaData = new Dictionary(); - } - - var fidoUser = new Fido2User - { - DisplayName = user.Name, - Name = user.Email, - Id = user.Id.ToByteArray(), - }; - - var excludeCredentials = provider.MetaData - .Where(k => k.Key.StartsWith("Key")) - .Select(k => new TwoFactorProvider.WebAuthnData((dynamic)k.Value).Descriptor) - .ToList(); - - var authenticatorSelection = new AuthenticatorSelection - { - AuthenticatorAttachment = null, - RequireResidentKey = false, - UserVerification = UserVerificationRequirement.Discouraged - }; - var options = _fido2.RequestNewCredential(fidoUser, excludeCredentials, authenticatorSelection, AttestationConveyancePreference.None); - - provider.MetaData["pending"] = options.ToJson(); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, false); - - return options; - } - - public async Task CompleteWebAuthRegistrationAsync(User user, int id, string name, AuthenticatorAttestationRawResponse attestationResponse) - { - var keyId = $"Key{id}"; - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (!provider?.MetaData?.ContainsKey("pending") ?? true) - { - return false; - } - - var options = CredentialCreateOptions.FromJson((string)provider.MetaData["pending"]); - - // Callback to ensure credential id is unique. Always return true since we don't care if another - // account uses the same 2fa key. - IsCredentialIdUniqueToUserAsyncDelegate callback = args => Task.FromResult(true); - - var success = await _fido2.MakeNewCredentialAsync(attestationResponse, options, callback); - - provider.MetaData.Remove("pending"); - provider.MetaData[keyId] = new TwoFactorProvider.WebAuthnData - { - Name = name, - Descriptor = new PublicKeyCredentialDescriptor(success.Result.CredentialId), - PublicKey = success.Result.PublicKey, - UserHandle = success.Result.User.Id, - SignatureCounter = success.Result.Counter, - CredType = success.Result.CredType, - RegDate = DateTime.Now, - AaGuid = success.Result.Aaguid - }; - - var providers = user.GetTwoFactorProviders(); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); - - return true; - } - - public async Task DeleteWebAuthnKeyAsync(User user, int id) - { - var providers = user.GetTwoFactorProviders(); - if (providers == null) - { - return false; - } - - var keyName = $"Key{id}"; - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (!provider?.MetaData?.ContainsKey(keyName) ?? true) - { - return false; - } - - if (provider.MetaData.Count < 2) - { - return false; - } - - provider.MetaData.Remove(keyName); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); - return true; - } - - public async Task SendEmailVerificationAsync(User user) - { - if (user.EmailVerified) - { - throw new BadRequestException("Email already verified."); - } - - var token = await base.GenerateEmailConfirmationTokenAsync(user); - await _mailService.SendVerifyEmailEmailAsync(user.Email, user.Id, token); - } - - public async Task InitiateEmailChangeAsync(User user, string newEmail) - { - var existingUser = await _userRepository.GetByEmailAsync(newEmail); - if (existingUser != null) - { - await _mailService.SendChangeEmailAlreadyExistsEmailAsync(user.Email, newEmail); - return; - } - - var token = await base.GenerateChangeEmailTokenAsync(user, newEmail); - await _mailService.SendChangeEmailEmailAsync(newEmail, token); - } - - public async Task ChangeEmailAsync(User user, string masterPassword, string newEmail, - string newMasterPassword, string token, string key) - { - var verifyPasswordResult = _passwordHasher.VerifyHashedPassword(user, user.MasterPassword, masterPassword); - if (verifyPasswordResult == PasswordVerificationResult.Failed) - { - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - if (!await base.VerifyUserTokenAsync(user, _identityOptions.Tokens.ChangeEmailTokenProvider, - GetChangeEmailTokenPurpose(newEmail), token)) - { - return IdentityResult.Failed(_identityErrorDescriber.InvalidToken()); - } - - var existingUser = await _userRepository.GetByEmailAsync(newEmail); - if (existingUser != null && existingUser.Id != user.Id) - { - return IdentityResult.Failed(_identityErrorDescriber.DuplicateEmail(newEmail)); - } - - var previousState = new - { - Key = user.Key, - MasterPassword = user.MasterPassword, - SecurityStamp = user.SecurityStamp, - Email = user.Email - }; - - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.Key = key; - user.Email = newEmail; - user.EmailVerified = true; - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - - if (user.Gateway == GatewayType.Stripe) - { - - try - { - await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, - user.BillingEmailAddress()); - } - catch (Exception ex) - { - //if sync to strip fails, update email and securityStamp to previous - user.Key = previousState.Key; - user.Email = previousState.Email; - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.MasterPassword = previousState.MasterPassword; - user.SecurityStamp = previousState.SecurityStamp; - - await _userRepository.ReplaceAsync(user); - return IdentityResult.Failed(new IdentityError - { - Description = ex.Message - }); - } - } - - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; - } - - public override Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword) - { - throw new NotImplementedException(); - } - - public async Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, - string key) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await CheckPasswordAsync(user, masterPassword)) - { - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; + throw new InvalidOperationException("Licenses require self hosting."); } - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.MasterPasswordHint = passwordHint; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; - } - - Logger.LogWarning("Change password failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - public async Task SetPasswordAsync(User user, string masterPassword, string key, - string orgIdentifier = null) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (!string.IsNullOrWhiteSpace(user.MasterPassword)) - { - Logger.LogWarning("Change password failed for user {userId} - already has password.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); - } - - var result = await UpdatePasswordHash(user, masterPassword, true, false); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); - - if (!string.IsNullOrWhiteSpace(orgIdentifier)) - { - await _organizationService.AcceptUserAsync(orgIdentifier, user, this); - } - - return IdentityResult.Success; - } - - public async Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier) - { - var identityResult = CheckCanUseKeyConnector(user); - if (identityResult != null) - { - return identityResult; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.UsesKeyConnector = true; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); - - await _organizationService.AcceptUserAsync(orgIdentifier, user, this); - - return IdentityResult.Success; - } - - public async Task ConvertToKeyConnectorAsync(User user) - { - var identityResult = CheckCanUseKeyConnector(user); - if (identityResult != null) - { - return identityResult; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.MasterPassword = null; - user.UsesKeyConnector = true; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); - - return IdentityResult.Success; - } - - private IdentityResult CheckCanUseKeyConnector(User user) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (user.UsesKeyConnector) - { - Logger.LogWarning("Already uses Key Connector."); - return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); - } - - if (_currentContext.Organizations.Any(u => - u.Type is OrganizationUserType.Owner or OrganizationUserType.Admin)) - { - throw new BadRequestException("Cannot use Key Connector when admin or owner of an organization."); - } - - return null; - } - - public async Task AdminResetPasswordAsync(OrganizationUserType callingUserType, Guid orgId, Guid id, string newMasterPassword, string key) - { - // Org must be able to use reset password - var org = await _organizationRepository.GetByIdAsync(orgId); - if (org == null || !org.UseResetPassword) - { - throw new BadRequestException("Organization does not allow password reset."); - } - - // Enterprise policy must be enabled - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Organization does not have the password reset policy enabled."); - } - - // Org User must be confirmed and have a ResetPasswordKey - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.Status != OrganizationUserStatusType.Confirmed || - orgUser.OrganizationId != orgId || string.IsNullOrEmpty(orgUser.ResetPasswordKey) || - !orgUser.UserId.HasValue) - { - throw new BadRequestException("Organization User not valid"); - } - - // Calling User must be of higher/equal user type to reset user's password - var canAdjustPassword = false; - switch (callingUserType) - { - case OrganizationUserType.Owner: - canAdjustPassword = true; - break; - case OrganizationUserType.Admin: - canAdjustPassword = orgUser.Type != OrganizationUserType.Owner; - break; - case OrganizationUserType.Custom: - canAdjustPassword = orgUser.Type != OrganizationUserType.Owner && - orgUser.Type != OrganizationUserType.Admin; - break; - } - - if (!canAdjustPassword) - { - throw new BadRequestException("Calling user does not have permission to reset this user's master password"); - } - - var user = await GetUserByIdAsync(orgUser.UserId.Value); - if (user == null) - { - throw new NotFoundException(); - } - - if (user.UsesKeyConnector) - { - throw new BadRequestException("Cannot reset password of a user with Key Connector."); - } - - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.ForcePasswordReset = true; - - await _userRepository.ReplaceAsync(user); - await _mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.Name); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_AdminResetPassword); - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; - } - - public async Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint) - { - if (!user.ForcePasswordReset) - { - throw new BadRequestException("User does not have a temporary password to update."); - } - - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.ForcePasswordReset = false; - user.Key = key; - user.MasterPasswordHint = hint; - - await _userRepository.ReplaceAsync(user); - await _mailService.SendUpdatedTempPasswordEmailAsync(user.Email, user.Name); - await _eventService.LogUserEventAsync(user.Id, EventType.User_UpdatedTempPassword); - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; - } - - public async Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, - string key, KdfType kdf, int kdfIterations) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await CheckPasswordAsync(user, masterPassword)) - { - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) + if (license?.LicenseType != null && license.LicenseType != LicenseType.User) { - return result; + throw new BadRequestException("Organization licenses cannot be applied to a user. " + + "Upload this license from the Organization settings page."); } - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.Kdf = kdf; - user.KdfIterations = kdfIterations; - await _userRepository.ReplaceAsync(user); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - Logger.LogWarning("Change KDF failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - public async Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, - IEnumerable ciphers, IEnumerable folders, IEnumerable sends) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await CheckPasswordAsync(user, masterPassword)) - { - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.SecurityStamp = Guid.NewGuid().ToString(); - user.Key = key; - user.PrivateKey = privateKey; - if (ciphers.Any() || folders.Any() || sends.Any()) - { - await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends); - } - else - { - await _userRepository.ReplaceAsync(user); - } - - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - Logger.LogWarning("Update key failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - public async Task RefreshSecurityStampAsync(User user, string secret) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await VerifySecretAsync(user, secret)) - { - var result = await base.UpdateSecurityStampAsync(user); - if (!result.Succeeded) - { - return result; - } - - await SaveUserAsync(user); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - Logger.LogWarning("Refresh security stamp failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - public async Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true) - { - SetTwoFactorProvider(user, type, setEnabled); - await SaveUserAsync(user); - if (logEvent) - { - await _eventService.LogUserEventAsync(user.Id, EventType.User_Updated2fa); - } - } - - public async Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, - IOrganizationService organizationService) - { - var providers = user.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - providers.Remove(type); - user.SetTwoFactorProviders(providers); - await SaveUserAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_Disabled2fa); - - if (!await TwoFactorIsEnabledAsync(user)) - { - await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); - } - } - - public async Task RecoverTwoFactorAsync(string email, string secret, string recoveryCode, - IOrganizationService organizationService) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. Do we want to send an email telling them this in the future? - return false; - } - - if (!await VerifySecretAsync(user, secret)) - { - return false; - } - - if (!CoreHelpers.FixedTimeEquals(user.TwoFactorRecoveryCode, recoveryCode)) - { - return false; - } - - user.TwoFactorProviders = null; - user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); - await SaveUserAsync(user); - await _mailService.SendRecoverTwoFactorEmail(user.Email, DateTime.UtcNow, _currentContext.IpAddress); - await _eventService.LogUserEventAsync(user.Id, EventType.User_Recovered2fa); - await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); - - return true; - } - - public async Task> SignUpPremiumAsync(User user, string paymentToken, - PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, - TaxInfo taxInfo) - { - if (user.Premium) - { - throw new BadRequestException("Already a premium user."); - } - - if (additionalStorageGb < 0) - { - throw new BadRequestException("You can't subtract storage!"); - } - - if ((paymentMethodType == PaymentMethodType.GoogleInApp || - paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) - { - throw new BadRequestException("You cannot add storage with this payment method."); - } - - string paymentIntentClientSecret = null; - IPaymentService paymentService = null; - if (_globalSettings.SelfHosted) - { if (license == null || !_licenseService.VerifyLicense(license)) { throw new BadRequestException("Invalid license."); @@ -1005,492 +1097,401 @@ public class UserService : UserManager, IUserService, IDisposable Directory.CreateDirectory(dir); using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - } - else - { - paymentIntentClientSecret = await _paymentService.PurchasePremiumAsync(user, paymentMethodType, - paymentToken, additionalStorageGb, taxInfo); - } - user.Premium = true; - user.RevisionDate = DateTime.UtcNow; - - if (_globalSettings.SelfHosted) - { - user.MaxStorageGb = 10240; // 10 TB + user.Premium = license.Premium; + user.RevisionDate = DateTime.UtcNow; + user.MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb; // 10 TB user.LicenseKey = license.LicenseKey; user.PremiumExpirationDate = license.Expires; - } - else - { - user.MaxStorageGb = (short)(1 + additionalStorageGb); - user.LicenseKey = CoreHelpers.SecureRandomString(20); + await SaveUserAsync(user); } - try + public async Task AdjustStorageAsync(User user, short storageAdjustmentGb) { - await SaveUserAsync(user); - await _pushService.PushSyncVaultAsync(user.Id); + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (!user.Premium) + { + throw new BadRequestException("Not a premium user."); + } + + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, + StoragePlanId); await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.UpgradePlan, user) + new ReferenceEvent(ReferenceEventType.AdjustStorage, user) { - Storage = user.MaxStorageGb, - PlanName = PremiumPlanId, + Storage = storageAdjustmentGb, + PlanName = StoragePlanId, + }); + await SaveUserAsync(user); + return secret; + } + + public async Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo) + { + if (paymentToken.StartsWith("btok_")) + { + throw new BadRequestException("Invalid token."); + } + + var updated = await _paymentService.UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, taxInfo: taxInfo); + if (updated) + { + await SaveUserAsync(user); + } + } + + public async Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false) + { + var eop = endOfPeriod.GetValueOrDefault(true); + if (!endOfPeriod.HasValue && user.PremiumExpirationDate.HasValue && + user.PremiumExpirationDate.Value < DateTime.UtcNow) + { + eop = false; + } + await _paymentService.CancelSubscriptionAsync(user, eop, accountDelete); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.CancelSubscription, user) + { + EndOfPeriod = eop, }); } - catch when (!_globalSettings.SelfHosted) - { - await paymentService.CancelAndRecoverChargesAsync(user); - throw; - } - return new Tuple(string.IsNullOrWhiteSpace(paymentIntentClientSecret), - paymentIntentClientSecret); - } - - public async Task IapCheckAsync(User user, PaymentMethodType paymentMethodType) - { - if (paymentMethodType != PaymentMethodType.AppleInApp) - { - throw new BadRequestException("Payment method not supported for in-app purchases."); - } - - if (user.Premium) - { - throw new BadRequestException("Already a premium user."); - } - - if (!string.IsNullOrWhiteSpace(user.GatewayCustomerId)) - { - var customerService = new Stripe.CustomerService(); - var customer = await customerService.GetAsync(user.GatewayCustomerId); - if (customer != null && customer.Balance != 0) - { - throw new BadRequestException("Customer balance cannot exist when using in-app purchases."); - } - } - } - - public async Task UpdateLicenseAsync(User user, UserLicense license) - { - if (!_globalSettings.SelfHosted) - { - throw new InvalidOperationException("Licenses require self hosting."); - } - - if (license?.LicenseType != null && license.LicenseType != LicenseType.User) - { - throw new BadRequestException("Organization licenses cannot be applied to a user. " - + "Upload this license from the Organization settings page."); - } - - if (license == null || !_licenseService.VerifyLicense(license)) - { - throw new BadRequestException("Invalid license."); - } - - if (!license.CanUse(user)) - { - throw new BadRequestException("This license is not valid for this user."); - } - - var dir = $"{_globalSettings.LicenseDirectory}/user"; - Directory.CreateDirectory(dir); - using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); - await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - - user.Premium = license.Premium; - user.RevisionDate = DateTime.UtcNow; - user.MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb; // 10 TB - user.LicenseKey = license.LicenseKey; - user.PremiumExpirationDate = license.Expires; - await SaveUserAsync(user); - } - - public async Task AdjustStorageAsync(User user, short storageAdjustmentGb) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (!user.Premium) - { - throw new BadRequestException("Not a premium user."); - } - - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, - StoragePlanId); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.AdjustStorage, user) - { - Storage = storageAdjustmentGb, - PlanName = StoragePlanId, - }); - await SaveUserAsync(user); - return secret; - } - - public async Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo) - { - if (paymentToken.StartsWith("btok_")) - { - throw new BadRequestException("Invalid token."); - } - - var updated = await _paymentService.UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, taxInfo: taxInfo); - if (updated) - { - await SaveUserAsync(user); - } - } - - public async Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false) - { - var eop = endOfPeriod.GetValueOrDefault(true); - if (!endOfPeriod.HasValue && user.PremiumExpirationDate.HasValue && - user.PremiumExpirationDate.Value < DateTime.UtcNow) - { - eop = false; - } - await _paymentService.CancelSubscriptionAsync(user, eop, accountDelete); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.CancelSubscription, user) - { - EndOfPeriod = eop, - }); - } - - public async Task ReinstatePremiumAsync(User user) - { - await _paymentService.ReinstateSubscriptionAsync(user); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.ReinstateSubscription, user)); - } - - public async Task EnablePremiumAsync(Guid userId, DateTime? expirationDate) - { - var user = await _userRepository.GetByIdAsync(userId); - await EnablePremiumAsync(user, expirationDate); - } - - public async Task EnablePremiumAsync(User user, DateTime? expirationDate) - { - if (user != null && !user.Premium && user.Gateway.HasValue) - { - user.Premium = true; - user.PremiumExpirationDate = expirationDate; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - } - - public async Task DisablePremiumAsync(Guid userId, DateTime? expirationDate) - { - var user = await _userRepository.GetByIdAsync(userId); - await DisablePremiumAsync(user, expirationDate); - } - - public async Task DisablePremiumAsync(User user, DateTime? expirationDate) - { - if (user != null && user.Premium) - { - user.Premium = false; - user.PremiumExpirationDate = expirationDate; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - } - - public async Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate) - { - var user = await _userRepository.GetByIdAsync(userId); - if (user != null) - { - user.PremiumExpirationDate = expirationDate; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - } - - public async Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, - int? version = null) - { - if (user == null) - { - throw new NotFoundException(); - } - - if (subscriptionInfo == null && user.Gateway != null) - { - subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); - } - - return subscriptionInfo == null ? new UserLicense(user, _licenseService) : - new UserLicense(user, subscriptionInfo, _licenseService); - } - - public override async Task CheckPasswordAsync(User user, string password) - { - if (user == null) - { - return false; - } - - var result = await base.VerifyPasswordAsync(Store as IUserPasswordStore, user, password); - if (result == PasswordVerificationResult.SuccessRehashNeeded) - { - await UpdatePasswordHash(user, password, false, false); - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - - var success = result != PasswordVerificationResult.Failed; - if (!success) - { - Logger.LogWarning(0, "Invalid password for user {userId}.", user.Id); - } - return success; - } - - public async Task CanAccessPremium(ITwoFactorProvidersUser user) - { - var userId = user.GetUserId(); - if (!userId.HasValue) - { - return false; - } - - return user.GetPremium() || await this.HasPremiumFromOrganization(user); - } - - public async Task HasPremiumFromOrganization(ITwoFactorProvidersUser user) - { - var userId = user.GetUserId(); - if (!userId.HasValue) - { - return false; - } - - // orgUsers in the Invited status are not associated with a userId yet, so this will get - // orgUsers in Accepted and Confirmed states only - var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value); - - if (!orgUsers.Any()) - { - return false; - } - - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - return orgUsers.Any(ou => - orgAbilities.TryGetValue(ou.OrganizationId, out var orgAbility) && - orgAbility.UsersGetPremium && - orgAbility.Enabled); - } - - public async Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user) - { - var providers = user.GetTwoFactorProviders(); - if (providers == null) - { - return false; - } - - foreach (var p in providers) - { - if (p.Value?.Enabled ?? false) - { - if (!TwoFactorProvider.RequiresPremium(p.Key)) - { - return true; - } - if (await CanAccessPremium(user)) - { - return true; - } - } - } - return false; - } - - public async Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user) - { - var providers = user.GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider) || !providers[provider].Enabled) - { - return false; - } - - if (!TwoFactorProvider.RequiresPremium(provider)) - { - return true; - } - - return await CanAccessPremium(user); - } - - public async Task GenerateSignInTokenAsync(User user, string purpose) - { - var token = await GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, - purpose); - return token; - } - - private async Task UpdatePasswordHash(User user, string newPassword, - bool validatePassword = true, bool refreshStamp = true) - { - if (validatePassword) - { - var validate = await ValidatePasswordInternal(user, newPassword); - if (!validate.Succeeded) - { - return validate; - } - } - - user.MasterPassword = _passwordHasher.HashPassword(user, newPassword); - if (refreshStamp) - { - user.SecurityStamp = Guid.NewGuid().ToString(); - } - - return IdentityResult.Success; - } - - private async Task ValidatePasswordInternal(User user, string password) - { - var errors = new List(); - foreach (var v in _passwordValidators) - { - var result = await v.ValidateAsync(this, user, password); - if (!result.Succeeded) - { - errors.AddRange(result.Errors); - } - } - - if (errors.Count > 0) - { - Logger.LogWarning("User {userId} password validation failed: {errors}.", await GetUserIdAsync(user), - string.Join(";", errors.Select(e => e.Code))); - return IdentityResult.Failed(errors.ToArray()); - } - - return IdentityResult.Success; - } - - public void SetTwoFactorProvider(User user, TwoFactorProviderType type, bool setEnabled = true) - { - var providers = user.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - if (setEnabled) - { - providers[type].Enabled = true; - } - user.SetTwoFactorProviders(providers); - - if (string.IsNullOrWhiteSpace(user.TwoFactorRecoveryCode)) - { - user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); - } - } - - private async Task CheckPoliciesOnTwoFactorRemovalAsync(User user, IOrganizationService organizationService) - { - var twoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, - PolicyType.TwoFactorAuthentication); - - var removeOrgUserTasks = twoFactorPolicies.Select(async p => - { - await organizationService.DeleteUserAsync(p.OrganizationId, user.Id); - var organization = await _organizationRepository.GetByIdAsync(p.OrganizationId); - await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( - organization.Name, user.Email); - }).ToArray(); - - await Task.WhenAll(removeOrgUserTasks); - } - - public override async Task ConfirmEmailAsync(User user, string token) - { - var result = await base.ConfirmEmailAsync(user, token); - if (result.Succeeded) + + public async Task ReinstatePremiumAsync(User user) { + await _paymentService.ReinstateSubscriptionAsync(user); await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.ConfirmEmailAddress, user)); - } - return result; - } - - public async Task RotateApiKeyAsync(User user) - { - user.ApiKey = CoreHelpers.SecureRandomString(30); - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - - public async Task SendOTPAsync(User user) - { - if (user.Email == null) - { - throw new BadRequestException("No user email."); + new ReferenceEvent(ReferenceEventType.ReinstateSubscription, user)); } - if (!user.UsesKeyConnector) + public async Task EnablePremiumAsync(Guid userId, DateTime? expirationDate) { - throw new BadRequestException("Not using Key Connector."); + var user = await _userRepository.GetByIdAsync(userId); + await EnablePremiumAsync(user, expirationDate); } - var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "otp:" + user.Email); - await _mailService.SendOTPEmailAsync(user.Email, token); - } - - public Task VerifyOTPAsync(User user, string token) - { - return base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "otp:" + user.Email, token); - } - - public async Task VerifySecretAsync(User user, string secret) - { - return user.UsesKeyConnector - ? await VerifyOTPAsync(user, secret) - : await CheckPasswordAsync(user, secret); - } - - public async Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType) - { - return CanEditDeviceVerificationSettings(user) - && user.UnknownDeviceVerificationEnabled - && grantType != "authorization_code" - && await IsNewDeviceAndNotTheFirstOneAsync(user, deviceIdentifier); - } - - public bool CanEditDeviceVerificationSettings(User user) - { - return _globalSettings.TwoFactorAuth.EmailOnNewDeviceLogin - && user.EmailVerified - && !user.UsesKeyConnector - && !(user.GetTwoFactorProviders()?.Any() ?? false); - } - - private async Task IsNewDeviceAndNotTheFirstOneAsync(User user, string deviceIdentifier) - { - if (user == null) + public async Task EnablePremiumAsync(User user, DateTime? expirationDate) { - return default; + if (user != null && !user.Premium && user.Gateway.HasValue) + { + user.Premium = true; + user.PremiumExpirationDate = expirationDate; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } } - var devices = await _deviceRepository.GetManyByUserIdAsync(user.Id); - if (!devices.Any()) + public async Task DisablePremiumAsync(Guid userId, DateTime? expirationDate) { + var user = await _userRepository.GetByIdAsync(userId); + await DisablePremiumAsync(user, expirationDate); + } + + public async Task DisablePremiumAsync(User user, DateTime? expirationDate) + { + if (user != null && user.Premium) + { + user.Premium = false; + user.PremiumExpirationDate = expirationDate; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + } + + public async Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate) + { + var user = await _userRepository.GetByIdAsync(userId); + if (user != null) + { + user.PremiumExpirationDate = expirationDate; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + } + + public async Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, + int? version = null) + { + if (user == null) + { + throw new NotFoundException(); + } + + if (subscriptionInfo == null && user.Gateway != null) + { + subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); + } + + return subscriptionInfo == null ? new UserLicense(user, _licenseService) : + new UserLicense(user, subscriptionInfo, _licenseService); + } + + public override async Task CheckPasswordAsync(User user, string password) + { + if (user == null) + { + return false; + } + + var result = await base.VerifyPasswordAsync(Store as IUserPasswordStore, user, password); + if (result == PasswordVerificationResult.SuccessRehashNeeded) + { + await UpdatePasswordHash(user, password, false, false); + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + + var success = result != PasswordVerificationResult.Failed; + if (!success) + { + Logger.LogWarning(0, "Invalid password for user {userId}.", user.Id); + } + return success; + } + + public async Task CanAccessPremium(ITwoFactorProvidersUser user) + { + var userId = user.GetUserId(); + if (!userId.HasValue) + { + return false; + } + + return user.GetPremium() || await this.HasPremiumFromOrganization(user); + } + + public async Task HasPremiumFromOrganization(ITwoFactorProvidersUser user) + { + var userId = user.GetUserId(); + if (!userId.HasValue) + { + return false; + } + + // orgUsers in the Invited status are not associated with a userId yet, so this will get + // orgUsers in Accepted and Confirmed states only + var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value); + + if (!orgUsers.Any()) + { + return false; + } + + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + return orgUsers.Any(ou => + orgAbilities.TryGetValue(ou.OrganizationId, out var orgAbility) && + orgAbility.UsersGetPremium && + orgAbility.Enabled); + } + + public async Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null) + { + return false; + } + + foreach (var p in providers) + { + if (p.Value?.Enabled ?? false) + { + if (!TwoFactorProvider.RequiresPremium(p.Key)) + { + return true; + } + if (await CanAccessPremium(user)) + { + return true; + } + } + } return false; } - return !devices.Any(d => d.Identifier == deviceIdentifier); + public async Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider) || !providers[provider].Enabled) + { + return false; + } + + if (!TwoFactorProvider.RequiresPremium(provider)) + { + return true; + } + + return await CanAccessPremium(user); + } + + public async Task GenerateSignInTokenAsync(User user, string purpose) + { + var token = await GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, + purpose); + return token; + } + + private async Task UpdatePasswordHash(User user, string newPassword, + bool validatePassword = true, bool refreshStamp = true) + { + if (validatePassword) + { + var validate = await ValidatePasswordInternal(user, newPassword); + if (!validate.Succeeded) + { + return validate; + } + } + + user.MasterPassword = _passwordHasher.HashPassword(user, newPassword); + if (refreshStamp) + { + user.SecurityStamp = Guid.NewGuid().ToString(); + } + + return IdentityResult.Success; + } + + private async Task ValidatePasswordInternal(User user, string password) + { + var errors = new List(); + foreach (var v in _passwordValidators) + { + var result = await v.ValidateAsync(this, user, password); + if (!result.Succeeded) + { + errors.AddRange(result.Errors); + } + } + + if (errors.Count > 0) + { + Logger.LogWarning("User {userId} password validation failed: {errors}.", await GetUserIdAsync(user), + string.Join(";", errors.Select(e => e.Code))); + return IdentityResult.Failed(errors.ToArray()); + } + + return IdentityResult.Success; + } + + public void SetTwoFactorProvider(User user, TwoFactorProviderType type, bool setEnabled = true) + { + var providers = user.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + if (setEnabled) + { + providers[type].Enabled = true; + } + user.SetTwoFactorProviders(providers); + + if (string.IsNullOrWhiteSpace(user.TwoFactorRecoveryCode)) + { + user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); + } + } + + private async Task CheckPoliciesOnTwoFactorRemovalAsync(User user, IOrganizationService organizationService) + { + var twoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, + PolicyType.TwoFactorAuthentication); + + var removeOrgUserTasks = twoFactorPolicies.Select(async p => + { + await organizationService.DeleteUserAsync(p.OrganizationId, user.Id); + var organization = await _organizationRepository.GetByIdAsync(p.OrganizationId); + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + organization.Name, user.Email); + }).ToArray(); + + await Task.WhenAll(removeOrgUserTasks); + } + + public override async Task ConfirmEmailAsync(User user, string token) + { + var result = await base.ConfirmEmailAsync(user, token); + if (result.Succeeded) + { + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.ConfirmEmailAddress, user)); + } + return result; + } + + public async Task RotateApiKeyAsync(User user) + { + user.ApiKey = CoreHelpers.SecureRandomString(30); + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + + public async Task SendOTPAsync(User user) + { + if (user.Email == null) + { + throw new BadRequestException("No user email."); + } + + if (!user.UsesKeyConnector) + { + throw new BadRequestException("Not using Key Connector."); + } + + var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "otp:" + user.Email); + await _mailService.SendOTPEmailAsync(user.Email, token); + } + + public Task VerifyOTPAsync(User user, string token) + { + return base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "otp:" + user.Email, token); + } + + public async Task VerifySecretAsync(User user, string secret) + { + return user.UsesKeyConnector + ? await VerifyOTPAsync(user, secret) + : await CheckPasswordAsync(user, secret); + } + + public async Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType) + { + return CanEditDeviceVerificationSettings(user) + && user.UnknownDeviceVerificationEnabled + && grantType != "authorization_code" + && await IsNewDeviceAndNotTheFirstOneAsync(user, deviceIdentifier); + } + + public bool CanEditDeviceVerificationSettings(User user) + { + return _globalSettings.TwoFactorAuth.EmailOnNewDeviceLogin + && user.EmailVerified + && !user.UsesKeyConnector + && !(user.GetTwoFactorProviders()?.Any() ?? false); + } + + private async Task IsNewDeviceAndNotTheFirstOneAsync(User user, string deviceIdentifier) + { + if (user == null) + { + return default; + } + + var devices = await _deviceRepository.GetManyByUserIdAsync(user.Id); + if (!devices.Any()) + { + return false; + } + + return !devices.Any(d => d.Identifier == deviceIdentifier); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs b/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs index 24f669c366..7643fc43cb 100644 --- a/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs +++ b/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs @@ -2,68 +2,69 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public class NoopAttachmentStorageService : IAttachmentStorageService +namespace Bit.Core.Services { - public FileUploadType FileUploadType => FileUploadType.Direct; + public class NoopAttachmentStorageService : IAttachmentStorageService + { + public FileUploadType FileUploadType => FileUploadType.Direct; - public Task CleanupAsync(Guid cipherId) - { - return Task.FromResult(0); - } + public Task CleanupAsync(Guid cipherId) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentsForCipherAsync(Guid cipherId) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentsForCipherAsync(Guid cipherId) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentsForUserAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentsForUserAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) - { - return Task.FromResult(0); - } + public Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) + { + return Task.FromResult(0); + } - public Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult((string)null); - } + public Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult((string)null); + } - public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(default(string)); - } - public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) - { - return Task.FromResult((false, (long?)null)); + public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(default(string)); + } + public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) + { + return Task.FromResult((false, (long?)null)); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopBlockIpService.cs b/src/Core/Services/NoopImplementations/NoopBlockIpService.cs index fd034325e5..4ec59f09db 100644 --- a/src/Core/Services/NoopImplementations/NoopBlockIpService.cs +++ b/src/Core/Services/NoopImplementations/NoopBlockIpService.cs @@ -1,10 +1,11 @@ -namespace Bit.Core.Services; - -public class NoopBlockIpService : IBlockIpService +namespace Bit.Core.Services { - public Task BlockIpAsync(string ipAddress, bool permanentBlock) + public class NoopBlockIpService : IBlockIpService { - // Do nothing - return Task.FromResult(0); + public Task BlockIpAsync(string ipAddress, bool permanentBlock) + { + // Do nothing + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs b/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs index ef5e3366da..6e680227ad 100644 --- a/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs +++ b/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs @@ -2,17 +2,18 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.Services; - -public class NoopCaptchaValidationService : ICaptchaValidationService +namespace Bit.Core.Services { - public string SiteKeyResponseKeyName => null; - public string SiteKey => null; - public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) => false; - public string GenerateCaptchaBypassToken(User user) => ""; - public Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, - User user = null) + public class NoopCaptchaValidationService : ICaptchaValidationService { - return Task.FromResult(new CaptchaResponse { Success = true }); + public string SiteKeyResponseKeyName => null; + public string SiteKey => null; + public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) => false; + public string GenerateCaptchaBypassToken(User user) => ""; + public Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, + User user = null) + { + return Task.FromResult(new CaptchaResponse { Success = true }); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopEventService.cs b/src/Core/Services/NoopImplementations/NoopEventService.cs index 7c596717e9..976657bf33 100644 --- a/src/Core/Services/NoopImplementations/NoopEventService.cs +++ b/src/Core/Services/NoopImplementations/NoopEventService.cs @@ -2,69 +2,70 @@ using Bit.Core.Entities.Provider; using Bit.Core.Enums; -namespace Bit.Core.Services; - -public class NoopEventService : IEventService +namespace Bit.Core.Services { - public Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) + public class NoopEventService : IEventService { - return Task.FromResult(0); - } + public Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogCipherEventsAsync(IEnumerable> events) - { - return Task.FromResult(0); - } + public Task LogCipherEventsAsync(IEnumerable> events) + { + return Task.FromResult(0); + } - public Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) - { - return Task.FromResult(0); - } + public Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) + { + return Task.FromResult(0); + } - public Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, - DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, + DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, - DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, + DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) - { - return Task.FromResult(0); - } + public Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) + { + return Task.FromResult(0); + } - public Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) - { - return Task.FromResult(0); + public Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopEventWriteService.cs b/src/Core/Services/NoopImplementations/NoopEventWriteService.cs index d7288389f5..94be40b209 100644 --- a/src/Core/Services/NoopImplementations/NoopEventWriteService.cs +++ b/src/Core/Services/NoopImplementations/NoopEventWriteService.cs @@ -1,16 +1,17 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Services; - -public class NoopEventWriteService : IEventWriteService +namespace Bit.Core.Services { - public Task CreateAsync(IEvent e) + public class NoopEventWriteService : IEventWriteService { - return Task.FromResult(0); - } + public Task CreateAsync(IEvent e) + { + return Task.FromResult(0); + } - public Task CreateManyAsync(IEnumerable e) - { - return Task.FromResult(0); + public Task CreateManyAsync(IEnumerable e) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopLicensingService.cs b/src/Core/Services/NoopImplementations/NoopLicensingService.cs index c79be8009e..ef5cb9b856 100644 --- a/src/Core/Services/NoopImplementations/NoopLicensingService.cs +++ b/src/Core/Services/NoopImplementations/NoopLicensingService.cs @@ -4,52 +4,53 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; -namespace Bit.Core.Services; - -public class NoopLicensingService : ILicensingService +namespace Bit.Core.Services { - public NoopLicensingService( - IWebHostEnvironment environment, - GlobalSettings globalSettings) + public class NoopLicensingService : ILicensingService { - if (!environment.IsDevelopment() && globalSettings.SelfHosted) + public NoopLicensingService( + IWebHostEnvironment environment, + GlobalSettings globalSettings) { - throw new Exception($"{nameof(NoopLicensingService)} cannot be used for self hosted instances."); + if (!environment.IsDevelopment() && globalSettings.SelfHosted) + { + throw new Exception($"{nameof(NoopLicensingService)} cannot be used for self hosted instances."); + } + } + + public Task ValidateOrganizationsAsync() + { + return Task.FromResult(0); + } + + public Task ValidateUsersAsync() + { + return Task.FromResult(0); + } + + public Task ValidateUserPremiumAsync(User user) + { + return Task.FromResult(user.Premium); + } + + public bool VerifyLicense(ILicense license) + { + return true; + } + + public byte[] SignLicense(ILicense license) + { + return new byte[0]; + } + + public Task ReadOrganizationLicenseAsync(Organization organization) + { + return Task.FromResult(null); + } + + public Task ReadOrganizationLicenseAsync(Guid organizationId) + { + return Task.FromResult(null); } } - - public Task ValidateOrganizationsAsync() - { - return Task.FromResult(0); - } - - public Task ValidateUsersAsync() - { - return Task.FromResult(0); - } - - public Task ValidateUserPremiumAsync(User user) - { - return Task.FromResult(user.Premium); - } - - public bool VerifyLicense(ILicense license) - { - return true; - } - - public byte[] SignLicense(ILicense license) - { - return new byte[0]; - } - - public Task ReadOrganizationLicenseAsync(Organization organization) - { - return Task.FromResult(null); - } - - public Task ReadOrganizationLicenseAsync(Guid organizationId) - { - return Task.FromResult(null); - } } diff --git a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs b/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs index 96b97b14f5..dc8ef6b60b 100644 --- a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs @@ -1,11 +1,12 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - -public class NoopMailDeliveryService : IMailDeliveryService +namespace Bit.Core.Services { - public Task SendEmailAsync(MailMessage message) + public class NoopMailDeliveryService : IMailDeliveryService { - return Task.FromResult(0); + public Task SendEmailAsync(MailMessage message) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index cee8c91f42..910516ab52 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -3,238 +3,239 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - -public class NoopMailService : IMailService +namespace Bit.Core.Services { - public Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) + public class NoopMailService : IMailService { - return Task.FromResult(0); - } + public Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) + { + return Task.FromResult(0); + } - public Task SendVerifyEmailEmailAsync(string email, Guid userId, string hint) - { - return Task.FromResult(0); - } + public Task SendVerifyEmailEmailAsync(string email, Guid userId, string hint) + { + return Task.FromResult(0); + } - public Task SendChangeEmailEmailAsync(string newEmailAddress, string token) - { - return Task.FromResult(0); - } + public Task SendChangeEmailEmailAsync(string newEmailAddress, string token) + { + return Task.FromResult(0); + } - public Task SendMasterPasswordHintEmailAsync(string email, string hint) - { - return Task.FromResult(0); - } + public Task SendMasterPasswordHintEmailAsync(string email, string hint) + { + return Task.FromResult(0); + } - public Task SendNoMasterPasswordHintEmailAsync(string email) - { - return Task.FromResult(0); - } + public Task SendNoMasterPasswordHintEmailAsync(string email) + { + return Task.FromResult(0); + } - public Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) - { - return Task.FromResult(0); - } + public Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) + { + return Task.FromResult(0); + } - public Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) - { - return Task.FromResult(0); - } + public Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) + { + return Task.FromResult(0); + } - public Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails) - { - return Task.FromResult(0); - } + public Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails) + { + return Task.FromResult(0); + } - public Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) - { - return Task.FromResult(0); - } + public Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) + { + return Task.FromResult(0); + } - public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) - { - return Task.FromResult(0); - } + public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) + { + return Task.FromResult(0); + } - public Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) - { - return Task.FromResult(0); - } + public Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) + { + return Task.FromResult(0); + } - public Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) - { - return Task.FromResult(0); - } + public Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) + { + return Task.FromResult(0); + } - public Task SendTwoFactorEmailAsync(string email, string token) - { - return Task.FromResult(0); - } + public Task SendTwoFactorEmailAsync(string email, string token) + { + return Task.FromResult(0); + } - public Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) - { - return Task.CompletedTask; - } + public Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) + { + return Task.CompletedTask; + } - public Task SendWelcomeEmailAsync(User user) - { - return Task.FromResult(0); - } + public Task SendWelcomeEmailAsync(User user) + { + return Task.FromResult(0); + } - public Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) - { - return Task.FromResult(0); - } + public Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) + { + return Task.FromResult(0); + } - public Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) - { - return Task.FromResult(0); - } + public Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) + { + return Task.FromResult(0); + } - public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, - List items, bool mentionInvoices) - { - return Task.FromResult(0); - } + public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, + List items, bool mentionInvoices) + { + return Task.FromResult(0); + } - public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) - { - return Task.FromResult(0); - } + public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) + { + return Task.FromResult(0); + } - public Task SendAddedCreditAsync(string email, decimal amount) - { - return Task.FromResult(0); - } + public Task SendAddedCreditAsync(string email, decimal amount) + { + return Task.FromResult(0); + } - public Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) - { - return Task.FromResult(0); - } + public Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) + { + return Task.FromResult(0); + } - public Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) - { - return Task.FromResult(0); - } + public Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) + { + return Task.FromResult(0); + } - public Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) - { - return Task.FromResult(0); - } + public Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) + { + return Task.FromResult(0); + } - public Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) - { - return Task.FromResult(0); - } + public Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email) + { + return Task.FromResult(0); + } - public Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) - { - return Task.FromResult(0); - } + public Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) + { + return Task.FromResult(0); + } - public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) - { - return Task.FromResult(0); - } + public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + { + return Task.FromResult(0); + } - public Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) - { - return Task.FromResult(0); - } + public Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) + { + return Task.FromResult(0); + } - public Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) - { - return Task.FromResult(0); - } + public Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) + { + return Task.FromResult(0); + } - public Task SendProviderConfirmedEmailAsync(string providerName, string email) - { - return Task.FromResult(0); - } + public Task SendProviderConfirmedEmailAsync(string providerName, string email) + { + return Task.FromResult(0); + } - public Task SendProviderUserRemoved(string providerName, string email) - { - return Task.FromResult(0); - } + public Task SendProviderUserRemoved(string providerName, string email) + { + return Task.FromResult(0); + } - public Task SendUpdatedTempPasswordEmailAsync(string email, string userName) - { - return Task.FromResult(0); - } + public Task SendUpdatedTempPasswordEmailAsync(string email, string userName) + { + return Task.FromResult(0); + } - public Task SendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, string email, bool existingAccount, string token) - { - return Task.FromResult(0); - } + public Task SendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, string email, bool existingAccount, string token) + { + return Task.FromResult(0); + } - public Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) - { - return Task.FromResult(0); - } + public Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) + { + return Task.FromResult(0); + } - public Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) - { - return Task.FromResult(0); - } + public Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) + { + return Task.FromResult(0); + } - public Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) - { - return Task.FromResult(0); - } + public Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) + { + return Task.FromResult(0); + } - public Task SendOTPEmailAsync(string email, string token) - { - return Task.FromResult(0); - } + public Task SendOTPEmailAsync(string email, string token) + { + return Task.FromResult(0); + } - public Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - return Task.FromResult(0); - } + public Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + return Task.FromResult(0); + } - public Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - return Task.FromResult(0); + public Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopProviderService.cs b/src/Core/Services/NoopImplementations/NoopProviderService.cs index 478c5c6c10..efa5741449 100644 --- a/src/Core/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/Services/NoopImplementations/NoopProviderService.cs @@ -3,35 +3,36 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Business.Provider; -namespace Bit.Core.Services; - -public class NoopProviderService : IProviderService +namespace Bit.Core.Services { - public Task CreateAsync(string ownerEmail) => throw new NotImplementedException(); + public class NoopProviderService : IProviderService + { + public Task CreateAsync(string ownerEmail) => throw new NotImplementedException(); - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) => throw new NotImplementedException(); - public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); + public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); - public Task> InviteUserAsync(ProviderUserInvite invite) => throw new NotImplementedException(); + public Task> InviteUserAsync(ProviderUserInvite invite) => throw new NotImplementedException(); - public Task>> ResendInvitesAsync(ProviderUserInvite invite) => throw new NotImplementedException(); + public Task>> ResendInvitesAsync(ProviderUserInvite invite) => throw new NotImplementedException(); - public Task AcceptUserAsync(Guid providerUserId, User user, string token) => throw new NotImplementedException(); + public Task AcceptUserAsync(Guid providerUserId, User user, string token) => throw new NotImplementedException(); - public Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId) => throw new NotImplementedException(); + public Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId) => throw new NotImplementedException(); - public Task SaveUserAsync(ProviderUser user, Guid savingUserId) => throw new NotImplementedException(); + public Task SaveUserAsync(ProviderUser user, Guid savingUserId) => throw new NotImplementedException(); - public Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, Guid deletingUserId) => throw new NotImplementedException(); + public Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, Guid deletingUserId) => throw new NotImplementedException(); - public Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) => throw new NotImplementedException(); + public Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) => throw new NotImplementedException(); - public Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, string clientOwnerEmail, User user) => throw new NotImplementedException(); + public Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, string clientOwnerEmail, User user) => throw new NotImplementedException(); - public Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) => throw new NotImplementedException(); + public Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) => throw new NotImplementedException(); - public Task LogProviderAccessToOrganizationAsync(Guid organizationId) => throw new NotImplementedException(); + public Task LogProviderAccessToOrganizationAsync(Guid organizationId) => throw new NotImplementedException(); - public Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid userId) => throw new NotImplementedException(); + public Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid userId) => throw new NotImplementedException(); + } } diff --git a/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs b/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs index ee2c6a498b..8d9f1117eb 100644 --- a/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs +++ b/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs @@ -1,89 +1,90 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services; - -public class NoopPushNotificationService : IPushNotificationService +namespace Bit.Core.Services { - public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + public class NoopPushNotificationService : IPushNotificationService { - return Task.FromResult(0); - } + public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + return Task.FromResult(0); + } - public Task PushSyncCipherDeleteAsync(Cipher cipher) - { - return Task.FromResult(0); - } + public Task PushSyncCipherDeleteAsync(Cipher cipher) + { + return Task.FromResult(0); + } - public Task PushSyncCiphersAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncCiphersAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - return Task.FromResult(0); - } + public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + return Task.FromResult(0); + } - public Task PushSyncFolderCreateAsync(Folder folder) - { - return Task.FromResult(0); - } + public Task PushSyncFolderCreateAsync(Folder folder) + { + return Task.FromResult(0); + } - public Task PushSyncFolderDeleteAsync(Folder folder) - { - return Task.FromResult(0); - } + public Task PushSyncFolderDeleteAsync(Folder folder) + { + return Task.FromResult(0); + } - public Task PushSyncFolderUpdateAsync(Folder folder) - { - return Task.FromResult(0); - } + public Task PushSyncFolderUpdateAsync(Folder folder) + { + return Task.FromResult(0); + } - public Task PushSyncOrgKeysAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncOrgKeysAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncSettingsAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncSettingsAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncVaultAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncVaultAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushLogOutAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushLogOutAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncSendCreateAsync(Send send) - { - return Task.FromResult(0); - } + public Task PushSyncSendCreateAsync(Send send) + { + return Task.FromResult(0); + } - public Task PushSyncSendDeleteAsync(Send send) - { - return Task.FromResult(0); - } + public Task PushSyncSendDeleteAsync(Send send) + { + return Task.FromResult(0); + } - public Task PushSyncSendUpdateAsync(Send send) - { - return Task.FromResult(0); - } + public Task PushSyncSendUpdateAsync(Send send) + { + return Task.FromResult(0); + } - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - return Task.FromResult(0); - } + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + return Task.FromResult(0); + } - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - return Task.FromResult(0); + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs index f6279c9467..c574314e09 100644 --- a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs +++ b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs @@ -1,27 +1,28 @@ using Bit.Core.Enums; -namespace Bit.Core.Services; - -public class NoopPushRegistrationService : IPushRegistrationService +namespace Bit.Core.Services { - public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + public class NoopPushRegistrationService : IPushRegistrationService { - return Task.FromResult(0); - } + public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + return Task.FromResult(0); + } - public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) - { - return Task.FromResult(0); - } + public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type) + { + return Task.FromResult(0); + } - public Task DeleteRegistrationAsync(string deviceId) - { - return Task.FromResult(0); - } + public Task DeleteRegistrationAsync(string deviceId) + { + return Task.FromResult(0); + } - public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - return Task.FromResult(0); + public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + return Task.FromResult(0); + } } } diff --git a/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs b/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs index a32001e854..fa15ce7274 100644 --- a/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs +++ b/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs @@ -1,11 +1,12 @@ using Bit.Core.Models.Business; -namespace Bit.Core.Services; - -public class NoopReferenceEventService : IReferenceEventService +namespace Bit.Core.Services { - public Task RaiseEventAsync(ReferenceEvent referenceEvent) + public class NoopReferenceEventService : IReferenceEventService { - return Task.CompletedTask; + public Task RaiseEventAsync(ReferenceEvent referenceEvent) + { + return Task.CompletedTask; + } } } diff --git a/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs b/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs index 08602ef9fb..407e3976fd 100644 --- a/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs +++ b/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs @@ -1,44 +1,45 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services; - -public class NoopSendFileStorageService : ISendFileStorageService +namespace Bit.Core.Services { - public FileUploadType FileUploadType => FileUploadType.Direct; - - public Task UploadNewFileAsync(Stream stream, Send send, string attachmentId) + public class NoopSendFileStorageService : ISendFileStorageService { - return Task.FromResult(0); - } + public FileUploadType FileUploadType => FileUploadType.Direct; - public Task DeleteFileAsync(Send send, string fileId) - { - return Task.FromResult(0); - } + public Task UploadNewFileAsync(Stream stream, Send send, string attachmentId) + { + return Task.FromResult(0); + } - public Task DeleteFilesForOrganizationAsync(Guid organizationId) - { - return Task.FromResult(0); - } + public Task DeleteFileAsync(Send send, string fileId) + { + return Task.FromResult(0); + } - public Task DeleteFilesForUserAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task DeleteFilesForOrganizationAsync(Guid organizationId) + { + return Task.FromResult(0); + } - public Task GetSendFileDownloadUrlAsync(Send send, string fileId) - { - return Task.FromResult((string)null); - } + public Task DeleteFilesForUserAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task GetSendFileUploadUrlAsync(Send send, string fileId) - { - return Task.FromResult((string)null); - } + public Task GetSendFileDownloadUrlAsync(Send send, string fileId) + { + return Task.FromResult((string)null); + } - public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) - { - return Task.FromResult((false, default(long?))); + public Task GetSendFileUploadUrlAsync(Send send, string fileId) + { + return Task.FromResult((string)null); + } + + public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) + { + return Task.FromResult((false, default(long?))); + } } } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index bd4087f3a9..f0bdca4efe 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -1,501 +1,502 @@ -namespace Bit.Core.Settings; - -public class GlobalSettings : IGlobalSettings +namespace Bit.Core.Settings { - private string _logDirectory; - private string _licenseDirectory; - - public GlobalSettings() + public class GlobalSettings : IGlobalSettings { - BaseServiceUri = new BaseServiceUriSettings(this); - Attachment = new FileStorageSettings(this, "attachments", "attachments"); - Send = new FileStorageSettings(this, "attachments/send", "attachments/send"); - DataProtection = new DataProtectionSettings(this); - } + private string _logDirectory; + private string _licenseDirectory; - public bool SelfHosted { get; set; } - public virtual string KnownProxies { get; set; } - public virtual string SiteName { get; set; } - public virtual string ProjectName { get; set; } - public virtual string LogDirectory - { - get => BuildDirectory(_logDirectory, "/logs"); - set => _logDirectory = value; - } - public virtual long? LogRollBySizeLimit { get; set; } - public virtual string LicenseDirectory - { - get => BuildDirectory(_licenseDirectory, "/core/licenses"); - set => _licenseDirectory = value; - } - public string LicenseCertificatePassword { get; set; } - public virtual string PushRelayBaseUri { get; set; } - public virtual string InternalIdentityKey { get; set; } - public virtual string OidcIdentityClientKey { get; set; } - public virtual string HibpApiKey { get; set; } - public virtual bool DisableUserRegistration { get; set; } - public virtual bool DisableEmailNewDevice { get; set; } - public virtual bool EnableCloudCommunication { get; set; } = false; - public virtual int OrganizationInviteExpirationHours { get; set; } = 120; // 5 days - public virtual string EventGridKey { get; set; } - public virtual CaptchaSettings Captcha { get; set; } = new CaptchaSettings(); - public virtual IInstallationSettings Installation { get; set; } = new InstallationSettings(); - public virtual IBaseServiceUriSettings BaseServiceUri { get; set; } - public virtual string DatabaseProvider { get; set; } - public virtual SqlSettings SqlServer { get; set; } = new SqlSettings(); - public virtual SqlSettings PostgreSql { get; set; } = new SqlSettings(); - public virtual SqlSettings MySql { get; set; } = new SqlSettings(); - public virtual SqlSettings Sqlite { get; set; } = new SqlSettings(); - public virtual MailSettings Mail { get; set; } = new MailSettings(); - public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); - public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); - public virtual IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); - public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); - public virtual IFileStorageSettings Attachment { get; set; } - public virtual FileStorageSettings Send { get; set; } - public virtual IdentityServerSettings IdentityServer { get; set; } = new IdentityServerSettings(); - public virtual DataProtectionSettings DataProtection { get; set; } - public virtual DocumentDbSettings DocumentDb { get; set; } = new DocumentDbSettings(); - public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); - public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); - public virtual NotificationHubSettings NotificationHub { get; set; } = new NotificationHubSettings(); - public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); - public virtual DuoSettings Duo { get; set; } = new DuoSettings(); - public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); - public virtual BitPaySettings BitPay { get; set; } = new BitPaySettings(); - public virtual AmazonSettings Amazon { get; set; } = new AmazonSettings(); - public virtual ServiceBusSettings ServiceBus { get; set; } = new ServiceBusSettings(); - public virtual AppleIapSettings AppleIap { get; set; } = new AppleIapSettings(); - public virtual ISsoSettings Sso { get; set; } = new SsoSettings(); - public virtual StripeSettings Stripe { get; set; } = new StripeSettings(); - public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings(); - - public string BuildExternalUri(string explicitValue, string name) - { - if (!string.IsNullOrWhiteSpace(explicitValue)) + public GlobalSettings() { - return explicitValue; - } - if (!SelfHosted) - { - return null; - } - return string.Format("{0}/{1}", BaseServiceUri.Vault, name); - } - - public string BuildInternalUri(string explicitValue, string name) - { - if (!string.IsNullOrWhiteSpace(explicitValue)) - { - return explicitValue; - } - if (!SelfHosted) - { - return null; - } - return string.Format("http://{0}:5000", name); - } - - public string BuildDirectory(string explicitValue, string appendedPath) - { - if (!string.IsNullOrWhiteSpace(explicitValue)) - { - return explicitValue; - } - if (!SelfHosted) - { - return null; - } - return string.Concat("/etc/bitwarden", appendedPath); - } - - public class BaseServiceUriSettings : IBaseServiceUriSettings - { - private readonly GlobalSettings _globalSettings; - - private string _api; - private string _identity; - private string _admin; - private string _notifications; - private string _sso; - private string _scim; - private string _internalApi; - private string _internalIdentity; - private string _internalAdmin; - private string _internalNotifications; - private string _internalSso; - private string _internalVault; - private string _internalScim; - - public BaseServiceUriSettings(GlobalSettings globalSettings) - { - _globalSettings = globalSettings; + BaseServiceUri = new BaseServiceUriSettings(this); + Attachment = new FileStorageSettings(this, "attachments", "attachments"); + Send = new FileStorageSettings(this, "attachments/send", "attachments/send"); + DataProtection = new DataProtectionSettings(this); } - public string Vault { get; set; } - public string VaultWithHash => $"{Vault}/#"; + public bool SelfHosted { get; set; } + public virtual string KnownProxies { get; set; } + public virtual string SiteName { get; set; } + public virtual string ProjectName { get; set; } + public virtual string LogDirectory + { + get => BuildDirectory(_logDirectory, "/logs"); + set => _logDirectory = value; + } + public virtual long? LogRollBySizeLimit { get; set; } + public virtual string LicenseDirectory + { + get => BuildDirectory(_licenseDirectory, "/core/licenses"); + set => _licenseDirectory = value; + } + public string LicenseCertificatePassword { get; set; } + public virtual string PushRelayBaseUri { get; set; } + public virtual string InternalIdentityKey { get; set; } + public virtual string OidcIdentityClientKey { get; set; } + public virtual string HibpApiKey { get; set; } + public virtual bool DisableUserRegistration { get; set; } + public virtual bool DisableEmailNewDevice { get; set; } + public virtual bool EnableCloudCommunication { get; set; } = false; + public virtual int OrganizationInviteExpirationHours { get; set; } = 120; // 5 days + public virtual string EventGridKey { get; set; } + public virtual CaptchaSettings Captcha { get; set; } = new CaptchaSettings(); + public virtual IInstallationSettings Installation { get; set; } = new InstallationSettings(); + public virtual IBaseServiceUriSettings BaseServiceUri { get; set; } + public virtual string DatabaseProvider { get; set; } + public virtual SqlSettings SqlServer { get; set; } = new SqlSettings(); + public virtual SqlSettings PostgreSql { get; set; } = new SqlSettings(); + public virtual SqlSettings MySql { get; set; } = new SqlSettings(); + public virtual SqlSettings Sqlite { get; set; } = new SqlSettings(); + public virtual MailSettings Mail { get; set; } = new MailSettings(); + public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); + public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); + public virtual IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); + public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); + public virtual IFileStorageSettings Attachment { get; set; } + public virtual FileStorageSettings Send { get; set; } + public virtual IdentityServerSettings IdentityServer { get; set; } = new IdentityServerSettings(); + public virtual DataProtectionSettings DataProtection { get; set; } + public virtual DocumentDbSettings DocumentDb { get; set; } = new DocumentDbSettings(); + public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); + public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); + public virtual NotificationHubSettings NotificationHub { get; set; } = new NotificationHubSettings(); + public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); + public virtual DuoSettings Duo { get; set; } = new DuoSettings(); + public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); + public virtual BitPaySettings BitPay { get; set; } = new BitPaySettings(); + public virtual AmazonSettings Amazon { get; set; } = new AmazonSettings(); + public virtual ServiceBusSettings ServiceBus { get; set; } = new ServiceBusSettings(); + public virtual AppleIapSettings AppleIap { get; set; } = new AppleIapSettings(); + public virtual ISsoSettings Sso { get; set; } = new SsoSettings(); + public virtual StripeSettings Stripe { get; set; } = new StripeSettings(); + public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings(); - public string Api + public string BuildExternalUri(string explicitValue, string name) { - get => _globalSettings.BuildExternalUri(_api, "api"); - set => _api = value; - } - public string Identity - { - get => _globalSettings.BuildExternalUri(_identity, "identity"); - set => _identity = value; - } - public string Admin - { - get => _globalSettings.BuildExternalUri(_admin, "admin"); - set => _admin = value; - } - public string Notifications - { - get => _globalSettings.BuildExternalUri(_notifications, "notifications"); - set => _notifications = value; - } - public string Sso - { - get => _globalSettings.BuildExternalUri(_sso, "sso"); - set => _sso = value; - } - public string Scim - { - get => _globalSettings.BuildExternalUri(_scim, "scim"); - set => _scim = value; - } - - public string InternalNotifications - { - get => _globalSettings.BuildInternalUri(_internalNotifications, "notifications"); - set => _internalNotifications = value; - } - public string InternalAdmin - { - get => _globalSettings.BuildInternalUri(_internalAdmin, "admin"); - set => _internalAdmin = value; - } - public string InternalIdentity - { - get => _globalSettings.BuildInternalUri(_internalIdentity, "identity"); - set => _internalIdentity = value; - } - public string InternalApi - { - get => _globalSettings.BuildInternalUri(_internalApi, "api"); - set => _internalApi = value; - } - public string InternalVault - { - get => _globalSettings.BuildInternalUri(_internalVault, "web"); - set => _internalVault = value; - } - public string InternalSso - { - get => _globalSettings.BuildInternalUri(_internalSso, "sso"); - set => _internalSso = value; - } - public string InternalScim - { - get => _globalSettings.BuildInternalUri(_scim, "scim"); - set => _internalScim = value; - } - } - - public class SqlSettings - { - private string _connectionString; - private string _readOnlyConnectionString; - private string _jobSchedulerConnectionString; - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } - - public string ReadOnlyConnectionString - { - get => string.IsNullOrWhiteSpace(_readOnlyConnectionString) ? - _connectionString : _readOnlyConnectionString; - set => _readOnlyConnectionString = value.Trim('"'); - } - - public string JobSchedulerConnectionString - { - get => _jobSchedulerConnectionString; - set => _jobSchedulerConnectionString = value.Trim('"'); - } - } - - public class ConnectionStringSettings : IConnectionStringSettings - { - private string _connectionString; - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } - } - - public class FileStorageSettings : IFileStorageSettings - { - private readonly GlobalSettings _globalSettings; - private readonly string _urlName; - private readonly string _directoryName; - private string _connectionString; - private string _baseDirectory; - private string _baseUrl; - - public FileStorageSettings(GlobalSettings globalSettings, string urlName, string directoryName) - { - _globalSettings = globalSettings; - _urlName = urlName; - _directoryName = directoryName; - } - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } - - public string BaseDirectory - { - get => _globalSettings.BuildDirectory(_baseDirectory, string.Concat("/core/", _directoryName)); - set => _baseDirectory = value; - } - - public string BaseUrl - { - get => _globalSettings.BuildExternalUri(_baseUrl, _urlName); - set => _baseUrl = value; - } - } - - public class MailSettings - { - private ConnectionStringSettings _connectionStringSettings; - public string ConnectionString - { - get => _connectionStringSettings?.ConnectionString; - set + if (!string.IsNullOrWhiteSpace(explicitValue)) { - if (_connectionStringSettings == null) - { - _connectionStringSettings = new ConnectionStringSettings(); - } - _connectionStringSettings.ConnectionString = value; + return explicitValue; + } + if (!SelfHosted) + { + return null; + } + return string.Format("{0}/{1}", BaseServiceUri.Vault, name); + } + + public string BuildInternalUri(string explicitValue, string name) + { + if (!string.IsNullOrWhiteSpace(explicitValue)) + { + return explicitValue; + } + if (!SelfHosted) + { + return null; + } + return string.Format("http://{0}:5000", name); + } + + public string BuildDirectory(string explicitValue, string appendedPath) + { + if (!string.IsNullOrWhiteSpace(explicitValue)) + { + return explicitValue; + } + if (!SelfHosted) + { + return null; + } + return string.Concat("/etc/bitwarden", appendedPath); + } + + public class BaseServiceUriSettings : IBaseServiceUriSettings + { + private readonly GlobalSettings _globalSettings; + + private string _api; + private string _identity; + private string _admin; + private string _notifications; + private string _sso; + private string _scim; + private string _internalApi; + private string _internalIdentity; + private string _internalAdmin; + private string _internalNotifications; + private string _internalSso; + private string _internalVault; + private string _internalScim; + + public BaseServiceUriSettings(GlobalSettings globalSettings) + { + _globalSettings = globalSettings; + } + + public string Vault { get; set; } + public string VaultWithHash => $"{Vault}/#"; + + public string Api + { + get => _globalSettings.BuildExternalUri(_api, "api"); + set => _api = value; + } + public string Identity + { + get => _globalSettings.BuildExternalUri(_identity, "identity"); + set => _identity = value; + } + public string Admin + { + get => _globalSettings.BuildExternalUri(_admin, "admin"); + set => _admin = value; + } + public string Notifications + { + get => _globalSettings.BuildExternalUri(_notifications, "notifications"); + set => _notifications = value; + } + public string Sso + { + get => _globalSettings.BuildExternalUri(_sso, "sso"); + set => _sso = value; + } + public string Scim + { + get => _globalSettings.BuildExternalUri(_scim, "scim"); + set => _scim = value; + } + + public string InternalNotifications + { + get => _globalSettings.BuildInternalUri(_internalNotifications, "notifications"); + set => _internalNotifications = value; + } + public string InternalAdmin + { + get => _globalSettings.BuildInternalUri(_internalAdmin, "admin"); + set => _internalAdmin = value; + } + public string InternalIdentity + { + get => _globalSettings.BuildInternalUri(_internalIdentity, "identity"); + set => _internalIdentity = value; + } + public string InternalApi + { + get => _globalSettings.BuildInternalUri(_internalApi, "api"); + set => _internalApi = value; + } + public string InternalVault + { + get => _globalSettings.BuildInternalUri(_internalVault, "web"); + set => _internalVault = value; + } + public string InternalSso + { + get => _globalSettings.BuildInternalUri(_internalSso, "sso"); + set => _internalSso = value; + } + public string InternalScim + { + get => _globalSettings.BuildInternalUri(_scim, "scim"); + set => _internalScim = value; } } - public string ReplyToEmail { get; set; } - public string AmazonConfigSetName { get; set; } - public SmtpSettings Smtp { get; set; } = new SmtpSettings(); - public string SendGridApiKey { get; set; } - public int? SendGridPercentage { get; set; } - public class SmtpSettings + public class SqlSettings + { + private string _connectionString; + private string _readOnlyConnectionString; + private string _jobSchedulerConnectionString; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + + public string ReadOnlyConnectionString + { + get => string.IsNullOrWhiteSpace(_readOnlyConnectionString) ? + _connectionString : _readOnlyConnectionString; + set => _readOnlyConnectionString = value.Trim('"'); + } + + public string JobSchedulerConnectionString + { + get => _jobSchedulerConnectionString; + set => _jobSchedulerConnectionString = value.Trim('"'); + } + } + + public class ConnectionStringSettings : IConnectionStringSettings + { + private string _connectionString; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + } + + public class FileStorageSettings : IFileStorageSettings + { + private readonly GlobalSettings _globalSettings; + private readonly string _urlName; + private readonly string _directoryName; + private string _connectionString; + private string _baseDirectory; + private string _baseUrl; + + public FileStorageSettings(GlobalSettings globalSettings, string urlName, string directoryName) + { + _globalSettings = globalSettings; + _urlName = urlName; + _directoryName = directoryName; + } + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + + public string BaseDirectory + { + get => _globalSettings.BuildDirectory(_baseDirectory, string.Concat("/core/", _directoryName)); + set => _baseDirectory = value; + } + + public string BaseUrl + { + get => _globalSettings.BuildExternalUri(_baseUrl, _urlName); + set => _baseUrl = value; + } + } + + public class MailSettings + { + private ConnectionStringSettings _connectionStringSettings; + public string ConnectionString + { + get => _connectionStringSettings?.ConnectionString; + set + { + if (_connectionStringSettings == null) + { + _connectionStringSettings = new ConnectionStringSettings(); + } + _connectionStringSettings.ConnectionString = value; + } + } + public string ReplyToEmail { get; set; } + public string AmazonConfigSetName { get; set; } + public SmtpSettings Smtp { get; set; } = new SmtpSettings(); + public string SendGridApiKey { get; set; } + public int? SendGridPercentage { get; set; } + + public class SmtpSettings + { + public string Host { get; set; } + public int Port { get; set; } = 25; + public bool StartTls { get; set; } = false; + public bool Ssl { get; set; } = false; + public bool SslOverride { get; set; } = false; + public string Username { get; set; } + public string Password { get; set; } + public bool TrustServer { get; set; } = false; + } + } + + public class IdentityServerSettings + { + public string CertificateThumbprint { get; set; } + public string CertificatePassword { get; set; } + public string RedisConnectionString { get; set; } + } + + public class DataProtectionSettings + { + private readonly GlobalSettings _globalSettings; + + private string _directory; + + public DataProtectionSettings(GlobalSettings globalSettings) + { + _globalSettings = globalSettings; + } + + public string CertificateThumbprint { get; set; } + public string CertificatePassword { get; set; } + public string Directory + { + get => _globalSettings.BuildDirectory(_directory, "/core/aspnet-dataprotection"); + set => _directory = value; + } + } + + public class DocumentDbSettings + { + public string Uri { get; set; } + public string Key { get; set; } + } + + public class SentrySettings + { + public string Dsn { get; set; } + } + + public class NotificationsSettings : ConnectionStringSettings + { + public string RedisConnectionString { get; set; } + } + + public class SyslogSettings + { + /// + /// The connection string used to connect to a remote syslog server over TCP or UDP, or to connect locally. + /// + /// + /// The connection string will be parsed using to extract the protocol, host name and port number. + /// + /// + /// Supported protocols are: + /// + /// UDP (use udp://) + /// TCP (use tcp://) + /// TLS over TCP (use tls://) + /// + /// + /// + /// + /// A remote server (logging.dev.example.com) is listening on UDP (port 514): + /// + /// udp://logging.dev.example.com:514. + /// + public string Destination { get; set; } + /// + /// The absolute path to a Certificate (DER or Base64 encoded with private key). + /// + /// + /// The certificate path and are passed into the . + /// The file format of the certificate may be binary encded (DER) or base64. If the private key is encrypted, provide the password in , + /// + public string CertificatePath { get; set; } + /// + /// The password for the encrypted private key in the certificate supplied in . + /// + /// + public string CertificatePassword { get; set; } + /// + /// The thumbprint of the certificate in the X.509 certificate store for personal certificates for the user account running Bitwarden. + /// + /// + public string CertificateThumbprint { get; set; } + } + + public class NotificationHubSettings + { + private string _connectionString; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + public string HubName { get; set; } + } + + public class YubicoSettings + { + public string ClientId { get; set; } + public string Key { get; set; } + public string[] ValidationUrls { get; set; } + } + + public class DuoSettings + { + public string AKey { get; set; } + } + + public class BraintreeSettings + { + public bool Production { get; set; } + public string MerchantId { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + } + + public class BitPaySettings + { + public bool Production { get; set; } + public string Token { get; set; } + public string NotificationUrl { get; set; } + } + + public class InstallationSettings : IInstallationSettings + { + private string _identityUri; + private string _apiUri; + + public Guid Id { get; set; } + public string Key { get; set; } + public string IdentityUri + { + get => string.IsNullOrWhiteSpace(_identityUri) ? "https://identity.bitwarden.com" : _identityUri; + set => _identityUri = value; + } + public string ApiUri + { + get => string.IsNullOrWhiteSpace(_apiUri) ? "https://api.bitwarden.com" : _apiUri; + set => _apiUri = value; + } + } + + public class AmazonSettings + { + public string AccessKeyId { get; set; } + public string AccessKeySecret { get; set; } + public string Region { get; set; } + } + + public class ServiceBusSettings : ConnectionStringSettings + { + public string ApplicationCacheTopicName { get; set; } + public string ApplicationCacheSubscriptionName { get; set; } + } + + public class AppleIapSettings { - public string Host { get; set; } - public int Port { get; set; } = 25; - public bool StartTls { get; set; } = false; - public bool Ssl { get; set; } = false; - public bool SslOverride { get; set; } = false; - public string Username { get; set; } public string Password { get; set; } - public bool TrustServer { get; set; } = false; + public bool AppInReview { get; set; } } - } - public class IdentityServerSettings - { - public string CertificateThumbprint { get; set; } - public string CertificatePassword { get; set; } - public string RedisConnectionString { get; set; } - } - - public class DataProtectionSettings - { - private readonly GlobalSettings _globalSettings; - - private string _directory; - - public DataProtectionSettings(GlobalSettings globalSettings) + public class SsoSettings : ISsoSettings { - _globalSettings = globalSettings; + public int CacheLifetimeInSeconds { get; set; } = 60; + public double SsoTokenLifetimeInSeconds { get; set; } = 5; } - public string CertificateThumbprint { get; set; } - public string CertificatePassword { get; set; } - public string Directory + public class CaptchaSettings { - get => _globalSettings.BuildDirectory(_directory, "/core/aspnet-dataprotection"); - set => _directory = value; + public bool ForceCaptchaRequired { get; set; } = false; + public string HCaptchaSecretKey { get; set; } + public string HCaptchaSiteKey { get; set; } + public int MaximumFailedLoginAttempts { get; set; } + public double MaybeBotScoreThreshold { get; set; } = double.MaxValue; + public double IsBotScoreThreshold { get; set; } = double.MaxValue; } - } - public class DocumentDbSettings - { - public string Uri { get; set; } - public string Key { get; set; } - } - - public class SentrySettings - { - public string Dsn { get; set; } - } - - public class NotificationsSettings : ConnectionStringSettings - { - public string RedisConnectionString { get; set; } - } - - public class SyslogSettings - { - /// - /// The connection string used to connect to a remote syslog server over TCP or UDP, or to connect locally. - /// - /// - /// The connection string will be parsed using to extract the protocol, host name and port number. - /// - /// - /// Supported protocols are: - /// - /// UDP (use udp://) - /// TCP (use tcp://) - /// TLS over TCP (use tls://) - /// - /// - /// - /// - /// A remote server (logging.dev.example.com) is listening on UDP (port 514): - /// - /// udp://logging.dev.example.com:514. - /// - public string Destination { get; set; } - /// - /// The absolute path to a Certificate (DER or Base64 encoded with private key). - /// - /// - /// The certificate path and are passed into the . - /// The file format of the certificate may be binary encded (DER) or base64. If the private key is encrypted, provide the password in , - /// - public string CertificatePath { get; set; } - /// - /// The password for the encrypted private key in the certificate supplied in . - /// - /// - public string CertificatePassword { get; set; } - /// - /// The thumbprint of the certificate in the X.509 certificate store for personal certificates for the user account running Bitwarden. - /// - /// - public string CertificateThumbprint { get; set; } - } - - public class NotificationHubSettings - { - private string _connectionString; - - public string ConnectionString + public class StripeSettings { - get => _connectionString; - set => _connectionString = value.Trim('"'); + public string ApiKey { get; set; } + public int MaxNetworkRetries { get; set; } = 2; } - public string HubName { get; set; } - } - public class YubicoSettings - { - public string ClientId { get; set; } - public string Key { get; set; } - public string[] ValidationUrls { get; set; } - } - - public class DuoSettings - { - public string AKey { get; set; } - } - - public class BraintreeSettings - { - public bool Production { get; set; } - public string MerchantId { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - } - - public class BitPaySettings - { - public bool Production { get; set; } - public string Token { get; set; } - public string NotificationUrl { get; set; } - } - - public class InstallationSettings : IInstallationSettings - { - private string _identityUri; - private string _apiUri; - - public Guid Id { get; set; } - public string Key { get; set; } - public string IdentityUri + public class TwoFactorAuthSettings : ITwoFactorAuthSettings { - get => string.IsNullOrWhiteSpace(_identityUri) ? "https://identity.bitwarden.com" : _identityUri; - set => _identityUri = value; + public bool EmailOnNewDeviceLogin { get; set; } = false; } - public string ApiUri - { - get => string.IsNullOrWhiteSpace(_apiUri) ? "https://api.bitwarden.com" : _apiUri; - set => _apiUri = value; - } - } - - public class AmazonSettings - { - public string AccessKeyId { get; set; } - public string AccessKeySecret { get; set; } - public string Region { get; set; } - } - - public class ServiceBusSettings : ConnectionStringSettings - { - public string ApplicationCacheTopicName { get; set; } - public string ApplicationCacheSubscriptionName { get; set; } - } - - public class AppleIapSettings - { - public string Password { get; set; } - public bool AppInReview { get; set; } - } - - public class SsoSettings : ISsoSettings - { - public int CacheLifetimeInSeconds { get; set; } = 60; - public double SsoTokenLifetimeInSeconds { get; set; } = 5; - } - - public class CaptchaSettings - { - public bool ForceCaptchaRequired { get; set; } = false; - public string HCaptchaSecretKey { get; set; } - public string HCaptchaSiteKey { get; set; } - public int MaximumFailedLoginAttempts { get; set; } - public double MaybeBotScoreThreshold { get; set; } = double.MaxValue; - public double IsBotScoreThreshold { get; set; } = double.MaxValue; - } - - public class StripeSettings - { - public string ApiKey { get; set; } - public int MaxNetworkRetries { get; set; } = 2; - } - - public class TwoFactorAuthSettings : ITwoFactorAuthSettings - { - public bool EmailOnNewDeviceLogin { get; set; } = false; } } diff --git a/src/Core/Settings/IBaseServiceUriSettings.cs b/src/Core/Settings/IBaseServiceUriSettings.cs index 0550ae3e67..0dfdaf0b93 100644 --- a/src/Core/Settings/IBaseServiceUriSettings.cs +++ b/src/Core/Settings/IBaseServiceUriSettings.cs @@ -1,21 +1,22 @@  -namespace Bit.Core.Settings; - -public interface IBaseServiceUriSettings +namespace Bit.Core.Settings { - string Vault { get; set; } - string VaultWithHash { get; } - string Api { get; set; } - public string Identity { get; set; } - public string Admin { get; set; } - public string Notifications { get; set; } - public string Sso { get; set; } - public string Scim { get; set; } - public string InternalNotifications { get; set; } - public string InternalAdmin { get; set; } - public string InternalIdentity { get; set; } - public string InternalApi { get; set; } - public string InternalVault { get; set; } - public string InternalSso { get; set; } - public string InternalScim { get; set; } + public interface IBaseServiceUriSettings + { + string Vault { get; set; } + string VaultWithHash { get; } + string Api { get; set; } + public string Identity { get; set; } + public string Admin { get; set; } + public string Notifications { get; set; } + public string Sso { get; set; } + public string Scim { get; set; } + public string InternalNotifications { get; set; } + public string InternalAdmin { get; set; } + public string InternalIdentity { get; set; } + public string InternalApi { get; set; } + public string InternalVault { get; set; } + public string InternalSso { get; set; } + public string InternalScim { get; set; } + } } diff --git a/src/Core/Settings/IConnectionStringSettings.cs b/src/Core/Settings/IConnectionStringSettings.cs index 5b67dc9ca9..aff2b06270 100644 --- a/src/Core/Settings/IConnectionStringSettings.cs +++ b/src/Core/Settings/IConnectionStringSettings.cs @@ -1,6 +1,8 @@ -namespace Bit.Core.Settings; - -public interface IConnectionStringSettings +namespace Bit.Core.Settings { - string ConnectionString { get; set; } + + public interface IConnectionStringSettings + { + string ConnectionString { get; set; } + } } diff --git a/src/Core/Settings/IFileStorageSettings.cs b/src/Core/Settings/IFileStorageSettings.cs index 44546042d9..45e44802da 100644 --- a/src/Core/Settings/IFileStorageSettings.cs +++ b/src/Core/Settings/IFileStorageSettings.cs @@ -1,8 +1,9 @@ -namespace Bit.Core.Settings; - -public interface IFileStorageSettings +namespace Bit.Core.Settings { - string ConnectionString { get; set; } - string BaseDirectory { get; set; } - string BaseUrl { get; set; } + public interface IFileStorageSettings + { + string ConnectionString { get; set; } + string BaseDirectory { get; set; } + string BaseUrl { get; set; } + } } diff --git a/src/Core/Settings/IGlobalSettings.cs b/src/Core/Settings/IGlobalSettings.cs index 1929da1f3d..ec648384e1 100644 --- a/src/Core/Settings/IGlobalSettings.cs +++ b/src/Core/Settings/IGlobalSettings.cs @@ -1,18 +1,19 @@ -namespace Bit.Core.Settings; - -public interface IGlobalSettings +namespace Bit.Core.Settings { - // This interface exists for testing. Add settings here as needed for testing - bool SelfHosted { get; set; } - bool EnableCloudCommunication { get; set; } - string LicenseDirectory { get; set; } - string LicenseCertificatePassword { get; set; } - int OrganizationInviteExpirationHours { get; set; } - bool DisableUserRegistration { get; set; } - IInstallationSettings Installation { get; set; } - IFileStorageSettings Attachment { get; set; } - IConnectionStringSettings Storage { get; set; } - IBaseServiceUriSettings BaseServiceUri { get; set; } - ITwoFactorAuthSettings TwoFactorAuth { get; set; } - ISsoSettings Sso { get; set; } + public interface IGlobalSettings + { + // This interface exists for testing. Add settings here as needed for testing + bool SelfHosted { get; set; } + bool EnableCloudCommunication { get; set; } + string LicenseDirectory { get; set; } + string LicenseCertificatePassword { get; set; } + int OrganizationInviteExpirationHours { get; set; } + bool DisableUserRegistration { get; set; } + IInstallationSettings Installation { get; set; } + IFileStorageSettings Attachment { get; set; } + IConnectionStringSettings Storage { get; set; } + IBaseServiceUriSettings BaseServiceUri { get; set; } + ITwoFactorAuthSettings TwoFactorAuth { get; set; } + ISsoSettings Sso { get; set; } + } } diff --git a/src/Core/Settings/IInstallationSettings.cs b/src/Core/Settings/IInstallationSettings.cs index 6f56a3fa0f..dbc966d541 100644 --- a/src/Core/Settings/IInstallationSettings.cs +++ b/src/Core/Settings/IInstallationSettings.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Settings; - -public interface IInstallationSettings +namespace Bit.Core.Settings { - public Guid Id { get; set; } - public string Key { get; set; } - public string IdentityUri { get; set; } - public string ApiUri { get; } + public interface IInstallationSettings + { + public Guid Id { get; set; } + public string Key { get; set; } + public string IdentityUri { get; set; } + public string ApiUri { get; } + } } diff --git a/src/Core/Settings/ISsoSettings.cs b/src/Core/Settings/ISsoSettings.cs index c7429baef2..de5193cef4 100644 --- a/src/Core/Settings/ISsoSettings.cs +++ b/src/Core/Settings/ISsoSettings.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Settings; - -public interface ISsoSettings +namespace Bit.Core.Settings { - int CacheLifetimeInSeconds { get; set; } - double SsoTokenLifetimeInSeconds { get; set; } + public interface ISsoSettings + { + int CacheLifetimeInSeconds { get; set; } + double SsoTokenLifetimeInSeconds { get; set; } + } } diff --git a/src/Core/Settings/ITwoFactorAuthSettings.cs b/src/Core/Settings/ITwoFactorAuthSettings.cs index 2e11e65079..06dced0f8d 100644 --- a/src/Core/Settings/ITwoFactorAuthSettings.cs +++ b/src/Core/Settings/ITwoFactorAuthSettings.cs @@ -1,6 +1,7 @@ -namespace Bit.Core.Settings; - -public interface ITwoFactorAuthSettings +namespace Bit.Core.Settings { - bool EmailOnNewDeviceLogin { get; set; } + public interface ITwoFactorAuthSettings + { + bool EmailOnNewDeviceLogin { get; set; } + } } diff --git a/src/Core/Sso/SamlSigningAlgorithms.cs b/src/Core/Sso/SamlSigningAlgorithms.cs index 68ad8e5fa5..fba67a4ab7 100644 --- a/src/Core/Sso/SamlSigningAlgorithms.cs +++ b/src/Core/Sso/SamlSigningAlgorithms.cs @@ -1,18 +1,19 @@ -namespace Bit.Core.Sso; - -public static class SamlSigningAlgorithms +namespace Bit.Core.Sso { - public const string Default = Sha256; - public const string Sha256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"; - public const string Sha384 = "http://www.w3.org/2000/09/xmldsig#rsa-sha384"; - public const string Sha512 = "http://www.w3.org/2000/09/xmldsig#rsa-sha512"; - public const string Sha1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"; - - public static IEnumerable GetEnumerable() + public static class SamlSigningAlgorithms { - yield return Sha256; - yield return Sha384; - yield return Sha512; - yield return Sha1; + public const string Default = Sha256; + public const string Sha256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"; + public const string Sha384 = "http://www.w3.org/2000/09/xmldsig#rsa-sha384"; + public const string Sha512 = "http://www.w3.org/2000/09/xmldsig#rsa-sha512"; + public const string Sha1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"; + + public static IEnumerable GetEnumerable() + { + yield return Sha256; + yield return Sha384; + yield return Sha512; + yield return Sha1; + } } } diff --git a/src/Core/Tokens/BadTokenException.cs b/src/Core/Tokens/BadTokenException.cs index ca2dcac498..ffd9cb5205 100644 --- a/src/Core/Tokens/BadTokenException.cs +++ b/src/Core/Tokens/BadTokenException.cs @@ -1,12 +1,13 @@ -namespace Bit.Core.Tokens; - -public class BadTokenException : Exception +namespace Bit.Core.Tokens { - public BadTokenException() + public class BadTokenException : Exception { - } + public BadTokenException() + { + } - public BadTokenException(string message) : base(message) - { + public BadTokenException(string message) : base(message) + { + } } } diff --git a/src/Core/Tokens/DataProtectorTokenFactory.cs b/src/Core/Tokens/DataProtectorTokenFactory.cs index e0ec9811f6..8029b35547 100644 --- a/src/Core/Tokens/DataProtectorTokenFactory.cs +++ b/src/Core/Tokens/DataProtectorTokenFactory.cs @@ -1,54 +1,55 @@ using Microsoft.AspNetCore.DataProtection; -namespace Bit.Core.Tokens; - -public class DataProtectorTokenFactory : IDataProtectorTokenFactory where T : Tokenable +namespace Bit.Core.Tokens { - private readonly IDataProtector _dataProtector; - private readonly string _clearTextPrefix; - - public DataProtectorTokenFactory(string clearTextPrefix, string purpose, IDataProtectionProvider dataProtectionProvider) + public class DataProtectorTokenFactory : IDataProtectorTokenFactory where T : Tokenable { - _dataProtector = dataProtectionProvider.CreateProtector(purpose); - _clearTextPrefix = clearTextPrefix; - } + private readonly IDataProtector _dataProtector; + private readonly string _clearTextPrefix; - public string Protect(T data) => - data.ToToken().ProtectWith(_dataProtector).WithPrefix(_clearTextPrefix).ToString(); - - /// - /// Unprotect token - /// - /// The token to parse - /// The tokenable type to parse to - /// The parsed tokenable - /// Throws CryptographicException if fails to unprotect - public T Unprotect(string token) => - Tokenable.FromToken(new Token(token).RemovePrefix(_clearTextPrefix).UnprotectWith(_dataProtector).ToString()); - - public bool TokenValid(string token) - { - try + public DataProtectorTokenFactory(string clearTextPrefix, string purpose, IDataProtectionProvider dataProtectionProvider) { - return Unprotect(token).Valid; + _dataProtector = dataProtectionProvider.CreateProtector(purpose); + _clearTextPrefix = clearTextPrefix; } - catch - { - return false; - } - } - public bool TryUnprotect(string token, out T data) - { - try + public string Protect(T data) => + data.ToToken().ProtectWith(_dataProtector).WithPrefix(_clearTextPrefix).ToString(); + + /// + /// Unprotect token + /// + /// The token to parse + /// The tokenable type to parse to + /// The parsed tokenable + /// Throws CryptographicException if fails to unprotect + public T Unprotect(string token) => + Tokenable.FromToken(new Token(token).RemovePrefix(_clearTextPrefix).UnprotectWith(_dataProtector).ToString()); + + public bool TokenValid(string token) { - data = Unprotect(token); - return true; + try + { + return Unprotect(token).Valid; + } + catch + { + return false; + } } - catch + + public bool TryUnprotect(string token, out T data) { - data = default; - return false; + try + { + data = Unprotect(token); + return true; + } + catch + { + data = default; + return false; + } } } } diff --git a/src/Core/Tokens/ExpiringTokenable.cs b/src/Core/Tokens/ExpiringTokenable.cs index 089405e536..37907bbe3f 100644 --- a/src/Core/Tokens/ExpiringTokenable.cs +++ b/src/Core/Tokens/ExpiringTokenable.cs @@ -1,13 +1,14 @@ using System.Text.Json.Serialization; using Bit.Core.Utilities; -namespace Bit.Core.Tokens; - -public abstract class ExpiringTokenable : Tokenable +namespace Bit.Core.Tokens { - [JsonConverter(typeof(EpochDateTimeJsonConverter))] - public DateTime ExpirationDate { get; set; } - public override bool Valid => ExpirationDate > DateTime.UtcNow && TokenIsValid(); + public abstract class ExpiringTokenable : Tokenable + { + [JsonConverter(typeof(EpochDateTimeJsonConverter))] + public DateTime ExpirationDate { get; set; } + public override bool Valid => ExpirationDate > DateTime.UtcNow && TokenIsValid(); - protected abstract bool TokenIsValid(); + protected abstract bool TokenIsValid(); + } } diff --git a/src/Core/Tokens/IBillingSyncTokenable.cs b/src/Core/Tokens/IBillingSyncTokenable.cs index d63df0cc7f..a9fdc06bd9 100644 --- a/src/Core/Tokens/IBillingSyncTokenable.cs +++ b/src/Core/Tokens/IBillingSyncTokenable.cs @@ -1,7 +1,8 @@ -namespace Bit.Core.Tokens; - -public interface IBillingSyncTokenable +namespace Bit.Core.Tokens { - public Guid OrganizationId { get; set; } - public string BillingSyncKey { get; set; } + public interface IBillingSyncTokenable + { + public Guid OrganizationId { get; set; } + public string BillingSyncKey { get; set; } + } } diff --git a/src/Core/Tokens/IDataProtectorTokenFactory.cs b/src/Core/Tokens/IDataProtectorTokenFactory.cs index 3809c40dab..038eff0f7d 100644 --- a/src/Core/Tokens/IDataProtectorTokenFactory.cs +++ b/src/Core/Tokens/IDataProtectorTokenFactory.cs @@ -1,9 +1,10 @@ -namespace Bit.Core.Tokens; - -public interface IDataProtectorTokenFactory where T : Tokenable +namespace Bit.Core.Tokens { - string Protect(T data); - T Unprotect(string token); - bool TryUnprotect(string token, out T data); - bool TokenValid(string token); + public interface IDataProtectorTokenFactory where T : Tokenable + { + string Protect(T data); + T Unprotect(string token); + bool TryUnprotect(string token, out T data); + bool TokenValid(string token); + } } diff --git a/src/Core/Tokens/Token.cs b/src/Core/Tokens/Token.cs index a50b81fbbb..396b8747d5 100644 --- a/src/Core/Tokens/Token.cs +++ b/src/Core/Tokens/Token.cs @@ -1,36 +1,37 @@ using Microsoft.AspNetCore.DataProtection; -namespace Bit.Core.Tokens; - -public class Token +namespace Bit.Core.Tokens { - private readonly string _token; - - public Token(string token) + public class Token { - _token = token; - } + private readonly string _token; - public Token WithPrefix(string prefix) - { - return new Token($"{prefix}{_token}"); - } - - public Token RemovePrefix(string expectedPrefix) - { - if (!_token.StartsWith(expectedPrefix)) + public Token(string token) { - throw new BadTokenException($"Expected prefix, {expectedPrefix}, was not present."); + _token = token; } - return new Token(_token[expectedPrefix.Length..]); + public Token WithPrefix(string prefix) + { + return new Token($"{prefix}{_token}"); + } + + public Token RemovePrefix(string expectedPrefix) + { + if (!_token.StartsWith(expectedPrefix)) + { + throw new BadTokenException($"Expected prefix, {expectedPrefix}, was not present."); + } + + return new Token(_token[expectedPrefix.Length..]); + } + + public Token ProtectWith(IDataProtector dataProtector) => + new(dataProtector.Protect(ToString())); + + public Token UnprotectWith(IDataProtector dataProtector) => + new(dataProtector.Unprotect(ToString())); + + public override string ToString() => _token; } - - public Token ProtectWith(IDataProtector dataProtector) => - new(dataProtector.Protect(ToString())); - - public Token UnprotectWith(IDataProtector dataProtector) => - new(dataProtector.Unprotect(ToString())); - - public override string ToString() => _token; } diff --git a/src/Core/Tokens/Tokenable.cs b/src/Core/Tokens/Tokenable.cs index a145e64bb5..c5c57c2f74 100644 --- a/src/Core/Tokens/Tokenable.cs +++ b/src/Core/Tokens/Tokenable.cs @@ -1,19 +1,20 @@ using System.Text.Json; -namespace Bit.Core.Tokens; - -public abstract class Tokenable +namespace Bit.Core.Tokens { - public abstract bool Valid { get; } - - public Token ToToken() + public abstract class Tokenable { - return new Token(JsonSerializer.Serialize(this, this.GetType())); - } + public abstract bool Valid { get; } - public static T FromToken(string token) => FromToken(new Token(token)); - public static T FromToken(Token token) - { - return JsonSerializer.Deserialize(token.ToString()); + public Token ToToken() + { + return new Token(JsonSerializer.Serialize(this, this.GetType())); + } + + public static T FromToken(string token) => FromToken(new Token(token)); + public static T FromToken(Token token) + { + return JsonSerializer.Deserialize(token.ToString()); + } } } diff --git a/src/Core/Utilities/BillingHelpers.cs b/src/Core/Utilities/BillingHelpers.cs index e7ccfc3547..41202a2b41 100644 --- a/src/Core/Utilities/BillingHelpers.cs +++ b/src/Core/Utilities/BillingHelpers.cs @@ -2,56 +2,57 @@ using Bit.Core.Exceptions; using Bit.Core.Services; -namespace Bit.Core.Utilities; - -public static class BillingHelpers +namespace Bit.Core.Utilities { - internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, - short storageAdjustmentGb, string storagePlanId) + public static class BillingHelpers { - if (storableSubscriber == null) + internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, + short storageAdjustmentGb, string storagePlanId) { - throw new ArgumentNullException(nameof(storableSubscriber)); - } + if (storableSubscriber == null) + { + throw new ArgumentNullException(nameof(storableSubscriber)); + } - if (string.IsNullOrWhiteSpace(storableSubscriber.GatewayCustomerId)) - { - throw new BadRequestException("No payment method found."); - } + if (string.IsNullOrWhiteSpace(storableSubscriber.GatewayCustomerId)) + { + throw new BadRequestException("No payment method found."); + } - if (string.IsNullOrWhiteSpace(storableSubscriber.GatewaySubscriptionId)) - { - throw new BadRequestException("No subscription found."); - } + if (string.IsNullOrWhiteSpace(storableSubscriber.GatewaySubscriptionId)) + { + throw new BadRequestException("No subscription found."); + } - if (!storableSubscriber.MaxStorageGb.HasValue) - { - throw new BadRequestException("No access to storage."); - } + if (!storableSubscriber.MaxStorageGb.HasValue) + { + throw new BadRequestException("No access to storage."); + } - var newStorageGb = (short)(storableSubscriber.MaxStorageGb.Value + storageAdjustmentGb); - if (newStorageGb < 1) - { - newStorageGb = 1; - } + var newStorageGb = (short)(storableSubscriber.MaxStorageGb.Value + storageAdjustmentGb); + if (newStorageGb < 1) + { + newStorageGb = 1; + } - if (newStorageGb > 100) - { - throw new BadRequestException("Maximum storage is 100 GB."); - } + if (newStorageGb > 100) + { + throw new BadRequestException("Maximum storage is 100 GB."); + } - var remainingStorage = storableSubscriber.StorageBytesRemaining(newStorageGb); - if (remainingStorage < 0) - { - throw new BadRequestException("You are currently using " + - $"{CoreHelpers.ReadableBytesSize(storableSubscriber.Storage.GetValueOrDefault(0))} of storage. " + - "Delete some stored data first."); - } + var remainingStorage = storableSubscriber.StorageBytesRemaining(newStorageGb); + if (remainingStorage < 0) + { + throw new BadRequestException("You are currently using " + + $"{CoreHelpers.ReadableBytesSize(storableSubscriber.Storage.GetValueOrDefault(0))} of storage. " + + "Delete some stored data first."); + } - var additionalStorage = newStorageGb - 1; - var paymentIntentClientSecret = await paymentService.AdjustStorageAsync(storableSubscriber, - additionalStorage, storagePlanId); - storableSubscriber.MaxStorageGb = newStorageGb; - return paymentIntentClientSecret; + var additionalStorage = newStorageGb - 1; + var paymentIntentClientSecret = await paymentService.AdjustStorageAsync(storableSubscriber, + additionalStorage, storagePlanId); + storableSubscriber.MaxStorageGb = newStorageGb; + return paymentIntentClientSecret; + } } } diff --git a/src/Core/Utilities/BitPayClient.cs b/src/Core/Utilities/BitPayClient.cs index 35a078998d..2532e8476f 100644 --- a/src/Core/Utilities/BitPayClient.cs +++ b/src/Core/Utilities/BitPayClient.cs @@ -1,27 +1,28 @@ using Bit.Core.Settings; -namespace Bit.Core.Utilities; - -public class BitPayClient +namespace Bit.Core.Utilities { - private readonly BitPayLight.BitPay _bpClient; - - public BitPayClient(GlobalSettings globalSettings) + public class BitPayClient { - if (CoreHelpers.SettingHasValue(globalSettings.BitPay.Token)) + private readonly BitPayLight.BitPay _bpClient; + + public BitPayClient(GlobalSettings globalSettings) { - _bpClient = new BitPayLight.BitPay(globalSettings.BitPay.Token, - globalSettings.BitPay.Production ? BitPayLight.Env.Prod : BitPayLight.Env.Test); + if (CoreHelpers.SettingHasValue(globalSettings.BitPay.Token)) + { + _bpClient = new BitPayLight.BitPay(globalSettings.BitPay.Token, + globalSettings.BitPay.Production ? BitPayLight.Env.Prod : BitPayLight.Env.Test); + } + } + + public Task GetInvoiceAsync(string id) + { + return _bpClient.GetInvoice(id); + } + + public Task CreateInvoiceAsync(BitPayLight.Models.Invoice.Invoice invoice) + { + return _bpClient.CreateInvoice(invoice); } } - - public Task GetInvoiceAsync(string id) - { - return _bpClient.GetInvoice(id); - } - - public Task CreateInvoiceAsync(BitPayLight.Models.Invoice.Invoice invoice) - { - return _bpClient.CreateInvoice(invoice); - } } diff --git a/src/Core/Utilities/CaptchaProtectedAttribute.cs b/src/Core/Utilities/CaptchaProtectedAttribute.cs index 6a5de6a9df..102f1f175a 100644 --- a/src/Core/Utilities/CaptchaProtectedAttribute.cs +++ b/src/Core/Utilities/CaptchaProtectedAttribute.cs @@ -5,31 +5,32 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Utilities; - -public class CaptchaProtectedAttribute : ActionFilterAttribute +namespace Bit.Core.Utilities { - public string ModelParameterName { get; set; } = "model"; - - public override void OnActionExecuting(ActionExecutingContext context) + public class CaptchaProtectedAttribute : ActionFilterAttribute { - var currentContext = context.HttpContext.RequestServices.GetRequiredService(); - var captchaValidationService = context.HttpContext.RequestServices.GetRequiredService(); + public string ModelParameterName { get; set; } = "model"; - if (captchaValidationService.RequireCaptchaValidation(currentContext, null)) + public override void OnActionExecuting(ActionExecutingContext context) { - var captchaResponse = (context.ActionArguments[ModelParameterName] as ICaptchaProtectedModel)?.CaptchaResponse; + var currentContext = context.HttpContext.RequestServices.GetRequiredService(); + var captchaValidationService = context.HttpContext.RequestServices.GetRequiredService(); - if (string.IsNullOrWhiteSpace(captchaResponse)) + if (captchaValidationService.RequireCaptchaValidation(currentContext, null)) { - throw new BadRequestException(captchaValidationService.SiteKeyResponseKeyName, captchaValidationService.SiteKey); - } + var captchaResponse = (context.ActionArguments[ModelParameterName] as ICaptchaProtectedModel)?.CaptchaResponse; - var captchaValidationResponse = captchaValidationService.ValidateCaptchaResponseAsync(captchaResponse, - currentContext.IpAddress, null).GetAwaiter().GetResult(); - if (!captchaValidationResponse.Success || captchaValidationResponse.IsBot) - { - throw new BadRequestException("Captcha is invalid. Please refresh and try again"); + if (string.IsNullOrWhiteSpace(captchaResponse)) + { + throw new BadRequestException(captchaValidationService.SiteKeyResponseKeyName, captchaValidationService.SiteKey); + } + + var captchaValidationResponse = captchaValidationService.ValidateCaptchaResponseAsync(captchaResponse, + currentContext.IpAddress, null).GetAwaiter().GetResult(); + if (!captchaValidationResponse.Success || captchaValidationResponse.IsBot) + { + throw new BadRequestException("Captcha is invalid. Please refresh and try again"); + } } } } diff --git a/src/Core/Utilities/ClaimsExtensions.cs b/src/Core/Utilities/ClaimsExtensions.cs index 75478869e2..ef25d1483c 100644 --- a/src/Core/Utilities/ClaimsExtensions.cs +++ b/src/Core/Utilities/ClaimsExtensions.cs @@ -1,11 +1,12 @@ using System.Security.Claims; -namespace Bit.Core.Utilities; - -public static class ClaimsExtensions +namespace Bit.Core.Utilities { - public static bool HasSsoIdP(this IEnumerable claims) + public static class ClaimsExtensions { - return claims.Any(c => c.Type == "idp" && c.Value == "sso"); + public static bool HasSsoIdP(this IEnumerable claims) + { + return claims.Any(c => c.Type == "idp" && c.Value == "sso"); + } } } diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index ef6848cf15..7ad850a266 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -19,834 +19,835 @@ using IdentityModel; using Microsoft.AspNetCore.DataProtection; using MimeKit; -namespace Bit.Core.Utilities; - -public static class CoreHelpers +namespace Bit.Core.Utilities { - private static readonly long _baseDateTicks = new DateTime(1900, 1, 1).Ticks; - private static readonly DateTime _epoc = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); - private static readonly DateTime _max = new DateTime(9999, 1, 1, 0, 0, 0, DateTimeKind.Utc); - private static readonly Random _random = new Random(); - private static string _version; - private static readonly string CloudFlareConnectingIp = "CF-Connecting-IP"; - private static readonly string RealIp = "X-Real-IP"; - - /// - /// Generate sequential Guid for Sql Server. - /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs - /// - /// A comb Guid. - public static Guid GenerateComb() - => GenerateComb(Guid.NewGuid(), DateTime.UtcNow); - - /// - /// Implementation of with input parameters to remove randomness. - /// This should NOT be used outside of testing. - /// - /// - /// You probably don't want to use this method and instead want to use with no parameters - /// - internal static Guid GenerateComb(Guid startingGuid, DateTime time) + public static class CoreHelpers { - var guidArray = startingGuid.ToByteArray(); + private static readonly long _baseDateTicks = new DateTime(1900, 1, 1).Ticks; + private static readonly DateTime _epoc = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); + private static readonly DateTime _max = new DateTime(9999, 1, 1, 0, 0, 0, DateTimeKind.Utc); + private static readonly Random _random = new Random(); + private static string _version; + private static readonly string CloudFlareConnectingIp = "CF-Connecting-IP"; + private static readonly string RealIp = "X-Real-IP"; - // Get the days and milliseconds which will be used to build the byte string - var days = new TimeSpan(time.Ticks - _baseDateTicks); - var msecs = time.TimeOfDay; + /// + /// Generate sequential Guid for Sql Server. + /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs + /// + /// A comb Guid. + public static Guid GenerateComb() + => GenerateComb(Guid.NewGuid(), DateTime.UtcNow); - // Convert to a byte array - // Note that SQL Server is accurate to 1/300th of a millisecond so we divide by 3.333333 - var daysArray = BitConverter.GetBytes(days.Days); - var msecsArray = BitConverter.GetBytes((long)(msecs.TotalMilliseconds / 3.333333)); - - // Reverse the bytes to match SQL Servers ordering - Array.Reverse(daysArray); - Array.Reverse(msecsArray); - - // Copy the bytes into the guid - Array.Copy(daysArray, daysArray.Length - 2, guidArray, guidArray.Length - 6, 2); - Array.Copy(msecsArray, msecsArray.Length - 4, guidArray, guidArray.Length - 4, 4); - - return new Guid(guidArray); - } - - public static IEnumerable> Batch(this IEnumerable source, int size) - { - T[] bucket = null; - var count = 0; - foreach (var item in source) + /// + /// Implementation of with input parameters to remove randomness. + /// This should NOT be used outside of testing. + /// + /// + /// You probably don't want to use this method and instead want to use with no parameters + /// + internal static Guid GenerateComb(Guid startingGuid, DateTime time) { - if (bucket == null) + var guidArray = startingGuid.ToByteArray(); + + // Get the days and milliseconds which will be used to build the byte string + var days = new TimeSpan(time.Ticks - _baseDateTicks); + var msecs = time.TimeOfDay; + + // Convert to a byte array + // Note that SQL Server is accurate to 1/300th of a millisecond so we divide by 3.333333 + var daysArray = BitConverter.GetBytes(days.Days); + var msecsArray = BitConverter.GetBytes((long)(msecs.TotalMilliseconds / 3.333333)); + + // Reverse the bytes to match SQL Servers ordering + Array.Reverse(daysArray); + Array.Reverse(msecsArray); + + // Copy the bytes into the guid + Array.Copy(daysArray, daysArray.Length - 2, guidArray, guidArray.Length - 6, 2); + Array.Copy(msecsArray, msecsArray.Length - 4, guidArray, guidArray.Length - 4, 4); + + return new Guid(guidArray); + } + + public static IEnumerable> Batch(this IEnumerable source, int size) + { + T[] bucket = null; + var count = 0; + foreach (var item in source) { - bucket = new T[size]; - } - bucket[count++] = item; - if (count != size) - { - continue; - } - yield return bucket.Select(x => x); - bucket = null; - count = 0; - } - // Return the last bucket with all remaining elements - if (bucket != null && count > 0) - { - yield return bucket.Take(count); - } - } - - public static string CleanCertificateThumbprint(string thumbprint) - { - // Clean possible garbage characters from thumbprint copy/paste - // ref http://stackoverflow.com/questions/8448147/problems-with-x509store-certificates-find-findbythumbprint - return Regex.Replace(thumbprint, @"[^\da-fA-F]", string.Empty).ToUpper(); - } - - public static X509Certificate2 GetCertificate(string thumbprint) - { - thumbprint = CleanCertificateThumbprint(thumbprint); - - X509Certificate2 cert = null; - var certStore = new X509Store(StoreName.My, StoreLocation.CurrentUser); - certStore.Open(OpenFlags.ReadOnly); - var certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false); - if (certCollection.Count > 0) - { - cert = certCollection[0]; - } - - certStore.Close(); - return cert; - } - - public static X509Certificate2 GetCertificate(string file, string password) - { - return new X509Certificate2(file, password); - } - - public async static Task GetEmbeddedCertificateAsync(string file, string password) - { - var assembly = typeof(CoreHelpers).GetTypeInfo().Assembly; - using (var s = assembly.GetManifestResourceStream($"Bit.Core.{file}")) - using (var ms = new MemoryStream()) - { - await s.CopyToAsync(ms); - return new X509Certificate2(ms.ToArray(), password); - } - } - - public static string GetEmbeddedResourceContentsAsync(string file) - { - var assembly = Assembly.GetCallingAssembly(); - var resourceName = assembly.GetManifestResourceNames().Single(n => n.EndsWith(file)); - using (var stream = assembly.GetManifestResourceStream(resourceName)) - using (var reader = new StreamReader(stream)) - { - return reader.ReadToEnd(); - } - } - - public async static Task GetBlobCertificateAsync(string connectionString, string container, string file, string password) - { - try - { - var blobServiceClient = new BlobServiceClient(connectionString); - var containerRef2 = blobServiceClient.GetBlobContainerClient(container); - var blobRef = containerRef2.GetBlobClient(file); - - using var memStream = new MemoryStream(); - await blobRef.DownloadToAsync(memStream).ConfigureAwait(false); - return new X509Certificate2(memStream.ToArray(), password); - } - catch (RequestFailedException ex) - when (ex.ErrorCode == BlobErrorCode.ContainerNotFound || ex.ErrorCode == BlobErrorCode.BlobNotFound) - { - return null; - } - catch (Exception) - { - return null; - } - } - - public static long ToEpocMilliseconds(DateTime date) - { - return (long)Math.Round((date - _epoc).TotalMilliseconds, 0); - } - - public static DateTime FromEpocMilliseconds(long milliseconds) - { - return _epoc.AddMilliseconds(milliseconds); - } - - public static long ToEpocSeconds(DateTime date) - { - return (long)Math.Round((date - _epoc).TotalSeconds, 0); - } - - public static DateTime FromEpocSeconds(long seconds) - { - return _epoc.AddSeconds(seconds); - } - - public static string U2fAppIdUrl(GlobalSettings globalSettings) - { - return string.Concat(globalSettings.BaseServiceUri.Vault, "/app-id.json"); - } - - public static string RandomString(int length, bool alpha = true, bool upper = true, bool lower = true, - bool numeric = true, bool special = false) - { - return RandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); - } - - public static string RandomString(int length, string characters) - { - return new string(Enumerable.Repeat(characters, length).Select(s => s[_random.Next(s.Length)]).ToArray()); - } - - public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, - bool numeric = true, bool special = false) - { - return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); - } - - // ref https://stackoverflow.com/a/8996788/1090359 with modifications - public static string SecureRandomString(int length, string characters) - { - if (length < 0) - { - throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); - } - - if ((characters?.Length ?? 0) == 0) - { - throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); - } - - const int byteSize = 0x100; - if (byteSize < characters.Length) - { - throw new ArgumentException( - string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), - nameof(characters)); - } - - var outOfRangeStart = byteSize - (byteSize % characters.Length); - using (var rng = RandomNumberGenerator.Create()) - { - var sb = new StringBuilder(); - var buffer = new byte[128]; - while (sb.Length < length) - { - rng.GetBytes(buffer); - for (var i = 0; i < buffer.Length && sb.Length < length; ++i) + if (bucket == null) { - // Divide the byte into charSet-sized groups. If the random value falls into the last group and the - // last group is too small to choose from the entire allowedCharSet, ignore the value in order to - // avoid biasing the result. - if (outOfRangeStart <= buffer[i]) - { - continue; - } + bucket = new T[size]; + } + bucket[count++] = item; + if (count != size) + { + continue; + } + yield return bucket.Select(x => x); + bucket = null; + count = 0; + } + // Return the last bucket with all remaining elements + if (bucket != null && count > 0) + { + yield return bucket.Take(count); + } + } - sb.Append(characters[buffer[i] % characters.Length]); + public static string CleanCertificateThumbprint(string thumbprint) + { + // Clean possible garbage characters from thumbprint copy/paste + // ref http://stackoverflow.com/questions/8448147/problems-with-x509store-certificates-find-findbythumbprint + return Regex.Replace(thumbprint, @"[^\da-fA-F]", string.Empty).ToUpper(); + } + + public static X509Certificate2 GetCertificate(string thumbprint) + { + thumbprint = CleanCertificateThumbprint(thumbprint); + + X509Certificate2 cert = null; + var certStore = new X509Store(StoreName.My, StoreLocation.CurrentUser); + certStore.Open(OpenFlags.ReadOnly); + var certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false); + if (certCollection.Count > 0) + { + cert = certCollection[0]; + } + + certStore.Close(); + return cert; + } + + public static X509Certificate2 GetCertificate(string file, string password) + { + return new X509Certificate2(file, password); + } + + public async static Task GetEmbeddedCertificateAsync(string file, string password) + { + var assembly = typeof(CoreHelpers).GetTypeInfo().Assembly; + using (var s = assembly.GetManifestResourceStream($"Bit.Core.{file}")) + using (var ms = new MemoryStream()) + { + await s.CopyToAsync(ms); + return new X509Certificate2(ms.ToArray(), password); + } + } + + public static string GetEmbeddedResourceContentsAsync(string file) + { + var assembly = Assembly.GetCallingAssembly(); + var resourceName = assembly.GetManifestResourceNames().Single(n => n.EndsWith(file)); + using (var stream = assembly.GetManifestResourceStream(resourceName)) + using (var reader = new StreamReader(stream)) + { + return reader.ReadToEnd(); + } + } + + public async static Task GetBlobCertificateAsync(string connectionString, string container, string file, string password) + { + try + { + var blobServiceClient = new BlobServiceClient(connectionString); + var containerRef2 = blobServiceClient.GetBlobContainerClient(container); + var blobRef = containerRef2.GetBlobClient(file); + + using var memStream = new MemoryStream(); + await blobRef.DownloadToAsync(memStream).ConfigureAwait(false); + return new X509Certificate2(memStream.ToArray(), password); + } + catch (RequestFailedException ex) + when (ex.ErrorCode == BlobErrorCode.ContainerNotFound || ex.ErrorCode == BlobErrorCode.BlobNotFound) + { + return null; + } + catch (Exception) + { + return null; + } + } + + public static long ToEpocMilliseconds(DateTime date) + { + return (long)Math.Round((date - _epoc).TotalMilliseconds, 0); + } + + public static DateTime FromEpocMilliseconds(long milliseconds) + { + return _epoc.AddMilliseconds(milliseconds); + } + + public static long ToEpocSeconds(DateTime date) + { + return (long)Math.Round((date - _epoc).TotalSeconds, 0); + } + + public static DateTime FromEpocSeconds(long seconds) + { + return _epoc.AddSeconds(seconds); + } + + public static string U2fAppIdUrl(GlobalSettings globalSettings) + { + return string.Concat(globalSettings.BaseServiceUri.Vault, "/app-id.json"); + } + + public static string RandomString(int length, bool alpha = true, bool upper = true, bool lower = true, + bool numeric = true, bool special = false) + { + return RandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); + } + + public static string RandomString(int length, string characters) + { + return new string(Enumerable.Repeat(characters, length).Select(s => s[_random.Next(s.Length)]).ToArray()); + } + + public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, + bool numeric = true, bool special = false) + { + return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); + } + + // ref https://stackoverflow.com/a/8996788/1090359 with modifications + public static string SecureRandomString(int length, string characters) + { + if (length < 0) + { + throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); + } + + if ((characters?.Length ?? 0) == 0) + { + throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); + } + + const int byteSize = 0x100; + if (byteSize < characters.Length) + { + throw new ArgumentException( + string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), + nameof(characters)); + } + + var outOfRangeStart = byteSize - (byteSize % characters.Length); + using (var rng = RandomNumberGenerator.Create()) + { + var sb = new StringBuilder(); + var buffer = new byte[128]; + while (sb.Length < length) + { + rng.GetBytes(buffer); + for (var i = 0; i < buffer.Length && sb.Length < length; ++i) + { + // Divide the byte into charSet-sized groups. If the random value falls into the last group and the + // last group is too small to choose from the entire allowedCharSet, ignore the value in order to + // avoid biasing the result. + if (outOfRangeStart <= buffer[i]) + { + continue; + } + + sb.Append(characters[buffer[i] % characters.Length]); + } + } + + return sb.ToString(); + } + } + + private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) + { + var characters = string.Empty; + if (alpha) + { + if (upper) + { + characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + } + + if (lower) + { + characters += "abcdefghijklmnopqrstuvwxyz"; } } - return sb.ToString(); - } - } - - private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) - { - var characters = string.Empty; - if (alpha) - { - if (upper) + if (numeric) { - characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + characters += "0123456789"; } - if (lower) + if (special) { - characters += "abcdefghijklmnopqrstuvwxyz"; + characters += "!@#$%^*&"; } + + return characters; } - if (numeric) + // ref: https://stackoverflow.com/a/11124118/1090359 + // Returns the human-readable file size for an arbitrary 64-bit file size . + // The format is "0.## XB", ex: "4.2 KB" or "1.43 GB" + public static string ReadableBytesSize(long size) { - characters += "0123456789"; - } + // Get absolute value + var absoluteSize = (size < 0 ? -size : size); - if (special) - { - characters += "!@#$%^*&"; - } - - return characters; - } - - // ref: https://stackoverflow.com/a/11124118/1090359 - // Returns the human-readable file size for an arbitrary 64-bit file size . - // The format is "0.## XB", ex: "4.2 KB" or "1.43 GB" - public static string ReadableBytesSize(long size) - { - // Get absolute value - var absoluteSize = (size < 0 ? -size : size); - - // Determine the suffix and readable value - string suffix; - double readable; - if (absoluteSize >= 0x40000000) // 1 Gigabyte - { - suffix = "GB"; - readable = (size >> 20); - } - else if (absoluteSize >= 0x100000) // 1 Megabyte - { - suffix = "MB"; - readable = (size >> 10); - } - else if (absoluteSize >= 0x400) // 1 Kilobyte - { - suffix = "KB"; - readable = size; - } - else - { - return size.ToString("0 Bytes"); // Byte - } - - // Divide by 1024 to get fractional value - readable = (readable / 1024); - - // Return formatted number with suffix - return readable.ToString("0.## ") + suffix; - } - - /// - /// Creates a clone of the given object through serializing to json and deserializing. - /// This method is subject to the limitations of System.Text.Json. For example, properties with - /// inaccessible setters will not be set. - /// - public static T CloneObject(T obj) - { - return JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - } - - public static bool SettingHasValue(string setting) - { - var normalizedSetting = setting?.ToLowerInvariant(); - return !string.IsNullOrWhiteSpace(normalizedSetting) && !normalizedSetting.Equals("secret") && - !normalizedSetting.Equals("replace"); - } - - public static string Base64EncodeString(string input) - { - return Convert.ToBase64String(Encoding.UTF8.GetBytes(input)); - } - - public static string Base64DecodeString(string input) - { - return Encoding.UTF8.GetString(Convert.FromBase64String(input)); - } - - public static string Base64UrlEncodeString(string input) - { - return Base64UrlEncode(Encoding.UTF8.GetBytes(input)); - } - - public static string Base64UrlDecodeString(string input) - { - return Encoding.UTF8.GetString(Base64UrlDecode(input)); - } - - public static string Base64UrlEncode(byte[] input) - { - var output = Convert.ToBase64String(input) - .Replace('+', '-') - .Replace('/', '_') - .Replace("=", string.Empty); - return output; - } - - public static byte[] Base64UrlDecode(string input) - { - var output = input; - // 62nd char of encoding - output = output.Replace('-', '+'); - // 63rd char of encoding - output = output.Replace('_', '/'); - // Pad with trailing '='s - switch (output.Length % 4) - { - case 0: - // No pad chars in this case - break; - case 2: - // Two pad chars - output += "=="; break; - case 3: - // One pad char - output += "="; break; - default: - throw new InvalidOperationException("Illegal base64url string!"); - } - - // Standard base64 decoder - return Convert.FromBase64String(output); - } - - public static string PunyEncode(string text) - { - if (text == "") - { - return ""; - } - - if (text == null) - { - return null; - } - - if (!text.Contains("@")) - { - // Assume domain name or non-email address - var idn = new IdnMapping(); - return idn.GetAscii(text); - } - else - { - // Assume email address - return MailboxAddress.EncodeAddrspec(text); - } - } - - public static string FormatLicenseSignatureValue(object val) - { - if (val == null) - { - return string.Empty; - } - - if (val.GetType() == typeof(DateTime)) - { - return ToEpocSeconds((DateTime)val).ToString(); - } - - if (val.GetType() == typeof(bool)) - { - return val.ToString().ToLowerInvariant(); - } - - if (val is PlanType planType) - { - return planType switch + // Determine the suffix and readable value + string suffix; + double readable; + if (absoluteSize >= 0x40000000) // 1 Gigabyte { - PlanType.Free => "Free", - PlanType.FamiliesAnnually2019 => "FamiliesAnnually", - PlanType.TeamsMonthly2019 => "TeamsMonthly", - PlanType.TeamsAnnually2019 => "TeamsAnnually", - PlanType.EnterpriseMonthly2019 => "EnterpriseMonthly", - PlanType.EnterpriseAnnually2019 => "EnterpriseAnnually", - PlanType.Custom => "Custom", - _ => ((byte)planType).ToString(), - }; - } - - return val.ToString(); - } - - public static string GetVersion() - { - if (string.IsNullOrWhiteSpace(_version)) - { - _version = Assembly.GetEntryAssembly() - .GetCustomAttribute() - .InformationalVersion; - } - - return _version; - } - - public static string SanitizeForEmail(string value, bool htmlEncode = true) - { - var cleanedValue = value.Replace("@", "[at]"); - var regexOptions = RegexOptions.CultureInvariant | - RegexOptions.Singleline | - RegexOptions.IgnoreCase; - cleanedValue = Regex.Replace(cleanedValue, @"(\.\w)", - m => string.Concat("[dot]", m.ToString().Last()), regexOptions); - while (Regex.IsMatch(cleanedValue, @"((^|\b)(\w*)://)", regexOptions)) - { - cleanedValue = Regex.Replace(cleanedValue, @"((^|\b)(\w*)://)", - string.Empty, regexOptions); - } - return htmlEncode ? HttpUtility.HtmlEncode(cleanedValue) : cleanedValue; - } - - public static string DateTimeToTableStorageKey(DateTime? date = null) - { - if (date.HasValue) - { - date = date.Value.ToUniversalTime(); - } - else - { - date = DateTime.UtcNow; - } - - return _max.Subtract(date.Value).TotalMilliseconds.ToString(CultureInfo.InvariantCulture); - } - - // ref: https://stackoverflow.com/a/27545010/1090359 - public static Uri ExtendQuery(Uri uri, IDictionary values) - { - var baseUri = uri.ToString(); - var queryString = string.Empty; - if (baseUri.Contains("?")) - { - var urlSplit = baseUri.Split('?'); - baseUri = urlSplit[0]; - queryString = urlSplit.Length > 1 ? urlSplit[1] : string.Empty; - } - - var queryCollection = HttpUtility.ParseQueryString(queryString); - foreach (var kvp in values ?? new Dictionary()) - { - queryCollection[kvp.Key] = kvp.Value; - } - - var uriKind = uri.IsAbsoluteUri ? UriKind.Absolute : UriKind.Relative; - if (queryCollection.Count == 0) - { - return new Uri(baseUri, uriKind); - } - return new Uri(string.Format("{0}?{1}", baseUri, queryCollection), uriKind); - } - - public static string CustomProviderName(TwoFactorProviderType type) - { - return string.Concat("Custom_", type.ToString()); - } - - public static bool UserInviteTokenIsValid(IDataProtector protector, string token, string userEmail, - Guid orgUserId, IGlobalSettings globalSettings) - { - return TokenIsValid("OrganizationUserInvite", protector, token, userEmail, orgUserId, - globalSettings.OrganizationInviteExpirationHours); - } - - public static bool TokenIsValid(string firstTokenPart, IDataProtector protector, string token, string userEmail, - Guid id, double expirationInHours) - { - var invalid = true; - try - { - var unprotectedData = protector.Unprotect(token); - var dataParts = unprotectedData.Split(' '); - if (dataParts.Length == 4 && dataParts[0] == firstTokenPart && - new Guid(dataParts[1]) == id && - dataParts[2].Equals(userEmail, StringComparison.InvariantCultureIgnoreCase)) - { - var creationTime = FromEpocMilliseconds(Convert.ToInt64(dataParts[3])); - var expTime = creationTime.AddHours(expirationInHours); - invalid = expTime < DateTime.UtcNow; + suffix = "GB"; + readable = (size >> 20); } - } - catch - { - invalid = true; - } - - return !invalid; - } - - public static string GetApplicationCacheServiceBusSubcriptionName(GlobalSettings globalSettings) - { - var subName = globalSettings.ServiceBus.ApplicationCacheSubscriptionName; - if (string.IsNullOrWhiteSpace(subName)) - { - var websiteInstanceId = Environment.GetEnvironmentVariable("WEBSITE_INSTANCE_ID"); - if (string.IsNullOrWhiteSpace(websiteInstanceId)) + else if (absoluteSize >= 0x100000) // 1 Megabyte { - throw new Exception("No service bus subscription name available."); + suffix = "MB"; + readable = (size >> 10); + } + else if (absoluteSize >= 0x400) // 1 Kilobyte + { + suffix = "KB"; + readable = size; } else { - subName = $"{globalSettings.ProjectName.ToLower()}_{websiteInstanceId}"; - if (subName.Length > 50) - { - subName = subName.Substring(0, 50); - } + return size.ToString("0 Bytes"); // Byte + } + + // Divide by 1024 to get fractional value + readable = (readable / 1024); + + // Return formatted number with suffix + return readable.ToString("0.## ") + suffix; + } + + /// + /// Creates a clone of the given object through serializing to json and deserializing. + /// This method is subject to the limitations of System.Text.Json. For example, properties with + /// inaccessible setters will not be set. + /// + public static T CloneObject(T obj) + { + return JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + } + + public static bool SettingHasValue(string setting) + { + var normalizedSetting = setting?.ToLowerInvariant(); + return !string.IsNullOrWhiteSpace(normalizedSetting) && !normalizedSetting.Equals("secret") && + !normalizedSetting.Equals("replace"); + } + + public static string Base64EncodeString(string input) + { + return Convert.ToBase64String(Encoding.UTF8.GetBytes(input)); + } + + public static string Base64DecodeString(string input) + { + return Encoding.UTF8.GetString(Convert.FromBase64String(input)); + } + + public static string Base64UrlEncodeString(string input) + { + return Base64UrlEncode(Encoding.UTF8.GetBytes(input)); + } + + public static string Base64UrlDecodeString(string input) + { + return Encoding.UTF8.GetString(Base64UrlDecode(input)); + } + + public static string Base64UrlEncode(byte[] input) + { + var output = Convert.ToBase64String(input) + .Replace('+', '-') + .Replace('/', '_') + .Replace("=", string.Empty); + return output; + } + + public static byte[] Base64UrlDecode(string input) + { + var output = input; + // 62nd char of encoding + output = output.Replace('-', '+'); + // 63rd char of encoding + output = output.Replace('_', '/'); + // Pad with trailing '='s + switch (output.Length % 4) + { + case 0: + // No pad chars in this case + break; + case 2: + // Two pad chars + output += "=="; break; + case 3: + // One pad char + output += "="; break; + default: + throw new InvalidOperationException("Illegal base64url string!"); + } + + // Standard base64 decoder + return Convert.FromBase64String(output); + } + + public static string PunyEncode(string text) + { + if (text == "") + { + return ""; + } + + if (text == null) + { + return null; + } + + if (!text.Contains("@")) + { + // Assume domain name or non-email address + var idn = new IdnMapping(); + return idn.GetAscii(text); + } + else + { + // Assume email address + return MailboxAddress.EncodeAddrspec(text); } } - return subName; - } - public static string GetIpAddress(this Microsoft.AspNetCore.Http.HttpContext httpContext, - GlobalSettings globalSettings) - { - if (httpContext == null) + public static string FormatLicenseSignatureValue(object val) { + if (val == null) + { + return string.Empty; + } + + if (val.GetType() == typeof(DateTime)) + { + return ToEpocSeconds((DateTime)val).ToString(); + } + + if (val.GetType() == typeof(bool)) + { + return val.ToString().ToLowerInvariant(); + } + + if (val is PlanType planType) + { + return planType switch + { + PlanType.Free => "Free", + PlanType.FamiliesAnnually2019 => "FamiliesAnnually", + PlanType.TeamsMonthly2019 => "TeamsMonthly", + PlanType.TeamsAnnually2019 => "TeamsAnnually", + PlanType.EnterpriseMonthly2019 => "EnterpriseMonthly", + PlanType.EnterpriseAnnually2019 => "EnterpriseAnnually", + PlanType.Custom => "Custom", + _ => ((byte)planType).ToString(), + }; + } + + return val.ToString(); + } + + public static string GetVersion() + { + if (string.IsNullOrWhiteSpace(_version)) + { + _version = Assembly.GetEntryAssembly() + .GetCustomAttribute() + .InformationalVersion; + } + + return _version; + } + + public static string SanitizeForEmail(string value, bool htmlEncode = true) + { + var cleanedValue = value.Replace("@", "[at]"); + var regexOptions = RegexOptions.CultureInvariant | + RegexOptions.Singleline | + RegexOptions.IgnoreCase; + cleanedValue = Regex.Replace(cleanedValue, @"(\.\w)", + m => string.Concat("[dot]", m.ToString().Last()), regexOptions); + while (Regex.IsMatch(cleanedValue, @"((^|\b)(\w*)://)", regexOptions)) + { + cleanedValue = Regex.Replace(cleanedValue, @"((^|\b)(\w*)://)", + string.Empty, regexOptions); + } + return htmlEncode ? HttpUtility.HtmlEncode(cleanedValue) : cleanedValue; + } + + public static string DateTimeToTableStorageKey(DateTime? date = null) + { + if (date.HasValue) + { + date = date.Value.ToUniversalTime(); + } + else + { + date = DateTime.UtcNow; + } + + return _max.Subtract(date.Value).TotalMilliseconds.ToString(CultureInfo.InvariantCulture); + } + + // ref: https://stackoverflow.com/a/27545010/1090359 + public static Uri ExtendQuery(Uri uri, IDictionary values) + { + var baseUri = uri.ToString(); + var queryString = string.Empty; + if (baseUri.Contains("?")) + { + var urlSplit = baseUri.Split('?'); + baseUri = urlSplit[0]; + queryString = urlSplit.Length > 1 ? urlSplit[1] : string.Empty; + } + + var queryCollection = HttpUtility.ParseQueryString(queryString); + foreach (var kvp in values ?? new Dictionary()) + { + queryCollection[kvp.Key] = kvp.Value; + } + + var uriKind = uri.IsAbsoluteUri ? UriKind.Absolute : UriKind.Relative; + if (queryCollection.Count == 0) + { + return new Uri(baseUri, uriKind); + } + return new Uri(string.Format("{0}?{1}", baseUri, queryCollection), uriKind); + } + + public static string CustomProviderName(TwoFactorProviderType type) + { + return string.Concat("Custom_", type.ToString()); + } + + public static bool UserInviteTokenIsValid(IDataProtector protector, string token, string userEmail, + Guid orgUserId, IGlobalSettings globalSettings) + { + return TokenIsValid("OrganizationUserInvite", protector, token, userEmail, orgUserId, + globalSettings.OrganizationInviteExpirationHours); + } + + public static bool TokenIsValid(string firstTokenPart, IDataProtector protector, string token, string userEmail, + Guid id, double expirationInHours) + { + var invalid = true; + try + { + var unprotectedData = protector.Unprotect(token); + var dataParts = unprotectedData.Split(' '); + if (dataParts.Length == 4 && dataParts[0] == firstTokenPart && + new Guid(dataParts[1]) == id && + dataParts[2].Equals(userEmail, StringComparison.InvariantCultureIgnoreCase)) + { + var creationTime = FromEpocMilliseconds(Convert.ToInt64(dataParts[3])); + var expTime = creationTime.AddHours(expirationInHours); + invalid = expTime < DateTime.UtcNow; + } + } + catch + { + invalid = true; + } + + return !invalid; + } + + public static string GetApplicationCacheServiceBusSubcriptionName(GlobalSettings globalSettings) + { + var subName = globalSettings.ServiceBus.ApplicationCacheSubscriptionName; + if (string.IsNullOrWhiteSpace(subName)) + { + var websiteInstanceId = Environment.GetEnvironmentVariable("WEBSITE_INSTANCE_ID"); + if (string.IsNullOrWhiteSpace(websiteInstanceId)) + { + throw new Exception("No service bus subscription name available."); + } + else + { + subName = $"{globalSettings.ProjectName.ToLower()}_{websiteInstanceId}"; + if (subName.Length > 50) + { + subName = subName.Substring(0, 50); + } + } + } + return subName; + } + + public static string GetIpAddress(this Microsoft.AspNetCore.Http.HttpContext httpContext, + GlobalSettings globalSettings) + { + if (httpContext == null) + { + return null; + } + + if (!globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(CloudFlareConnectingIp)) + { + return httpContext.Request.Headers[CloudFlareConnectingIp].ToString(); + } + if (globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(RealIp)) + { + return httpContext.Request.Headers[RealIp].ToString(); + } + + return httpContext.Connection?.RemoteIpAddress?.ToString(); + } + + public static bool IsCorsOriginAllowed(string origin, GlobalSettings globalSettings) + { + return + // Web vault + origin == globalSettings.BaseServiceUri.Vault || + // Safari extension origin + origin == "file://" || + // Product website + (!globalSettings.SelfHosted && origin == "https://bitwarden.com"); + } + + public static X509Certificate2 GetIdentityServerCertificate(GlobalSettings globalSettings) + { + if (globalSettings.SelfHosted && + SettingHasValue(globalSettings.IdentityServer.CertificatePassword) + && File.Exists("identity.pfx")) + { + return GetCertificate("identity.pfx", + globalSettings.IdentityServer.CertificatePassword); + } + else if (SettingHasValue(globalSettings.IdentityServer.CertificateThumbprint)) + { + return GetCertificate( + globalSettings.IdentityServer.CertificateThumbprint); + } + else if (!globalSettings.SelfHosted && + SettingHasValue(globalSettings.Storage?.ConnectionString) && + SettingHasValue(globalSettings.IdentityServer.CertificatePassword)) + { + return GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", + "identity.pfx", globalSettings.IdentityServer.CertificatePassword).GetAwaiter().GetResult(); + } return null; } - if (!globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(CloudFlareConnectingIp)) + public static Dictionary AdjustIdentityServerConfig(Dictionary configDict, + string publicServiceUri, string internalServiceUri) { - return httpContext.Request.Headers[CloudFlareConnectingIp].ToString(); - } - if (globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(RealIp)) - { - return httpContext.Request.Headers[RealIp].ToString(); - } - - return httpContext.Connection?.RemoteIpAddress?.ToString(); - } - - public static bool IsCorsOriginAllowed(string origin, GlobalSettings globalSettings) - { - return - // Web vault - origin == globalSettings.BaseServiceUri.Vault || - // Safari extension origin - origin == "file://" || - // Product website - (!globalSettings.SelfHosted && origin == "https://bitwarden.com"); - } - - public static X509Certificate2 GetIdentityServerCertificate(GlobalSettings globalSettings) - { - if (globalSettings.SelfHosted && - SettingHasValue(globalSettings.IdentityServer.CertificatePassword) - && File.Exists("identity.pfx")) - { - return GetCertificate("identity.pfx", - globalSettings.IdentityServer.CertificatePassword); - } - else if (SettingHasValue(globalSettings.IdentityServer.CertificateThumbprint)) - { - return GetCertificate( - globalSettings.IdentityServer.CertificateThumbprint); - } - else if (!globalSettings.SelfHosted && - SettingHasValue(globalSettings.Storage?.ConnectionString) && - SettingHasValue(globalSettings.IdentityServer.CertificatePassword)) - { - return GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", - "identity.pfx", globalSettings.IdentityServer.CertificatePassword).GetAwaiter().GetResult(); - } - return null; - } - - public static Dictionary AdjustIdentityServerConfig(Dictionary configDict, - string publicServiceUri, string internalServiceUri) - { - var dictReplace = new Dictionary(); - foreach (var item in configDict) - { - if (item.Key == "authorization_endpoint" && item.Value is string val) + var dictReplace = new Dictionary(); + foreach (var item in configDict) { - var uri = new Uri(val); - dictReplace.Add(item.Key, string.Concat(publicServiceUri, uri.LocalPath)); - } - else if ((item.Key == "jwks_uri" || item.Key.EndsWith("_endpoint")) && item.Value is string val2) - { - var uri = new Uri(val2); - dictReplace.Add(item.Key, string.Concat(internalServiceUri, uri.LocalPath)); - } - } - foreach (var replace in dictReplace) - { - configDict[replace.Key] = replace.Value; - } - return configDict; - } - - public static List> BuildIdentityClaims(User user, ICollection orgs, - ICollection providers, bool isPremium) - { - var claims = new List>() - { - new KeyValuePair("premium", isPremium ? "true" : "false"), - new KeyValuePair(JwtClaimTypes.Email, user.Email), - new KeyValuePair(JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false"), - new KeyValuePair("sstamp", user.SecurityStamp) - }; - - if (!string.IsNullOrWhiteSpace(user.Name)) - { - claims.Add(new KeyValuePair(JwtClaimTypes.Name, user.Name)); - } - - // Orgs that this user belongs to - if (orgs.Any()) - { - foreach (var group in orgs.GroupBy(o => o.Type)) - { - switch (group.Key) + if (item.Key == "authorization_endpoint" && item.Value is string val) { - case Enums.OrganizationUserType.Owner: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgowner", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.Admin: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgadmin", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.Manager: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgmanager", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.User: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orguser", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.Custom: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgcustom", org.Id.ToString())); - foreach (var (permission, claimName) in org.Permissions.ClaimsMap) + var uri = new Uri(val); + dictReplace.Add(item.Key, string.Concat(publicServiceUri, uri.LocalPath)); + } + else if ((item.Key == "jwks_uri" || item.Key.EndsWith("_endpoint")) && item.Value is string val2) + { + var uri = new Uri(val2); + dictReplace.Add(item.Key, string.Concat(internalServiceUri, uri.LocalPath)); + } + } + foreach (var replace in dictReplace) + { + configDict[replace.Key] = replace.Value; + } + return configDict; + } + + public static List> BuildIdentityClaims(User user, ICollection orgs, + ICollection providers, bool isPremium) + { + var claims = new List>() + { + new KeyValuePair("premium", isPremium ? "true" : "false"), + new KeyValuePair(JwtClaimTypes.Email, user.Email), + new KeyValuePair(JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false"), + new KeyValuePair("sstamp", user.SecurityStamp) + }; + + if (!string.IsNullOrWhiteSpace(user.Name)) + { + claims.Add(new KeyValuePair(JwtClaimTypes.Name, user.Name)); + } + + // Orgs that this user belongs to + if (orgs.Any()) + { + foreach (var group in orgs.GroupBy(o => o.Type)) + { + switch (group.Key) + { + case Enums.OrganizationUserType.Owner: + foreach (var org in group) { - if (!permission) - { - continue; - } - - claims.Add(new KeyValuePair(claimName, org.Id.ToString())); + claims.Add(new KeyValuePair("orgowner", org.Id.ToString())); } - } - break; - default: - break; + break; + case Enums.OrganizationUserType.Admin: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgadmin", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.Manager: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgmanager", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.User: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orguser", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.Custom: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgcustom", org.Id.ToString())); + foreach (var (permission, claimName) in org.Permissions.ClaimsMap) + { + if (!permission) + { + continue; + } + + claims.Add(new KeyValuePair(claimName, org.Id.ToString())); + } + } + break; + default: + break; + } } } - } - if (providers.Any()) - { - foreach (var group in providers.GroupBy(o => o.Type)) + if (providers.Any()) { - switch (group.Key) + foreach (var group in providers.GroupBy(o => o.Type)) { - case ProviderUserType.ProviderAdmin: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); - } - break; - case ProviderUserType.ServiceUser: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); - } - break; + switch (group.Key) + { + case ProviderUserType.ProviderAdmin: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); + } + break; + case ProviderUserType.ServiceUser: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); + } + break; + } } } + + return claims; } - return claims; - } - - public static T LoadClassFromJsonData(string jsonData) where T : new() - { - if (string.IsNullOrWhiteSpace(jsonData)) + public static T LoadClassFromJsonData(string jsonData) where T : new() { - return new T(); + if (string.IsNullOrWhiteSpace(jsonData)) + { + return new T(); + } + + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + return System.Text.Json.JsonSerializer.Deserialize(jsonData, options); } - var options = new JsonSerializerOptions + public static string ClassToJsonData(T data) { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; - return System.Text.Json.JsonSerializer.Deserialize(jsonData, options); - } + return System.Text.Json.JsonSerializer.Serialize(data, options); + } - public static string ClassToJsonData(T data) - { - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - return System.Text.Json.JsonSerializer.Serialize(data, options); - } - - public static ICollection AddIfNotExists(this ICollection list, T item) - { - if (list.Contains(item)) + public static ICollection AddIfNotExists(this ICollection list, T item) { + if (list.Contains(item)) + { + return list; + } + list.Add(item); return list; } - list.Add(item); - return list; - } - public static string DecodeMessageText(this QueueMessage message) - { - var text = message?.MessageText; - if (string.IsNullOrWhiteSpace(text)) + public static string DecodeMessageText(this QueueMessage message) { - return text; - } - try - { - return Base64DecodeString(text); - } - catch - { - return text; - } - } - - public static bool FixedTimeEquals(string input1, string input2) - { - return CryptographicOperations.FixedTimeEquals( - Encoding.UTF8.GetBytes(input1), Encoding.UTF8.GetBytes(input2)); - } - - public static string ObfuscateEmail(string email) - { - if (email == null) - { - return email; + var text = message?.MessageText; + if (string.IsNullOrWhiteSpace(text)) + { + return text; + } + try + { + return Base64DecodeString(text); + } + catch + { + return text; + } } - var emailParts = email.Split('@', StringSplitOptions.RemoveEmptyEntries); - - if (emailParts.Length != 2) + public static bool FixedTimeEquals(string input1, string input2) { - return email; + return CryptographicOperations.FixedTimeEquals( + Encoding.UTF8.GetBytes(input1), Encoding.UTF8.GetBytes(input2)); } - var username = emailParts[0]; - - if (username.Length < 2) + public static string ObfuscateEmail(string email) { - return email; + if (email == null) + { + return email; + } + + var emailParts = email.Split('@', StringSplitOptions.RemoveEmptyEntries); + + if (emailParts.Length != 2) + { + return email; + } + + var username = emailParts[0]; + + if (username.Length < 2) + { + return email; + } + + var sb = new StringBuilder(); + sb.Append(emailParts[0][..2]); + for (var i = 2; i < emailParts[0].Length; i++) + { + sb.Append('*'); + } + + return sb.Append('@') + .Append(emailParts[1]) + .ToString(); + } - - var sb = new StringBuilder(); - sb.Append(emailParts[0][..2]); - for (var i = 2; i < emailParts[0].Length; i++) - { - sb.Append('*'); - } - - return sb.Append('@') - .Append(emailParts[1]) - .ToString(); - } } diff --git a/src/Core/Utilities/CurrentContextMiddleware.cs b/src/Core/Utilities/CurrentContextMiddleware.cs index c1ac9322c2..bfba894dda 100644 --- a/src/Core/Utilities/CurrentContextMiddleware.cs +++ b/src/Core/Utilities/CurrentContextMiddleware.cs @@ -2,20 +2,21 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Utilities; - -public class CurrentContextMiddleware +namespace Bit.Core.Utilities { - private readonly RequestDelegate _next; - - public CurrentContextMiddleware(RequestDelegate next) + public class CurrentContextMiddleware { - _next = next; - } + private readonly RequestDelegate _next; - public async Task Invoke(HttpContext httpContext, ICurrentContext currentContext, GlobalSettings globalSettings) - { - await currentContext.BuildAsync(httpContext, globalSettings); - await _next.Invoke(httpContext); + public CurrentContextMiddleware(RequestDelegate next) + { + _next = next; + } + + public async Task Invoke(HttpContext httpContext, ICurrentContext currentContext, GlobalSettings globalSettings) + { + await currentContext.BuildAsync(httpContext, globalSettings); + await _next.Invoke(httpContext); + } } } diff --git a/src/Core/Utilities/CustomIpRateLimitMiddleware.cs b/src/Core/Utilities/CustomIpRateLimitMiddleware.cs index 5fb82cac02..529495e093 100644 --- a/src/Core/Utilities/CustomIpRateLimitMiddleware.cs +++ b/src/Core/Utilities/CustomIpRateLimitMiddleware.cs @@ -6,85 +6,86 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -namespace Bit.Core.Utilities; - -public class CustomIpRateLimitMiddleware : IpRateLimitMiddleware +namespace Bit.Core.Utilities { - private readonly IBlockIpService _blockIpService; - private readonly ILogger _logger; - private readonly IDistributedCache _distributedCache; - private readonly IpRateLimitOptions _options; - - public CustomIpRateLimitMiddleware( - IDistributedCache distributedCache, - IBlockIpService blockIpService, - RequestDelegate next, - IProcessingStrategy processingStrategy, - IRateLimitConfiguration rateLimitConfiguration, - IOptions options, - IIpPolicyStore policyStore, - ILogger logger) - : base(next, processingStrategy, options, policyStore, rateLimitConfiguration, logger) + public class CustomIpRateLimitMiddleware : IpRateLimitMiddleware { - _distributedCache = distributedCache; - _blockIpService = blockIpService; - _options = options.Value; - _logger = logger; - } + private readonly IBlockIpService _blockIpService; + private readonly ILogger _logger; + private readonly IDistributedCache _distributedCache; + private readonly IpRateLimitOptions _options; - public override Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) - { - var message = string.IsNullOrWhiteSpace(_options.QuotaExceededMessage) - ? $"Slow down! Too many requests. Try again in {rule.Period}." - : _options.QuotaExceededMessage; - httpContext.Response.Headers["Retry-After"] = retryAfter; - httpContext.Response.StatusCode = _options.HttpStatusCode; - var errorModel = new ErrorResponseModel { Message = message }; - return httpContext.Response.WriteAsJsonAsync(errorModel, httpContext.RequestAborted); - } - - protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, - RateLimitCounter counter, RateLimitRule rule) - { - base.LogBlockedRequest(httpContext, identity, counter, rule); - var key = $"blockedIp_{identity.ClientIp}"; - - _distributedCache.TryGetValue(key, out int blockedCount); - - blockedCount++; - if (blockedCount > 10) + public CustomIpRateLimitMiddleware( + IDistributedCache distributedCache, + IBlockIpService blockIpService, + RequestDelegate next, + IProcessingStrategy processingStrategy, + IRateLimitConfiguration rateLimitConfiguration, + IOptions options, + IIpPolicyStore policyStore, + ILogger logger) + : base(next, processingStrategy, options, policyStore, rateLimitConfiguration, logger) { - _blockIpService.BlockIpAsync(identity.ClientIp, false); - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Banned {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); - } - else - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Request blocked {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); - _distributedCache.Set(key, blockedCount, - new DistributedCacheEntryOptions().SetSlidingExpiration(new TimeSpan(0, 5, 0))); - } - } - - private string GetRequestInfo(HttpContext httpContext) - { - if (httpContext == null || httpContext.Request == null) - { - return null; + _distributedCache = distributedCache; + _blockIpService = blockIpService; + _options = options.Value; + _logger = logger; } - var s = string.Empty; - foreach (var header in httpContext.Request.Headers) + public override Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) { - s += $"Header \"{header.Key}\": {header.Value} \n"; + var message = string.IsNullOrWhiteSpace(_options.QuotaExceededMessage) + ? $"Slow down! Too many requests. Try again in {rule.Period}." + : _options.QuotaExceededMessage; + httpContext.Response.Headers["Retry-After"] = retryAfter; + httpContext.Response.StatusCode = _options.HttpStatusCode; + var errorModel = new ErrorResponseModel { Message = message }; + return httpContext.Response.WriteAsJsonAsync(errorModel, httpContext.RequestAborted); } - foreach (var query in httpContext.Request.Query) + protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, + RateLimitCounter counter, RateLimitRule rule) { - s += $"Query \"{query.Key}\": {query.Value} \n"; + base.LogBlockedRequest(httpContext, identity, counter, rule); + var key = $"blockedIp_{identity.ClientIp}"; + + _distributedCache.TryGetValue(key, out int blockedCount); + + blockedCount++; + if (blockedCount > 10) + { + _blockIpService.BlockIpAsync(identity.ClientIp, false); + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Banned {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); + } + else + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Request blocked {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); + _distributedCache.Set(key, blockedCount, + new DistributedCacheEntryOptions().SetSlidingExpiration(new TimeSpan(0, 5, 0))); + } } - return s; + private string GetRequestInfo(HttpContext httpContext) + { + if (httpContext == null || httpContext.Request == null) + { + return null; + } + + var s = string.Empty; + foreach (var header in httpContext.Request.Headers) + { + s += $"Header \"{header.Key}\": {header.Value} \n"; + } + + foreach (var query in httpContext.Request.Query) + { + s += $"Query \"{query.Key}\": {query.Value} \n"; + } + + return s; + } } } diff --git a/src/Core/Utilities/DistributedCacheExtensions.cs b/src/Core/Utilities/DistributedCacheExtensions.cs index 28282b6a47..d27d0ee469 100644 --- a/src/Core/Utilities/DistributedCacheExtensions.cs +++ b/src/Core/Utilities/DistributedCacheExtensions.cs @@ -1,47 +1,48 @@ using System.Text.Json; using Microsoft.Extensions.Caching.Distributed; -namespace Bit.Core.Utilities; - -public static class DistributedCacheExtensions +namespace Bit.Core.Utilities { - public static void Set(this IDistributedCache cache, string key, T value) + public static class DistributedCacheExtensions { - Set(cache, key, value, new DistributedCacheEntryOptions()); - } - - public static void Set(this IDistributedCache cache, string key, T value, - DistributedCacheEntryOptions options) - { - var bytes = JsonSerializer.SerializeToUtf8Bytes(value); - cache.Set(key, bytes, options); - } - - public static Task SetAsync(this IDistributedCache cache, string key, T value) - { - return SetAsync(cache, key, value, new DistributedCacheEntryOptions()); - } - - public static Task SetAsync(this IDistributedCache cache, string key, T value, - DistributedCacheEntryOptions options) - { - var bytes = JsonSerializer.SerializeToUtf8Bytes(value); - return cache.SetAsync(key, bytes, options); - } - - public static bool TryGetValue(this IDistributedCache cache, string key, out T value) - { - var val = cache.Get(key); - value = default; - if (val == null) return false; - try + public static void Set(this IDistributedCache cache, string key, T value) { - value = JsonSerializer.Deserialize(val); + Set(cache, key, value, new DistributedCacheEntryOptions()); } - catch + + public static void Set(this IDistributedCache cache, string key, T value, + DistributedCacheEntryOptions options) { - return false; + var bytes = JsonSerializer.SerializeToUtf8Bytes(value); + cache.Set(key, bytes, options); + } + + public static Task SetAsync(this IDistributedCache cache, string key, T value) + { + return SetAsync(cache, key, value, new DistributedCacheEntryOptions()); + } + + public static Task SetAsync(this IDistributedCache cache, string key, T value, + DistributedCacheEntryOptions options) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(value); + return cache.SetAsync(key, bytes, options); + } + + public static bool TryGetValue(this IDistributedCache cache, string key, out T value) + { + var val = cache.Get(key); + value = default; + if (val == null) return false; + try + { + value = JsonSerializer.Deserialize(val); + } + catch + { + return false; + } + return true; } - return true; } } diff --git a/src/Core/Utilities/DuoApi.cs b/src/Core/Utilities/DuoApi.cs index b5a3f040d4..c98c98938f 100644 --- a/src/Core/Utilities/DuoApi.cs +++ b/src/Core/Utilities/DuoApi.cs @@ -16,293 +16,294 @@ using System.Text.Json; using System.Text.RegularExpressions; using System.Web; -namespace Bit.Core.Utilities.Duo; - -public class DuoApi +namespace Bit.Core.Utilities.Duo { - private const string UrlScheme = "https"; - private const string UserAgent = "Bitwarden_DuoAPICSharp/1.0 (.NET Core)"; - - private readonly string _host; - private readonly string _ikey; - private readonly string _skey; - - public DuoApi(string ikey, string skey, string host) + public class DuoApi { - _ikey = ikey; - _skey = skey; - _host = host; + private const string UrlScheme = "https"; + private const string UserAgent = "Bitwarden_DuoAPICSharp/1.0 (.NET Core)"; - if (!ValidHost(host)) + private readonly string _host; + private readonly string _ikey; + private readonly string _skey; + + public DuoApi(string ikey, string skey, string host) { - throw new DuoException("Invalid Duo host configured.", new ArgumentException(nameof(host))); - } - } + _ikey = ikey; + _skey = skey; + _host = host; - public static bool ValidHost(string host) - { - if (Uri.TryCreate($"https://{host}", UriKind.Absolute, out var uri)) - { - return (string.IsNullOrWhiteSpace(uri.PathAndQuery) || uri.PathAndQuery == "/") && - uri.Host.StartsWith("api-") && - (uri.Host.EndsWith(".duosecurity.com") || uri.Host.EndsWith(".duofederal.com")); - } - return false; - } - - public static string CanonicalizeParams(Dictionary parameters) - { - var ret = new List(); - foreach (var pair in parameters) - { - var p = string.Format("{0}={1}", HttpUtility.UrlEncode(pair.Key), HttpUtility.UrlEncode(pair.Value)); - // Signatures require upper-case hex digits. - p = Regex.Replace(p, "(%[0-9A-Fa-f][0-9A-Fa-f])", c => c.Value.ToUpperInvariant()); - // Escape only the expected characters. - p = Regex.Replace(p, "([!'()*])", c => "%" + Convert.ToByte(c.Value[0]).ToString("X")); - p = p.Replace("%7E", "~"); - // UrlEncode converts space (" ") to "+". The - // signature algorithm requires "%20" instead. Actual - // + has already been replaced with %2B. - p = p.Replace("+", "%20"); - ret.Add(p); - } - - ret.Sort(StringComparer.Ordinal); - return string.Join("&", ret.ToArray()); - } - - protected string CanonicalizeRequest(string method, string path, string canonParams, string date) - { - string[] lines = { - date, - method.ToUpperInvariant(), - _host.ToLower(), - path, - canonParams, - }; - return string.Join("\n", lines); - } - - public string Sign(string method, string path, string canonParams, string date) - { - var canon = CanonicalizeRequest(method, path, canonParams, date); - var sig = HmacSign(canon); - var auth = string.Concat(_ikey, ':', sig); - return string.Concat("Basic ", Encode64(auth)); - } - - public string ApiCall(string method, string path, Dictionary parameters = null) - { - return ApiCall(method, path, parameters, 0, out var statusCode); - } - - /// The request timeout, in milliseconds. - /// Specify 0 to use the system-default timeout. Use caution if - /// you choose to specify a custom timeout - some API - /// calls (particularly in the Auth APIs) will not - /// return a response until an out-of-band authentication process - /// has completed. In some cases, this may take as much as a - /// small number of minutes. - public string ApiCall(string method, string path, Dictionary parameters, int timeout, - out HttpStatusCode statusCode) - { - if (parameters == null) - { - parameters = new Dictionary(); - } - - var canonParams = CanonicalizeParams(parameters); - var query = string.Empty; - if (!method.Equals("POST") && !method.Equals("PUT")) - { - if (parameters.Count > 0) + if (!ValidHost(host)) { - query = "?" + canonParams; + throw new DuoException("Invalid Duo host configured.", new ArgumentException(nameof(host))); } } - var url = string.Format("{0}://{1}{2}{3}", UrlScheme, _host, path, query); - var dateString = RFC822UtcNow(); - var auth = Sign(method, path, canonParams, dateString); - - var request = (HttpWebRequest)WebRequest.Create(url); - request.Method = method; - request.Accept = "application/json"; - request.Headers.Add("Authorization", auth); - request.Headers.Add("X-Duo-Date", dateString); - request.UserAgent = UserAgent; - - if (method.Equals("POST") || method.Equals("PUT")) + public static bool ValidHost(string host) { - var data = Encoding.UTF8.GetBytes(canonParams); - request.ContentType = "application/x-www-form-urlencoded"; - request.ContentLength = data.Length; - using (var requestStream = request.GetRequestStream()) + if (Uri.TryCreate($"https://{host}", UriKind.Absolute, out var uri)) { - requestStream.Write(data, 0, data.Length); + return (string.IsNullOrWhiteSpace(uri.PathAndQuery) || uri.PathAndQuery == "/") && + uri.Host.StartsWith("api-") && + (uri.Host.EndsWith(".duosecurity.com") || uri.Host.EndsWith(".duofederal.com")); + } + return false; + } + + public static string CanonicalizeParams(Dictionary parameters) + { + var ret = new List(); + foreach (var pair in parameters) + { + var p = string.Format("{0}={1}", HttpUtility.UrlEncode(pair.Key), HttpUtility.UrlEncode(pair.Value)); + // Signatures require upper-case hex digits. + p = Regex.Replace(p, "(%[0-9A-Fa-f][0-9A-Fa-f])", c => c.Value.ToUpperInvariant()); + // Escape only the expected characters. + p = Regex.Replace(p, "([!'()*])", c => "%" + Convert.ToByte(c.Value[0]).ToString("X")); + p = p.Replace("%7E", "~"); + // UrlEncode converts space (" ") to "+". The + // signature algorithm requires "%20" instead. Actual + // + has already been replaced with %2B. + p = p.Replace("+", "%20"); + ret.Add(p); + } + + ret.Sort(StringComparer.Ordinal); + return string.Join("&", ret.ToArray()); + } + + protected string CanonicalizeRequest(string method, string path, string canonParams, string date) + { + string[] lines = { + date, + method.ToUpperInvariant(), + _host.ToLower(), + path, + canonParams, + }; + return string.Join("\n", lines); + } + + public string Sign(string method, string path, string canonParams, string date) + { + var canon = CanonicalizeRequest(method, path, canonParams, date); + var sig = HmacSign(canon); + var auth = string.Concat(_ikey, ':', sig); + return string.Concat("Basic ", Encode64(auth)); + } + + public string ApiCall(string method, string path, Dictionary parameters = null) + { + return ApiCall(method, path, parameters, 0, out var statusCode); + } + + /// The request timeout, in milliseconds. + /// Specify 0 to use the system-default timeout. Use caution if + /// you choose to specify a custom timeout - some API + /// calls (particularly in the Auth APIs) will not + /// return a response until an out-of-band authentication process + /// has completed. In some cases, this may take as much as a + /// small number of minutes. + public string ApiCall(string method, string path, Dictionary parameters, int timeout, + out HttpStatusCode statusCode) + { + if (parameters == null) + { + parameters = new Dictionary(); + } + + var canonParams = CanonicalizeParams(parameters); + var query = string.Empty; + if (!method.Equals("POST") && !method.Equals("PUT")) + { + if (parameters.Count > 0) + { + query = "?" + canonParams; + } + } + var url = string.Format("{0}://{1}{2}{3}", UrlScheme, _host, path, query); + + var dateString = RFC822UtcNow(); + var auth = Sign(method, path, canonParams, dateString); + + var request = (HttpWebRequest)WebRequest.Create(url); + request.Method = method; + request.Accept = "application/json"; + request.Headers.Add("Authorization", auth); + request.Headers.Add("X-Duo-Date", dateString); + request.UserAgent = UserAgent; + + if (method.Equals("POST") || method.Equals("PUT")) + { + var data = Encoding.UTF8.GetBytes(canonParams); + request.ContentType = "application/x-www-form-urlencoded"; + request.ContentLength = data.Length; + using (var requestStream = request.GetRequestStream()) + { + requestStream.Write(data, 0, data.Length); + } + } + if (timeout > 0) + { + request.Timeout = timeout; + } + + // Do the request and process the result. + HttpWebResponse response; + try + { + response = (HttpWebResponse)request.GetResponse(); + } + catch (WebException ex) + { + response = (HttpWebResponse)ex.Response; + if (response == null) + { + throw; + } + } + using (var reader = new StreamReader(response.GetResponseStream())) + { + statusCode = response.StatusCode; + return reader.ReadToEnd(); } } - if (timeout > 0) + + public T JSONApiCall(string method, string path, Dictionary parameters = null) + where T : class { - request.Timeout = timeout; + return JSONApiCall(method, path, parameters, 0); } - // Do the request and process the result. - HttpWebResponse response; - try + /// The request timeout, in milliseconds. + /// Specify 0 to use the system-default timeout. Use caution if + /// you choose to specify a custom timeout - some API + /// calls (particularly in the Auth APIs) will not + /// return a response until an out-of-band authentication process + /// has completed. In some cases, this may take as much as a + /// small number of minutes. + public T JSONApiCall(string method, string path, Dictionary parameters, int timeout) + where T : class { - response = (HttpWebResponse)request.GetResponse(); - } - catch (WebException ex) - { - response = (HttpWebResponse)ex.Response; - if (response == null) + var res = ApiCall(method, path, parameters, timeout, out var statusCode); + try + { + // TODO: We should deserialize this into our own DTO and not work on dictionaries. + var dict = JsonSerializer.Deserialize>(res); + if (dict["stat"].ToString() == "OK") + { + return JsonSerializer.Deserialize(dict["response"].ToString()); + } + + var check = ToNullableInt(dict["code"].ToString()); + var code = check.GetValueOrDefault(0); + var messageDetail = string.Empty; + if (dict.ContainsKey("message_detail")) + { + messageDetail = dict["message_detail"].ToString(); + } + throw new ApiException(code, (int)statusCode, dict["message"].ToString(), messageDetail); + } + catch (ApiException) { throw; } - } - using (var reader = new StreamReader(response.GetResponseStream())) - { - statusCode = response.StatusCode; - return reader.ReadToEnd(); - } - } - - public T JSONApiCall(string method, string path, Dictionary parameters = null) - where T : class - { - return JSONApiCall(method, path, parameters, 0); - } - - /// The request timeout, in milliseconds. - /// Specify 0 to use the system-default timeout. Use caution if - /// you choose to specify a custom timeout - some API - /// calls (particularly in the Auth APIs) will not - /// return a response until an out-of-band authentication process - /// has completed. In some cases, this may take as much as a - /// small number of minutes. - public T JSONApiCall(string method, string path, Dictionary parameters, int timeout) - where T : class - { - var res = ApiCall(method, path, parameters, timeout, out var statusCode); - try - { - // TODO: We should deserialize this into our own DTO and not work on dictionaries. - var dict = JsonSerializer.Deserialize>(res); - if (dict["stat"].ToString() == "OK") + catch (Exception e) { - return JsonSerializer.Deserialize(dict["response"].ToString()); + throw new BadResponseException((int)statusCode, e); } + } - var check = ToNullableInt(dict["code"].ToString()); - var code = check.GetValueOrDefault(0); - var messageDetail = string.Empty; - if (dict.ContainsKey("message_detail")) + private int? ToNullableInt(string s) + { + int i; + if (int.TryParse(s, out i)) { - messageDetail = dict["message_detail"].ToString(); + return i; } - throw new ApiException(code, (int)statusCode, dict["message"].ToString(), messageDetail); + return null; } - catch (ApiException) + + private string HmacSign(string data) { - throw; + var keyBytes = Encoding.ASCII.GetBytes(_skey); + var dataBytes = Encoding.ASCII.GetBytes(data); + + using (var hmac = new HMACSHA1(keyBytes)) + { + var hash = hmac.ComputeHash(dataBytes); + var hex = BitConverter.ToString(hash); + return hex.Replace("-", string.Empty).ToLower(); + } } - catch (Exception e) + + private static string Encode64(string plaintext) { - throw new BadResponseException((int)statusCode, e); + var plaintextBytes = Encoding.ASCII.GetBytes(plaintext); + return Convert.ToBase64String(plaintextBytes); + } + + private static string RFC822UtcNow() + { + // Can't use the "zzzz" format because it adds a ":" + // between the offset's hours and minutes. + var dateString = DateTime.UtcNow.ToString("ddd, dd MMM yyyy HH:mm:ss", CultureInfo.InvariantCulture); + var offset = 0; + var zone = "+" + offset.ToString(CultureInfo.InvariantCulture).PadLeft(2, '0'); + dateString += " " + zone.PadRight(5, '0'); + return dateString; } } - private int? ToNullableInt(string s) + public class DuoException : Exception { - int i; - if (int.TryParse(s, out i)) - { - return i; - } - return null; - } + public int HttpStatus { get; private set; } - private string HmacSign(string data) - { - var keyBytes = Encoding.ASCII.GetBytes(_skey); - var dataBytes = Encoding.ASCII.GetBytes(data); + public DuoException(string message, Exception inner) + : base(message, inner) + { } - using (var hmac = new HMACSHA1(keyBytes)) + public DuoException(int httpStatus, string message, Exception inner) + : base(message, inner) { - var hash = hmac.ComputeHash(dataBytes); - var hex = BitConverter.ToString(hash); - return hex.Replace("-", string.Empty).ToLower(); + HttpStatus = httpStatus; } } - private static string Encode64(string plaintext) + public class ApiException : DuoException { - var plaintextBytes = Encoding.ASCII.GetBytes(plaintext); - return Convert.ToBase64String(plaintextBytes); + public int Code { get; private set; } + public string ApiMessage { get; private set; } + public string ApiMessageDetail { get; private set; } + + public ApiException(int code, int httpStatus, string apiMessage, string apiMessageDetail) + : base(httpStatus, FormatMessage(code, apiMessage, apiMessageDetail), null) + { + Code = code; + ApiMessage = apiMessage; + ApiMessageDetail = apiMessageDetail; + } + + private static string FormatMessage(int code, string apiMessage, string apiMessageDetail) + { + return string.Format("Duo API Error {0}: '{1}' ('{2}')", code, apiMessage, apiMessageDetail); + } } - private static string RFC822UtcNow() + public class BadResponseException : DuoException { - // Can't use the "zzzz" format because it adds a ":" - // between the offset's hours and minutes. - var dateString = DateTime.UtcNow.ToString("ddd, dd MMM yyyy HH:mm:ss", CultureInfo.InvariantCulture); - var offset = 0; - var zone = "+" + offset.ToString(CultureInfo.InvariantCulture).PadLeft(2, '0'); - dateString += " " + zone.PadRight(5, '0'); - return dateString; - } -} - -public class DuoException : Exception -{ - public int HttpStatus { get; private set; } - - public DuoException(string message, Exception inner) - : base(message, inner) - { } - - public DuoException(int httpStatus, string message, Exception inner) - : base(message, inner) - { - HttpStatus = httpStatus; - } -} - -public class ApiException : DuoException -{ - public int Code { get; private set; } - public string ApiMessage { get; private set; } - public string ApiMessageDetail { get; private set; } - - public ApiException(int code, int httpStatus, string apiMessage, string apiMessageDetail) - : base(httpStatus, FormatMessage(code, apiMessage, apiMessageDetail), null) - { - Code = code; - ApiMessage = apiMessage; - ApiMessageDetail = apiMessageDetail; - } - - private static string FormatMessage(int code, string apiMessage, string apiMessageDetail) - { - return string.Format("Duo API Error {0}: '{1}' ('{2}')", code, apiMessage, apiMessageDetail); - } -} - -public class BadResponseException : DuoException -{ - public BadResponseException(int httpStatus, Exception inner) - : base(httpStatus, FormatMessage(httpStatus, inner), inner) - { } - - private static string FormatMessage(int httpStatus, Exception inner) - { - var innerMessage = "(null)"; - if (inner != null) - { - innerMessage = string.Format("'{0}'", inner.Message); - } - return string.Format("Got error {0} with HTTP Status {1}", innerMessage, httpStatus); + public BadResponseException(int httpStatus, Exception inner) + : base(httpStatus, FormatMessage(httpStatus, inner), inner) + { } + + private static string FormatMessage(int httpStatus, Exception inner) + { + var innerMessage = "(null)"; + if (inner != null) + { + innerMessage = string.Format("'{0}'", inner.Message); + } + return string.Format("Got error {0} with HTTP Status {1}", innerMessage, httpStatus); + } } } diff --git a/src/Core/Utilities/DuoWeb.cs b/src/Core/Utilities/DuoWeb.cs index 151f71a15f..f8259d200e 100644 --- a/src/Core/Utilities/DuoWeb.cs +++ b/src/Core/Utilities/DuoWeb.cs @@ -36,205 +36,206 @@ THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. using System.Security.Cryptography; using System.Text; -namespace Bit.Core.Utilities.Duo; - -public static class DuoWeb +namespace Bit.Core.Utilities.Duo { - private const string DuoProfix = "TX"; - private const string AppPrefix = "APP"; - private const string AuthPrefix = "AUTH"; - private const int DuoExpire = 300; - private const int AppExpire = 3600; - private const int IKeyLength = 20; - private const int SKeyLength = 40; - private const int AKeyLength = 40; - - public static string ErrorUser = "ERR|The username passed to sign_request() is invalid."; - public static string ErrorIKey = "ERR|The Duo integration key passed to sign_request() is invalid."; - public static string ErrorSKey = "ERR|The Duo secret key passed to sign_request() is invalid."; - public static string ErrorAKey = "ERR|The application secret key passed to sign_request() must be at least " + - "40 characters."; - public static string ErrorUnknown = "ERR|An unknown error has occurred."; - - // throw on invalid bytes - private static Encoding _encoding = new UTF8Encoding(false, true); - private static DateTime _epoc = new DateTime(1970, 1, 1); - - /// - /// Generate a signed request for Duo authentication. - /// The returned value should be passed into the Duo.init() call - /// in the rendered web page used for Duo authentication. - /// - /// Duo integration key - /// Duo secret key - /// Application secret key - /// Primary-authenticated username - /// (optional) The current UTC time - /// signed request - public static string SignRequest(string ikey, string skey, string akey, string username, - DateTime? currentTime = null) + public static class DuoWeb { - string duoSig; - string appSig; + private const string DuoProfix = "TX"; + private const string AppPrefix = "APP"; + private const string AuthPrefix = "AUTH"; + private const int DuoExpire = 300; + private const int AppExpire = 3600; + private const int IKeyLength = 20; + private const int SKeyLength = 40; + private const int AKeyLength = 40; - var currentTimeValue = currentTime ?? DateTime.UtcNow; + public static string ErrorUser = "ERR|The username passed to sign_request() is invalid."; + public static string ErrorIKey = "ERR|The Duo integration key passed to sign_request() is invalid."; + public static string ErrorSKey = "ERR|The Duo secret key passed to sign_request() is invalid."; + public static string ErrorAKey = "ERR|The application secret key passed to sign_request() must be at least " + + "40 characters."; + public static string ErrorUnknown = "ERR|An unknown error has occurred."; - if (username == string.Empty) + // throw on invalid bytes + private static Encoding _encoding = new UTF8Encoding(false, true); + private static DateTime _epoc = new DateTime(1970, 1, 1); + + /// + /// Generate a signed request for Duo authentication. + /// The returned value should be passed into the Duo.init() call + /// in the rendered web page used for Duo authentication. + /// + /// Duo integration key + /// Duo secret key + /// Application secret key + /// Primary-authenticated username + /// (optional) The current UTC time + /// signed request + public static string SignRequest(string ikey, string skey, string akey, string username, + DateTime? currentTime = null) { - return ErrorUser; - } - if (username.Contains("|")) - { - return ErrorUser; - } - if (ikey.Length != IKeyLength) - { - return ErrorIKey; - } - if (skey.Length != SKeyLength) - { - return ErrorSKey; - } - if (akey.Length < AKeyLength) - { - return ErrorAKey; + string duoSig; + string appSig; + + var currentTimeValue = currentTime ?? DateTime.UtcNow; + + if (username == string.Empty) + { + return ErrorUser; + } + if (username.Contains("|")) + { + return ErrorUser; + } + if (ikey.Length != IKeyLength) + { + return ErrorIKey; + } + if (skey.Length != SKeyLength) + { + return ErrorSKey; + } + if (akey.Length < AKeyLength) + { + return ErrorAKey; + } + + try + { + duoSig = SignVals(skey, username, ikey, DuoProfix, DuoExpire, currentTimeValue); + appSig = SignVals(akey, username, ikey, AppPrefix, AppExpire, currentTimeValue); + } + catch + { + return ErrorUnknown; + } + + return $"{duoSig}:{appSig}"; } - try + /// + /// Validate the signed response returned from Duo. + /// Returns the username of the authenticated user, or null. + /// + /// Duo integration key + /// Duo secret key + /// Application secret key + /// The signed response POST'ed to the server + /// (optional) The current UTC time + /// authenticated username, or null + public static string VerifyResponse(string ikey, string skey, string akey, string sigResponse, + DateTime? currentTime = null) { - duoSig = SignVals(skey, username, ikey, DuoProfix, DuoExpire, currentTimeValue); - appSig = SignVals(akey, username, ikey, AppPrefix, AppExpire, currentTimeValue); - } - catch - { - return ErrorUnknown; + string authUser = null; + string appUser = null; + var currentTimeValue = currentTime ?? DateTime.UtcNow; + + try + { + var sigs = sigResponse.Split(':'); + var authSig = sigs[0]; + var appSig = sigs[1]; + + authUser = ParseVals(skey, authSig, AuthPrefix, ikey, currentTimeValue); + appUser = ParseVals(akey, appSig, AppPrefix, ikey, currentTimeValue); + } + catch + { + return null; + } + + if (authUser != appUser) + { + return null; + } + + return authUser; } - return $"{duoSig}:{appSig}"; - } - - /// - /// Validate the signed response returned from Duo. - /// Returns the username of the authenticated user, or null. - /// - /// Duo integration key - /// Duo secret key - /// Application secret key - /// The signed response POST'ed to the server - /// (optional) The current UTC time - /// authenticated username, or null - public static string VerifyResponse(string ikey, string skey, string akey, string sigResponse, - DateTime? currentTime = null) - { - string authUser = null; - string appUser = null; - var currentTimeValue = currentTime ?? DateTime.UtcNow; - - try + private static string SignVals(string key, string username, string ikey, string prefix, long expire, + DateTime currentTime) { - var sigs = sigResponse.Split(':'); - var authSig = sigs[0]; - var appSig = sigs[1]; - - authUser = ParseVals(skey, authSig, AuthPrefix, ikey, currentTimeValue); - appUser = ParseVals(akey, appSig, AppPrefix, ikey, currentTimeValue); - } - catch - { - return null; + var ts = (long)(currentTime - _epoc).TotalSeconds; + expire = ts + expire; + var val = $"{username}|{ikey}|{expire.ToString()}"; + var cookie = $"{prefix}|{Encode64(val)}"; + var sig = Sign(key, cookie); + return $"{cookie}|{sig}"; } - if (authUser != appUser) + private static string ParseVals(string key, string val, string prefix, string ikey, DateTime currentTime) { - return null; + var ts = (long)(currentTime - _epoc).TotalSeconds; + + var parts = val.Split('|'); + if (parts.Length != 3) + { + return null; + } + + var uPrefix = parts[0]; + var uB64 = parts[1]; + var uSig = parts[2]; + + var sig = Sign(key, $"{uPrefix}|{uB64}"); + if (Sign(key, sig) != Sign(key, uSig)) + { + return null; + } + + if (uPrefix != prefix) + { + return null; + } + + var cookie = Decode64(uB64); + var cookieParts = cookie.Split('|'); + if (cookieParts.Length != 3) + { + return null; + } + + var username = cookieParts[0]; + var uIKey = cookieParts[1]; + var expire = cookieParts[2]; + + if (uIKey != ikey) + { + return null; + } + + var expireTs = Convert.ToInt32(expire); + if (ts >= expireTs) + { + return null; + } + + return username; } - return authUser; - } - - private static string SignVals(string key, string username, string ikey, string prefix, long expire, - DateTime currentTime) - { - var ts = (long)(currentTime - _epoc).TotalSeconds; - expire = ts + expire; - var val = $"{username}|{ikey}|{expire.ToString()}"; - var cookie = $"{prefix}|{Encode64(val)}"; - var sig = Sign(key, cookie); - return $"{cookie}|{sig}"; - } - - private static string ParseVals(string key, string val, string prefix, string ikey, DateTime currentTime) - { - var ts = (long)(currentTime - _epoc).TotalSeconds; - - var parts = val.Split('|'); - if (parts.Length != 3) + private static string Sign(string skey, string data) { - return null; + var keyBytes = Encoding.ASCII.GetBytes(skey); + var dataBytes = Encoding.ASCII.GetBytes(data); + + using (var hmac = new HMACSHA1(keyBytes)) + { + var hash = hmac.ComputeHash(dataBytes); + var hex = BitConverter.ToString(hash); + return hex.Replace("-", "").ToLower(); + } } - var uPrefix = parts[0]; - var uB64 = parts[1]; - var uSig = parts[2]; - - var sig = Sign(key, $"{uPrefix}|{uB64}"); - if (Sign(key, sig) != Sign(key, uSig)) + private static string Encode64(string plaintext) { - return null; + var plaintextBytes = _encoding.GetBytes(plaintext); + return Convert.ToBase64String(plaintextBytes); } - if (uPrefix != prefix) + private static string Decode64(string encoded) { - return null; + var plaintextBytes = Convert.FromBase64String(encoded); + return _encoding.GetString(plaintextBytes); } - - var cookie = Decode64(uB64); - var cookieParts = cookie.Split('|'); - if (cookieParts.Length != 3) - { - return null; - } - - var username = cookieParts[0]; - var uIKey = cookieParts[1]; - var expire = cookieParts[2]; - - if (uIKey != ikey) - { - return null; - } - - var expireTs = Convert.ToInt32(expire); - if (ts >= expireTs) - { - return null; - } - - return username; - } - - private static string Sign(string skey, string data) - { - var keyBytes = Encoding.ASCII.GetBytes(skey); - var dataBytes = Encoding.ASCII.GetBytes(data); - - using (var hmac = new HMACSHA1(keyBytes)) - { - var hash = hmac.ComputeHash(dataBytes); - var hex = BitConverter.ToString(hash); - return hex.Replace("-", "").ToLower(); - } - } - - private static string Encode64(string plaintext) - { - var plaintextBytes = _encoding.GetBytes(plaintext); - return Convert.ToBase64String(plaintextBytes); - } - - private static string Decode64(string encoded) - { - var plaintextBytes = Convert.FromBase64String(encoded); - return _encoding.GetString(plaintextBytes); } } diff --git a/src/Core/Utilities/EncryptedStringLengthAttribute.cs b/src/Core/Utilities/EncryptedStringLengthAttribute.cs index d7a8ffaec8..46170487d9 100644 --- a/src/Core/Utilities/EncryptedStringLengthAttribute.cs +++ b/src/Core/Utilities/EncryptedStringLengthAttribute.cs @@ -1,16 +1,17 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Utilities; - -public class EncryptedStringLengthAttribute : StringLengthAttribute +namespace Bit.Core.Utilities { - public EncryptedStringLengthAttribute(int maximumLength) - : base(maximumLength) - { } - - public override string FormatErrorMessage(string name) + public class EncryptedStringLengthAttribute : StringLengthAttribute { - return string.Format("The field {0} exceeds the maximum encrypted value length of {1} characters.", - name, MaximumLength); + public EncryptedStringLengthAttribute(int maximumLength) + : base(maximumLength) + { } + + public override string FormatErrorMessage(string name) + { + return string.Format("The field {0} exceeds the maximum encrypted value length of {1} characters.", + name, MaximumLength); + } } } diff --git a/src/Core/Utilities/EncryptedValueAttribute.cs b/src/Core/Utilities/EncryptedValueAttribute.cs index ec0b218c5f..9ae43110ba 100644 --- a/src/Core/Utilities/EncryptedValueAttribute.cs +++ b/src/Core/Utilities/EncryptedValueAttribute.cs @@ -1,137 +1,138 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Utilities; - -/// -/// Validates a string that is in encrypted form: "head.b64iv=|b64ct=|b64mac=" -/// -public class EncryptedStringAttribute : ValidationAttribute +namespace Bit.Core.Utilities { - public EncryptedStringAttribute() - : base("{0} is not a valid encrypted string.") - { } - - public override bool IsValid(object value) + /// + /// Validates a string that is in encrypted form: "head.b64iv=|b64ct=|b64mac=" + /// + public class EncryptedStringAttribute : ValidationAttribute { - if (value == null) - { - return true; - } + public EncryptedStringAttribute() + : base("{0} is not a valid encrypted string.") + { } - try + public override bool IsValid(object value) { - var encString = value?.ToString(); - if (string.IsNullOrWhiteSpace(encString)) + if (value == null) + { + return true; + } + + try + { + var encString = value?.ToString(); + if (string.IsNullOrWhiteSpace(encString)) + { + return false; + } + + var headerPieces = encString.Split('.'); + string[] encStringPieces = null; + var encType = Enums.EncryptionType.AesCbc256_B64; + + if (headerPieces.Length == 1) + { + encStringPieces = headerPieces[0].Split('|'); + if (encStringPieces.Length == 3) + { + encType = Enums.EncryptionType.AesCbc128_HmacSha256_B64; + } + else + { + encType = Enums.EncryptionType.AesCbc256_B64; + } + } + else if (headerPieces.Length == 2) + { + encStringPieces = headerPieces[1].Split('|'); + if (!Enum.TryParse(headerPieces[0], out encType)) + { + return false; + } + } + + switch (encType) + { + case Enums.EncryptionType.AesCbc256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: + if (encStringPieces.Length != 2) + { + return false; + } + break; + case Enums.EncryptionType.AesCbc128_HmacSha256_B64: + case Enums.EncryptionType.AesCbc256_HmacSha256_B64: + if (encStringPieces.Length != 3) + { + return false; + } + break; + case Enums.EncryptionType.Rsa2048_OaepSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_B64: + if (encStringPieces.Length != 1) + { + return false; + } + break; + default: + return false; + } + + switch (encType) + { + case Enums.EncryptionType.AesCbc256_B64: + case Enums.EncryptionType.AesCbc128_HmacSha256_B64: + case Enums.EncryptionType.AesCbc256_HmacSha256_B64: + var iv = Convert.FromBase64String(encStringPieces[0]); + var ct = Convert.FromBase64String(encStringPieces[1]); + if (iv.Length < 1 || ct.Length < 1) + { + return false; + } + + if (encType == Enums.EncryptionType.AesCbc128_HmacSha256_B64 || + encType == Enums.EncryptionType.AesCbc256_HmacSha256_B64) + { + var mac = Convert.FromBase64String(encStringPieces[2]); + if (mac.Length < 1) + { + return false; + } + } + + break; + case Enums.EncryptionType.Rsa2048_OaepSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: + var rsaCt = Convert.FromBase64String(encStringPieces[0]); + if (rsaCt.Length < 1) + { + return false; + } + + if (encType == Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64 || + encType == Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64) + { + var mac = Convert.FromBase64String(encStringPieces[1]); + if (mac.Length < 1) + { + return false; + } + } + + break; + default: + return false; + } + } + catch { return false; } - var headerPieces = encString.Split('.'); - string[] encStringPieces = null; - var encType = Enums.EncryptionType.AesCbc256_B64; - - if (headerPieces.Length == 1) - { - encStringPieces = headerPieces[0].Split('|'); - if (encStringPieces.Length == 3) - { - encType = Enums.EncryptionType.AesCbc128_HmacSha256_B64; - } - else - { - encType = Enums.EncryptionType.AesCbc256_B64; - } - } - else if (headerPieces.Length == 2) - { - encStringPieces = headerPieces[1].Split('|'); - if (!Enum.TryParse(headerPieces[0], out encType)) - { - return false; - } - } - - switch (encType) - { - case Enums.EncryptionType.AesCbc256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: - if (encStringPieces.Length != 2) - { - return false; - } - break; - case Enums.EncryptionType.AesCbc128_HmacSha256_B64: - case Enums.EncryptionType.AesCbc256_HmacSha256_B64: - if (encStringPieces.Length != 3) - { - return false; - } - break; - case Enums.EncryptionType.Rsa2048_OaepSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_B64: - if (encStringPieces.Length != 1) - { - return false; - } - break; - default: - return false; - } - - switch (encType) - { - case Enums.EncryptionType.AesCbc256_B64: - case Enums.EncryptionType.AesCbc128_HmacSha256_B64: - case Enums.EncryptionType.AesCbc256_HmacSha256_B64: - var iv = Convert.FromBase64String(encStringPieces[0]); - var ct = Convert.FromBase64String(encStringPieces[1]); - if (iv.Length < 1 || ct.Length < 1) - { - return false; - } - - if (encType == Enums.EncryptionType.AesCbc128_HmacSha256_B64 || - encType == Enums.EncryptionType.AesCbc256_HmacSha256_B64) - { - var mac = Convert.FromBase64String(encStringPieces[2]); - if (mac.Length < 1) - { - return false; - } - } - - break; - case Enums.EncryptionType.Rsa2048_OaepSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: - var rsaCt = Convert.FromBase64String(encStringPieces[0]); - if (rsaCt.Length < 1) - { - return false; - } - - if (encType == Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64 || - encType == Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64) - { - var mac = Convert.FromBase64String(encStringPieces[1]); - if (mac.Length < 1) - { - return false; - } - } - - break; - default: - return false; - } + return true; } - catch - { - return false; - } - - return true; } } diff --git a/src/Core/Utilities/EpochDateTimeJsonConverter.cs b/src/Core/Utilities/EpochDateTimeJsonConverter.cs index 035da04a77..a9354fa6f8 100644 --- a/src/Core/Utilities/EpochDateTimeJsonConverter.cs +++ b/src/Core/Utilities/EpochDateTimeJsonConverter.cs @@ -1,16 +1,17 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace Bit.Core.Utilities; - -public class EpochDateTimeJsonConverter : JsonConverter +namespace Bit.Core.Utilities { - public override DateTime Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public class EpochDateTimeJsonConverter : JsonConverter { - return CoreHelpers.FromEpocMilliseconds(reader.GetInt64()); - } - public override void Write(Utf8JsonWriter writer, DateTime value, JsonSerializerOptions options) - { - writer.WriteNumberValue(CoreHelpers.ToEpocMilliseconds(value)); + public override DateTime Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return CoreHelpers.FromEpocMilliseconds(reader.GetInt64()); + } + public override void Write(Utf8JsonWriter writer, DateTime value, JsonSerializerOptions options) + { + writer.WriteNumberValue(CoreHelpers.ToEpocMilliseconds(value)); + } } } diff --git a/src/Core/Utilities/HandlebarsObjectJsonConverter.cs b/src/Core/Utilities/HandlebarsObjectJsonConverter.cs index 5651da4dc9..2ba1d40029 100644 --- a/src/Core/Utilities/HandlebarsObjectJsonConverter.cs +++ b/src/Core/Utilities/HandlebarsObjectJsonConverter.cs @@ -1,17 +1,18 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace Bit.Core.Utilities; - -public class HandlebarsObjectJsonConverter : JsonConverter +namespace Bit.Core.Utilities { - public override bool CanConvert(Type typeToConvert) => true; - public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public class HandlebarsObjectJsonConverter : JsonConverter { - return JsonSerializer.Deserialize>(ref reader, options); - } - public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options) - { - JsonSerializer.Serialize(writer, value, options); + public override bool CanConvert(Type typeToConvert) => true; + public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return JsonSerializer.Deserialize>(ref reader, options); + } + public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value, options); + } } } diff --git a/src/Core/Utilities/HostBuilderExtensions.cs b/src/Core/Utilities/HostBuilderExtensions.cs index 4806c4032a..2d54545edd 100644 --- a/src/Core/Utilities/HostBuilderExtensions.cs +++ b/src/Core/Utilities/HostBuilderExtensions.cs @@ -2,41 +2,42 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; -namespace Bit.Core.Utilities; - -public static class HostBuilderExtensions +namespace Bit.Core.Utilities { - public static IHostBuilder ConfigureCustomAppConfiguration(this IHostBuilder hostBuilder, string[] args) + public static class HostBuilderExtensions { - // Reload app configuration with SelfHosted overrides. - return hostBuilder.ConfigureAppConfiguration((hostingContext, config) => + public static IHostBuilder ConfigureCustomAppConfiguration(this IHostBuilder hostBuilder, string[] args) { - if (Environment.GetEnvironmentVariable("globalSettings__selfHosted")?.ToLower() != "true") + // Reload app configuration with SelfHosted overrides. + return hostBuilder.ConfigureAppConfiguration((hostingContext, config) => { - return; - } - - var env = hostingContext.HostingEnvironment; - - config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) - .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true) - .AddJsonFile("appsettings.SelfHosted.json", optional: true, reloadOnChange: true); - - if (env.IsDevelopment()) - { - var appAssembly = Assembly.Load(new AssemblyName(env.ApplicationName)); - if (appAssembly != null) + if (Environment.GetEnvironmentVariable("globalSettings__selfHosted")?.ToLower() != "true") { - config.AddUserSecrets(appAssembly, optional: true); + return; } - } - config.AddEnvironmentVariables(); + var env = hostingContext.HostingEnvironment; - if (args != null) - { - config.AddCommandLine(args); - } - }); + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true) + .AddJsonFile("appsettings.SelfHosted.json", optional: true, reloadOnChange: true); + + if (env.IsDevelopment()) + { + var appAssembly = Assembly.Load(new AssemblyName(env.ApplicationName)); + if (appAssembly != null) + { + config.AddUserSecrets(appAssembly, optional: true); + } + } + + config.AddEnvironmentVariables(); + + if (args != null) + { + config.AddCommandLine(args); + } + }); + } } } diff --git a/src/Core/Utilities/JsonHelpers.cs b/src/Core/Utilities/JsonHelpers.cs index ad7aefd257..b6a9481b64 100644 --- a/src/Core/Utilities/JsonHelpers.cs +++ b/src/Core/Utilities/JsonHelpers.cs @@ -3,204 +3,205 @@ using System.Text.Json; using System.Text.Json.Serialization; using NS = Newtonsoft.Json; -namespace Bit.Core.Utilities; - -public static class JsonHelpers +namespace Bit.Core.Utilities { - public static JsonSerializerOptions Default { get; } - public static JsonSerializerOptions Indented { get; } - public static JsonSerializerOptions IgnoreCase { get; } - public static JsonSerializerOptions IgnoreWritingNull { get; } - public static JsonSerializerOptions CamelCase { get; } - public static JsonSerializerOptions IgnoreWritingNullAndCamelCase { get; } - - static JsonHelpers() + public static class JsonHelpers { - Default = new JsonSerializerOptions(); + public static JsonSerializerOptions Default { get; } + public static JsonSerializerOptions Indented { get; } + public static JsonSerializerOptions IgnoreCase { get; } + public static JsonSerializerOptions IgnoreWritingNull { get; } + public static JsonSerializerOptions CamelCase { get; } + public static JsonSerializerOptions IgnoreWritingNullAndCamelCase { get; } - Indented = new JsonSerializerOptions + static JsonHelpers() { - WriteIndented = true, - }; + Default = new JsonSerializerOptions(); - IgnoreCase = new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true, - }; - - IgnoreWritingNull = new JsonSerializerOptions - { - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - }; - - CamelCase = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - IgnoreWritingNullAndCamelCase = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - }; - } - - [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] - public static T ToObject(this JsonElement element, JsonSerializerOptions options = null) - { - return JsonSerializer.Deserialize(element.GetRawText(), options ?? Default); - } - - [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] - public static T ToObject(this JsonDocument document, JsonSerializerOptions options = null) - { - return JsonSerializer.Deserialize(document.RootElement.GetRawText(), options ?? default); - } - - public static T DeserializeOrNew(string json, JsonSerializerOptions options = null) - where T : new() - { - if (string.IsNullOrWhiteSpace(json)) - { - return new T(); - } - - return JsonSerializer.Deserialize(json, options); - } - - #region Legacy Newtonsoft.Json usage - private const string LegacyMessage = "Usage of Newtonsoft.Json should be kept to a minimum and will further be removed when we move to .NET 6"; - - [Obsolete(LegacyMessage)] - public static NS.JsonSerializerSettings LegacyEnumKeyResolver { get; } = new NS.JsonSerializerSettings - { - ContractResolver = new EnumKeyResolver(), - }; - - [Obsolete(LegacyMessage)] - public static string LegacySerialize(object value, NS.JsonSerializerSettings settings = null) - { - return NS.JsonConvert.SerializeObject(value, settings); - } - - [Obsolete(LegacyMessage)] - public static T LegacyDeserialize(string value, NS.JsonSerializerSettings settings = null) - { - return NS.JsonConvert.DeserializeObject(value, settings); - } - #endregion -} - -public class EnumKeyResolver : NS.Serialization.DefaultContractResolver - where T : struct -{ - protected override NS.Serialization.JsonDictionaryContract CreateDictionaryContract(Type objectType) - { - var contract = base.CreateDictionaryContract(objectType); - var keyType = contract.DictionaryKeyType; - - if (keyType.BaseType == typeof(Enum)) - { - contract.DictionaryKeyResolver = propName => ((T)Enum.Parse(keyType, propName)).ToString(); - } - - return contract; - } -} - -public class MsEpochConverter : JsonConverter -{ - public override DateTime? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - if (reader.TokenType == JsonTokenType.Null) - { - return null; - } - - if (!long.TryParse(reader.GetString(), out var milliseconds)) - { - return null; - } - - return CoreHelpers.FromEpocMilliseconds(milliseconds); - } - - public override void Write(Utf8JsonWriter writer, DateTime? value, JsonSerializerOptions options) - { - if (!value.HasValue) - { - writer.WriteNullValue(); - } - - writer.WriteStringValue(CoreHelpers.ToEpocMilliseconds(value.Value).ToString()); - } -} - -/// -/// Allows reading a string from a JSON number or string, should only be used on properties -/// -public class PermissiveStringConverter : JsonConverter -{ - internal static readonly PermissiveStringConverter Instance = new(); - private static readonly CultureInfo _cultureInfo = new("en-US"); - - public override string Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - return reader.TokenType switch - { - JsonTokenType.String => reader.GetString(), - JsonTokenType.Number => reader.GetDecimal().ToString(_cultureInfo), - JsonTokenType.True => bool.TrueString, - JsonTokenType.False => bool.FalseString, - _ => throw new JsonException($"Unsupported TokenType: {reader.TokenType}"), - }; - } - - public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOptions options) - { - writer.WriteStringValue(value); - } -} - -/// -/// Allows reading a JSON array of number or string, should only be used on whose generic type is -/// -public class PermissiveStringEnumerableConverter : JsonConverter> -{ - public override IEnumerable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var stringList = new List(); - - // Handle special cases or throw - if (reader.TokenType != JsonTokenType.StartArray) - { - // An array was expected but to be extra permissive allow reading from anything other than an object - if (reader.TokenType == JsonTokenType.StartObject) + Indented = new JsonSerializerOptions { - throw new JsonException("Cannot read JSON Object to an IEnumerable."); + WriteIndented = true, + }; + + IgnoreCase = new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true, + }; + + IgnoreWritingNull = new JsonSerializerOptions + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + + CamelCase = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + IgnoreWritingNullAndCamelCase = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + } + + [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] + public static T ToObject(this JsonElement element, JsonSerializerOptions options = null) + { + return JsonSerializer.Deserialize(element.GetRawText(), options ?? Default); + } + + [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] + public static T ToObject(this JsonDocument document, JsonSerializerOptions options = null) + { + return JsonSerializer.Deserialize(document.RootElement.GetRawText(), options ?? default); + } + + public static T DeserializeOrNew(string json, JsonSerializerOptions options = null) + where T : new() + { + if (string.IsNullOrWhiteSpace(json)) + { + return new T(); + } + + return JsonSerializer.Deserialize(json, options); + } + + #region Legacy Newtonsoft.Json usage + private const string LegacyMessage = "Usage of Newtonsoft.Json should be kept to a minimum and will further be removed when we move to .NET 6"; + + [Obsolete(LegacyMessage)] + public static NS.JsonSerializerSettings LegacyEnumKeyResolver { get; } = new NS.JsonSerializerSettings + { + ContractResolver = new EnumKeyResolver(), + }; + + [Obsolete(LegacyMessage)] + public static string LegacySerialize(object value, NS.JsonSerializerSettings settings = null) + { + return NS.JsonConvert.SerializeObject(value, settings); + } + + [Obsolete(LegacyMessage)] + public static T LegacyDeserialize(string value, NS.JsonSerializerSettings settings = null) + { + return NS.JsonConvert.DeserializeObject(value, settings); + } + #endregion + } + + public class EnumKeyResolver : NS.Serialization.DefaultContractResolver + where T : struct + { + protected override NS.Serialization.JsonDictionaryContract CreateDictionaryContract(Type objectType) + { + var contract = base.CreateDictionaryContract(objectType); + var keyType = contract.DictionaryKeyType; + + if (keyType.BaseType == typeof(Enum)) + { + contract.DictionaryKeyResolver = propName => ((T)Enum.Parse(keyType, propName)).ToString(); + } + + return contract; + } + } + + public class MsEpochConverter : JsonConverter + { + public override DateTime? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (!long.TryParse(reader.GetString(), out var milliseconds)) + { + return null; + } + + return CoreHelpers.FromEpocMilliseconds(milliseconds); + } + + public override void Write(Utf8JsonWriter writer, DateTime? value, JsonSerializerOptions options) + { + if (!value.HasValue) + { + writer.WriteNullValue(); + } + + writer.WriteStringValue(CoreHelpers.ToEpocMilliseconds(value.Value).ToString()); + } + } + + /// + /// Allows reading a string from a JSON number or string, should only be used on properties + /// + public class PermissiveStringConverter : JsonConverter + { + internal static readonly PermissiveStringConverter Instance = new(); + private static readonly CultureInfo _cultureInfo = new("en-US"); + + public override string Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return reader.TokenType switch + { + JsonTokenType.String => reader.GetString(), + JsonTokenType.Number => reader.GetDecimal().ToString(_cultureInfo), + JsonTokenType.True => bool.TrueString, + JsonTokenType.False => bool.FalseString, + _ => throw new JsonException($"Unsupported TokenType: {reader.TokenType}"), + }; + } + + public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOptions options) + { + writer.WriteStringValue(value); + } + } + + /// + /// Allows reading a JSON array of number or string, should only be used on whose generic type is + /// + public class PermissiveStringEnumerableConverter : JsonConverter> + { + public override IEnumerable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var stringList = new List(); + + // Handle special cases or throw + if (reader.TokenType != JsonTokenType.StartArray) + { + // An array was expected but to be extra permissive allow reading from anything other than an object + if (reader.TokenType == JsonTokenType.StartObject) + { + throw new JsonException("Cannot read JSON Object to an IEnumerable."); + } + + stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); + return stringList; + } + + while (reader.Read() && reader.TokenType != JsonTokenType.EndArray) + { + stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); } - stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); return stringList; } - while (reader.Read() && reader.TokenType != JsonTokenType.EndArray) + public override void Write(Utf8JsonWriter writer, IEnumerable value, JsonSerializerOptions options) { - stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); + writer.WriteStartArray(); + + foreach (var str in value) + { + PermissiveStringConverter.Instance.Write(writer, str, options); + } + + writer.WriteEndArray(); } - - return stringList; - } - - public override void Write(Utf8JsonWriter writer, IEnumerable value, JsonSerializerOptions options) - { - writer.WriteStartArray(); - - foreach (var str in value) - { - PermissiveStringConverter.Instance.Write(writer, str, options); - } - - writer.WriteEndArray(); } } diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index 792225cdfb..98896c56eb 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -10,136 +10,137 @@ using Serilog; using Serilog.Events; using Serilog.Sinks.Syslog; -namespace Bit.Core.Utilities; - -public static class LoggerFactoryExtensions +namespace Bit.Core.Utilities { - public static void UseSerilog( - this IApplicationBuilder appBuilder, - IWebHostEnvironment env, - IHostApplicationLifetime applicationLifetime, - GlobalSettings globalSettings) + public static class LoggerFactoryExtensions { - if (env.IsDevelopment()) + public static void UseSerilog( + this IApplicationBuilder appBuilder, + IWebHostEnvironment env, + IHostApplicationLifetime applicationLifetime, + GlobalSettings globalSettings) { - return; + if (env.IsDevelopment()) + { + return; + } + + applicationLifetime.ApplicationStopped.Register(Log.CloseAndFlush); } - applicationLifetime.ApplicationStopped.Register(Log.CloseAndFlush); - } - - public static ILoggingBuilder AddSerilog( - this ILoggingBuilder builder, - WebHostBuilderContext context, - Func filter = null) - { - if (context.HostingEnvironment.IsDevelopment()) + public static ILoggingBuilder AddSerilog( + this ILoggingBuilder builder, + WebHostBuilderContext context, + Func filter = null) { + if (context.HostingEnvironment.IsDevelopment()) + { + return builder; + } + + bool inclusionPredicate(LogEvent e) + { + if (filter == null) + { + return true; + } + var eventId = e.Properties.ContainsKey("EventId") ? e.Properties["EventId"].ToString() : null; + if (eventId?.Contains(Constants.BypassFiltersEventId.ToString()) ?? false) + { + return true; + } + return filter(e); + } + + var globalSettings = new GlobalSettings(); + ConfigurationBinder.Bind(context.Configuration.GetSection("GlobalSettings"), globalSettings); + + var config = new LoggerConfiguration() + .Enrich.FromLogContext() + .Filter.ByIncludingOnly(inclusionPredicate); + + if (CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Uri) && + CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Key)) + { + config.WriteTo.AzureCosmosDB(new Uri(globalSettings.DocumentDb.Uri), + globalSettings.DocumentDb.Key, timeToLive: TimeSpan.FromDays(7), + partitionKey: "_partitionKey") + .Enrich.FromLogContext() + .Enrich.WithProperty("Project", globalSettings.ProjectName); + } + else if (CoreHelpers.SettingHasValue(globalSettings?.Sentry.Dsn)) + { + config.WriteTo.Sentry(globalSettings.Sentry.Dsn) + .Enrich.FromLogContext() + .Enrich.WithProperty("Project", globalSettings.ProjectName); + } + else if (CoreHelpers.SettingHasValue(globalSettings?.Syslog.Destination)) + { + // appending sitename to project name to allow eaiser identification in syslog. + var appName = $"{globalSettings.SiteName}-{globalSettings.ProjectName}"; + if (globalSettings.Syslog.Destination.Equals("local", StringComparison.OrdinalIgnoreCase)) + { + config.WriteTo.LocalSyslog(appName); + } + else if (Uri.TryCreate(globalSettings.Syslog.Destination, UriKind.Absolute, out var syslogAddress)) + { + // Syslog's standard port is 514 (both UDP and TCP). TLS does not have a standard port, so assume 514. + int port = syslogAddress.Port >= 0 + ? syslogAddress.Port + : 514; + + if (syslogAddress.Scheme.Equals("udp")) + { + config.WriteTo.UdpSyslog(syslogAddress.Host, port, appName); + } + else if (syslogAddress.Scheme.Equals("tcp")) + { + config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName); + } + else if (syslogAddress.Scheme.Equals("tls")) + { + // TLS v1.1, v1.2 and v1.3 are explicitly selected (leaving out TLS v1.0) + const SslProtocols protocols = SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13; + + if (CoreHelpers.SettingHasValue(globalSettings.Syslog.CertificateThumbprint)) + { + config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, + secureProtocols: protocols, + certProvider: new CertificateStoreProvider(StoreName.My, StoreLocation.CurrentUser, + globalSettings.Syslog.CertificateThumbprint)); + } + else + { + config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, + secureProtocols: protocols, + certProvider: new CertificateFileProvider(globalSettings.Syslog.CertificatePath, + globalSettings.Syslog?.CertificatePassword ?? string.Empty)); + } + + } + } + } + else if (CoreHelpers.SettingHasValue(globalSettings.LogDirectory)) + { + if (globalSettings.LogRollBySizeLimit.HasValue) + { + config.WriteTo.File($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/log.txt", + rollOnFileSizeLimit: true, fileSizeLimitBytes: globalSettings.LogRollBySizeLimit); + } + else + { + config.WriteTo + .RollingFile($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/{{Date}}.txt"); + } + config + .Enrich.FromLogContext() + .Enrich.WithProperty("Project", globalSettings.ProjectName); + } + + var serilog = config.CreateLogger(); + builder.AddSerilog(serilog); + return builder; } - - bool inclusionPredicate(LogEvent e) - { - if (filter == null) - { - return true; - } - var eventId = e.Properties.ContainsKey("EventId") ? e.Properties["EventId"].ToString() : null; - if (eventId?.Contains(Constants.BypassFiltersEventId.ToString()) ?? false) - { - return true; - } - return filter(e); - } - - var globalSettings = new GlobalSettings(); - ConfigurationBinder.Bind(context.Configuration.GetSection("GlobalSettings"), globalSettings); - - var config = new LoggerConfiguration() - .Enrich.FromLogContext() - .Filter.ByIncludingOnly(inclusionPredicate); - - if (CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Uri) && - CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Key)) - { - config.WriteTo.AzureCosmosDB(new Uri(globalSettings.DocumentDb.Uri), - globalSettings.DocumentDb.Key, timeToLive: TimeSpan.FromDays(7), - partitionKey: "_partitionKey") - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - else if (CoreHelpers.SettingHasValue(globalSettings?.Sentry.Dsn)) - { - config.WriteTo.Sentry(globalSettings.Sentry.Dsn) - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - else if (CoreHelpers.SettingHasValue(globalSettings?.Syslog.Destination)) - { - // appending sitename to project name to allow eaiser identification in syslog. - var appName = $"{globalSettings.SiteName}-{globalSettings.ProjectName}"; - if (globalSettings.Syslog.Destination.Equals("local", StringComparison.OrdinalIgnoreCase)) - { - config.WriteTo.LocalSyslog(appName); - } - else if (Uri.TryCreate(globalSettings.Syslog.Destination, UriKind.Absolute, out var syslogAddress)) - { - // Syslog's standard port is 514 (both UDP and TCP). TLS does not have a standard port, so assume 514. - int port = syslogAddress.Port >= 0 - ? syslogAddress.Port - : 514; - - if (syslogAddress.Scheme.Equals("udp")) - { - config.WriteTo.UdpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tcp")) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tls")) - { - // TLS v1.1, v1.2 and v1.3 are explicitly selected (leaving out TLS v1.0) - const SslProtocols protocols = SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13; - - if (CoreHelpers.SettingHasValue(globalSettings.Syslog.CertificateThumbprint)) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - secureProtocols: protocols, - certProvider: new CertificateStoreProvider(StoreName.My, StoreLocation.CurrentUser, - globalSettings.Syslog.CertificateThumbprint)); - } - else - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - secureProtocols: protocols, - certProvider: new CertificateFileProvider(globalSettings.Syslog.CertificatePath, - globalSettings.Syslog?.CertificatePassword ?? string.Empty)); - } - - } - } - } - else if (CoreHelpers.SettingHasValue(globalSettings.LogDirectory)) - { - if (globalSettings.LogRollBySizeLimit.HasValue) - { - config.WriteTo.File($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/log.txt", - rollOnFileSizeLimit: true, fileSizeLimitBytes: globalSettings.LogRollBySizeLimit); - } - else - { - config.WriteTo - .RollingFile($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/{{Date}}.txt"); - } - config - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - - var serilog = config.CreateLogger(); - builder.AddSerilog(serilog); - - return builder; } } diff --git a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs index 6709bbb271..8df51b1e5f 100644 --- a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs +++ b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs @@ -2,21 +2,22 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -namespace Bit.Core.Utilities; - -public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute +namespace Bit.Core.Utilities { - public override void OnException(ExceptionContext context) + public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute { - var exception = context.Exception; - if (exception == null) + public override void OnException(ExceptionContext context) { - // Should never happen. - return; - } + var exception = context.Exception; + if (exception == null) + { + // Should never happen. + return; + } - var logger = context.HttpContext.RequestServices - .GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + var logger = context.HttpContext.RequestServices + .GetRequiredService>(); + logger.LogError(0, exception, exception.Message); + } } } diff --git a/src/Core/Utilities/SecurityHeadersMiddleware.cs b/src/Core/Utilities/SecurityHeadersMiddleware.cs index 19616e8a74..3a1cc477ec 100644 --- a/src/Core/Utilities/SecurityHeadersMiddleware.cs +++ b/src/Core/Utilities/SecurityHeadersMiddleware.cs @@ -1,28 +1,29 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; -namespace Bit.Core.Utilities; - -public sealed class SecurityHeadersMiddleware +namespace Bit.Core.Utilities { - private readonly RequestDelegate _next; - - public SecurityHeadersMiddleware(RequestDelegate next) + public sealed class SecurityHeadersMiddleware { - _next = next; - } + private readonly RequestDelegate _next; - public Task Invoke(HttpContext context) - { - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Frame-Options - context.Response.Headers.Add("x-frame-options", new StringValues("SAMEORIGIN")); + public SecurityHeadersMiddleware(RequestDelegate next) + { + _next = next; + } - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-XSS-Protection - context.Response.Headers.Add("x-xss-protection", new StringValues("1; mode=block")); + public Task Invoke(HttpContext context) + { + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Frame-Options + context.Response.Headers.Add("x-frame-options", new StringValues("SAMEORIGIN")); - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Content-Type-Options - context.Response.Headers.Add("x-content-type-options", new StringValues("nosniff")); + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-XSS-Protection + context.Response.Headers.Add("x-xss-protection", new StringValues("1; mode=block")); - return _next(context); + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Content-Type-Options + context.Response.Headers.Add("x-content-type-options", new StringValues("nosniff")); + + return _next(context); + } } } diff --git a/src/Core/Utilities/SelfHostedAttribute.cs b/src/Core/Utilities/SelfHostedAttribute.cs index f4ea835922..13dc83fa72 100644 --- a/src/Core/Utilities/SelfHostedAttribute.cs +++ b/src/Core/Utilities/SelfHostedAttribute.cs @@ -3,23 +3,24 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Utilities; - -public class SelfHostedAttribute : ActionFilterAttribute +namespace Bit.Core.Utilities { - public bool SelfHostedOnly { get; set; } - public bool NotSelfHostedOnly { get; set; } - - public override void OnActionExecuting(ActionExecutingContext context) + public class SelfHostedAttribute : ActionFilterAttribute { - var globalSettings = context.HttpContext.RequestServices.GetRequiredService(); - if (SelfHostedOnly && !globalSettings.SelfHosted) + public bool SelfHostedOnly { get; set; } + public bool NotSelfHostedOnly { get; set; } + + public override void OnActionExecuting(ActionExecutingContext context) { - throw new BadRequestException("Only allowed when self hosted."); - } - else if (NotSelfHostedOnly && globalSettings.SelfHosted) - { - throw new BadRequestException("Only allowed when not self hosted."); + var globalSettings = context.HttpContext.RequestServices.GetRequiredService(); + if (SelfHostedOnly && !globalSettings.SelfHosted) + { + throw new BadRequestException("Only allowed when self hosted."); + } + else if (NotSelfHostedOnly && globalSettings.SelfHosted) + { + throw new BadRequestException("Only allowed when not self hosted."); + } } } } diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index 053f7ed452..0b8cb61bf4 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -2,500 +2,501 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Utilities; - -public class StaticStore +namespace Bit.Core.Utilities { - static StaticStore() + public class StaticStore { - #region Global Domains - - GlobalDomains = new Dictionary>(); - - GlobalDomains.Add(GlobalEquivalentDomainsType.Ameritrade, new List { "ameritrade.com", "tdameritrade.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.BoA, new List { "bankofamerica.com", "bofa.com", "mbna.com", "usecfo.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sprint, new List { "sprint.com", "sprintpcs.com", "nextel.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Google, new List { "youtube.com", "google.com", "gmail.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Apple, new List { "apple.com", "icloud.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.WellsFargo, new List { "wellsfargo.com", "wf.com", "wellsfargoadvisors.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Merrill, new List { "mymerrill.com", "ml.com", "merrilledge.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Citi, new List { "accountonline.com", "citi.com", "citibank.com", "citicards.com", "citibankonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cnet, new List { "cnet.com", "cnettv.com", "com.com", "download.com", "news.com", "search.com", "upload.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Gap, new List { "bananarepublic.com", "gap.com", "oldnavy.com", "piperlime.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Microsoft, new List { "bing.com", "hotmail.com", "live.com", "microsoft.com", "msn.com", "passport.net", "windows.com", "microsoftonline.com", "office.com", "office365.com", "microsoftstore.com", "xbox.com", "azure.com", "windowsazure.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.United, new List { "ua2go.com", "ual.com", "united.com", "unitedwifi.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Yahoo, new List { "overture.com", "yahoo.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Zonelabs, new List { "zonealarm.com", "zonelabs.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.PayPal, new List { "paypal.com", "paypal-search.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Avon, new List { "avon.com", "youravon.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Diapers, new List { "diapers.com", "soap.com", "wag.com", "yoyo.com", "beautybar.com", "casa.com", "afterschool.com", "vine.com", "bookworm.com", "look.com", "vinemarket.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Contacts, new List { "1800contacts.com", "800contacts.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Amazon, new List { "amazon.com", "amazon.ae", "amazon.ca", "amazon.co.uk", "amazon.com.au", "amazon.com.br", "amazon.com.mx", "amazon.com.tr", "amazon.de", "amazon.es", "amazon.fr", "amazon.in", "amazon.it", "amazon.nl", "amazon.pl", "amazon.sa", "amazon.se", "amazon.sg" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cox, new List { "cox.com", "cox.net", "coxbusiness.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Norton, new List { "mynortonaccount.com", "norton.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Verizon, new List { "verizon.com", "verizon.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Buy, new List { "rakuten.com", "buy.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sirius, new List { "siriusxm.com", "sirius.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ea, new List { "ea.com", "origin.com", "play4free.com", "tiberiumalliance.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Basecamp, new List { "37signals.com", "basecamp.com", "basecamphq.com", "highrisehq.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Steam, new List { "steampowered.com", "steamcommunity.com", "steamgames.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Chart, new List { "chart.io", "chartio.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Gotomeeting, new List { "gotomeeting.com", "citrixonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Gogo, new List { "gogoair.com", "gogoinflight.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Oracle, new List { "mysql.com", "oracle.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Discover, new List { "discover.com", "discovercard.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Dcu, new List { "dcu.org", "dcu-online.org" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Healthcare, new List { "healthcare.gov", "cuidadodesalud.gov", "cms.gov" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Pepco, new List { "pepco.com", "pepcoholdings.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Century21, new List { "century21.com", "21online.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Comcast, new List { "comcast.com", "comcast.net", "xfinity.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cricket, new List { "cricketwireless.com", "aiowireless.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mtb, new List { "mandtbank.com", "mtb.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Dropbox, new List { "dropbox.com", "getdropbox.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Snapfish, new List { "snapfish.com", "snapfish.ca" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Alibaba, new List { "alibaba.com", "aliexpress.com", "aliyun.com", "net.cn" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Playstation, new List { "playstation.com", "sonyentertainmentnetwork.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mercado, new List { "mercadolivre.com", "mercadolivre.com.br", "mercadolibre.com", "mercadolibre.com.ar", "mercadolibre.com.mx" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Zendesk, new List { "zendesk.com", "zopim.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Autodesk, new List { "autodesk.com", "tinkercad.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.RailNation, new List { "railnation.ru", "railnation.de", "rail-nation.com", "railnation.gr", "railnation.us", "trucknation.de", "traviangames.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Wpcu, new List { "wpcu.coop", "wpcuonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mathletics, new List { "mathletics.com", "mathletics.com.au", "mathletics.co.uk" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Discountbank, new List { "discountbank.co.il", "telebank.co.il" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mi, new List { "mi.com", "xiaomi.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Postepay, new List { "postepay.it", "poste.it" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Facebook, new List { "facebook.com", "messenger.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Skysports, new List { "skysports.com", "skybet.com", "skyvegas.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Disney, new List { "disneymoviesanywhere.com", "go.com", "disney.com", "dadt.com", "disneyplus.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Pokemon, new List { "pokemon-gl.com", "pokemon.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Uv, new List { "myuv.com", "uvvu.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mdsol, new List { "mdsol.com", "imedidata.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Yahavo, new List { "bank-yahav.co.il", "bankhapoalim.co.il" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sears, new List { "sears.com", "shld.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Xiami, new List { "xiami.com", "alipay.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Belkin, new List { "belkin.com", "seedonk.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Turbotax, new List { "turbotax.com", "intuit.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Shopify, new List { "shopify.com", "myshopify.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ebay, new List { "ebay.com", "ebay.at", "ebay.be", "ebay.ca", "ebay.ch", "ebay.cn", "ebay.co.jp", "ebay.co.th", "ebay.co.uk", "ebay.com.au", "ebay.com.hk", "ebay.com.my", "ebay.com.sg", "ebay.com.tw", "ebay.de", "ebay.es", "ebay.fr", "ebay.ie", "ebay.in", "ebay.it", "ebay.nl", "ebay.ph", "ebay.pl" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Techdata, new List { "techdata.com", "techdata.ch" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Schwab, new List { "schwab.com", "schwabplan.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Tesla, new List { "tesla.com", "teslamotors.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.MorganStanley, new List { "morganstanley.com", "morganstanleyclientserv.com", "stockplanconnect.com", "ms.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.TaxAct, new List { "taxact.com", "taxactonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Wikimedia, new List { "mediawiki.org", "wikibooks.org", "wikidata.org", "wikimedia.org", "wikinews.org", "wikipedia.org", "wikiquote.org", "wikisource.org", "wikiversity.org", "wikivoyage.org", "wiktionary.org" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Airbnb, new List { "airbnb.at", "airbnb.be", "airbnb.ca", "airbnb.ch", "airbnb.cl", "airbnb.co.cr", "airbnb.co.id", "airbnb.co.in", "airbnb.co.kr", "airbnb.co.nz", "airbnb.co.uk", "airbnb.co.ve", "airbnb.com", "airbnb.com.ar", "airbnb.com.au", "airbnb.com.bo", "airbnb.com.br", "airbnb.com.bz", "airbnb.com.co", "airbnb.com.ec", "airbnb.com.gt", "airbnb.com.hk", "airbnb.com.hn", "airbnb.com.mt", "airbnb.com.my", "airbnb.com.ni", "airbnb.com.pa", "airbnb.com.pe", "airbnb.com.py", "airbnb.com.sg", "airbnb.com.sv", "airbnb.com.tr", "airbnb.com.tw", "airbnb.cz", "airbnb.de", "airbnb.dk", "airbnb.es", "airbnb.fi", "airbnb.fr", "airbnb.gr", "airbnb.gy", "airbnb.hu", "airbnb.ie", "airbnb.is", "airbnb.it", "airbnb.jp", "airbnb.mx", "airbnb.nl", "airbnb.no", "airbnb.pl", "airbnb.pt", "airbnb.ru", "airbnb.se" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Eventbrite, new List { "eventbrite.at", "eventbrite.be", "eventbrite.ca", "eventbrite.ch", "eventbrite.cl", "eventbrite.co", "eventbrite.co.nz", "eventbrite.co.uk", "eventbrite.com", "eventbrite.com.ar", "eventbrite.com.au", "eventbrite.com.br", "eventbrite.com.mx", "eventbrite.com.pe", "eventbrite.de", "eventbrite.dk", "eventbrite.es", "eventbrite.fi", "eventbrite.fr", "eventbrite.hk", "eventbrite.ie", "eventbrite.it", "eventbrite.nl", "eventbrite.pt", "eventbrite.se", "eventbrite.sg" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.StackExchange, new List { "stackexchange.com", "superuser.com", "stackoverflow.com", "serverfault.com", "mathoverflow.net", "askubuntu.com", "stackapps.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Docusign, new List { "docusign.com", "docusign.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Envato, new List { "envato.com", "themeforest.net", "codecanyon.net", "videohive.net", "audiojungle.net", "graphicriver.net", "photodune.net", "3docean.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.X10Hosting, new List { "x10hosting.com", "x10premium.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cisco, new List { "dnsomatic.com", "opendns.com", "umbrella.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.CedarFair, new List { "cagreatamerica.com", "canadaswonderland.com", "carowinds.com", "cedarfair.com", "cedarpoint.com", "dorneypark.com", "kingsdominion.com", "knotts.com", "miadventure.com", "schlitterbahn.com", "valleyfair.com", "visitkingsisland.com", "worldsoffun.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ubiquiti, new List { "ubnt.com", "ui.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Discord, new List { "discordapp.com", "discord.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Netcup, new List { "netcup.de", "netcup.eu", "customercontrolpanel.de" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Yandex, new List { "yandex.com", "ya.ru", "yandex.az", "yandex.by", "yandex.co.il", "yandex.com.am", "yandex.com.ge", "yandex.com.tr", "yandex.ee", "yandex.fi", "yandex.fr", "yandex.kg", "yandex.kz", "yandex.lt", "yandex.lv", "yandex.md", "yandex.pl", "yandex.ru", "yandex.tj", "yandex.tm", "yandex.ua", "yandex.uz" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sony, new List { "sonyentertainmentnetwork.com", "sony.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Proton, new List { "proton.me", "protonmail.com", "protonvpn.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ubisoft, new List { "ubisoft.com", "ubi.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.TransferWise, new List { "transferwise.com", "wise.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.TakeawayEU, new List { "takeaway.com", "just-eat.dk", "just-eat.no", "just-eat.fr", "just-eat.ch", "lieferando.de", "lieferando.at", "thuisbezorgd.nl", "pyszne.pl" }); - #endregion - - #region Plans - - Plans = new List + static StaticStore() { - new Plan + #region Global Domains + + GlobalDomains = new Dictionary>(); + + GlobalDomains.Add(GlobalEquivalentDomainsType.Ameritrade, new List { "ameritrade.com", "tdameritrade.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.BoA, new List { "bankofamerica.com", "bofa.com", "mbna.com", "usecfo.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sprint, new List { "sprint.com", "sprintpcs.com", "nextel.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Google, new List { "youtube.com", "google.com", "gmail.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Apple, new List { "apple.com", "icloud.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.WellsFargo, new List { "wellsfargo.com", "wf.com", "wellsfargoadvisors.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Merrill, new List { "mymerrill.com", "ml.com", "merrilledge.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Citi, new List { "accountonline.com", "citi.com", "citibank.com", "citicards.com", "citibankonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cnet, new List { "cnet.com", "cnettv.com", "com.com", "download.com", "news.com", "search.com", "upload.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Gap, new List { "bananarepublic.com", "gap.com", "oldnavy.com", "piperlime.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Microsoft, new List { "bing.com", "hotmail.com", "live.com", "microsoft.com", "msn.com", "passport.net", "windows.com", "microsoftonline.com", "office.com", "office365.com", "microsoftstore.com", "xbox.com", "azure.com", "windowsazure.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.United, new List { "ua2go.com", "ual.com", "united.com", "unitedwifi.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Yahoo, new List { "overture.com", "yahoo.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Zonelabs, new List { "zonealarm.com", "zonelabs.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.PayPal, new List { "paypal.com", "paypal-search.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Avon, new List { "avon.com", "youravon.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Diapers, new List { "diapers.com", "soap.com", "wag.com", "yoyo.com", "beautybar.com", "casa.com", "afterschool.com", "vine.com", "bookworm.com", "look.com", "vinemarket.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Contacts, new List { "1800contacts.com", "800contacts.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Amazon, new List { "amazon.com", "amazon.ae", "amazon.ca", "amazon.co.uk", "amazon.com.au", "amazon.com.br", "amazon.com.mx", "amazon.com.tr", "amazon.de", "amazon.es", "amazon.fr", "amazon.in", "amazon.it", "amazon.nl", "amazon.pl", "amazon.sa", "amazon.se", "amazon.sg" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cox, new List { "cox.com", "cox.net", "coxbusiness.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Norton, new List { "mynortonaccount.com", "norton.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Verizon, new List { "verizon.com", "verizon.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Buy, new List { "rakuten.com", "buy.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sirius, new List { "siriusxm.com", "sirius.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ea, new List { "ea.com", "origin.com", "play4free.com", "tiberiumalliance.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Basecamp, new List { "37signals.com", "basecamp.com", "basecamphq.com", "highrisehq.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Steam, new List { "steampowered.com", "steamcommunity.com", "steamgames.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Chart, new List { "chart.io", "chartio.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Gotomeeting, new List { "gotomeeting.com", "citrixonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Gogo, new List { "gogoair.com", "gogoinflight.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Oracle, new List { "mysql.com", "oracle.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Discover, new List { "discover.com", "discovercard.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Dcu, new List { "dcu.org", "dcu-online.org" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Healthcare, new List { "healthcare.gov", "cuidadodesalud.gov", "cms.gov" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Pepco, new List { "pepco.com", "pepcoholdings.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Century21, new List { "century21.com", "21online.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Comcast, new List { "comcast.com", "comcast.net", "xfinity.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cricket, new List { "cricketwireless.com", "aiowireless.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mtb, new List { "mandtbank.com", "mtb.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Dropbox, new List { "dropbox.com", "getdropbox.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Snapfish, new List { "snapfish.com", "snapfish.ca" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Alibaba, new List { "alibaba.com", "aliexpress.com", "aliyun.com", "net.cn" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Playstation, new List { "playstation.com", "sonyentertainmentnetwork.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mercado, new List { "mercadolivre.com", "mercadolivre.com.br", "mercadolibre.com", "mercadolibre.com.ar", "mercadolibre.com.mx" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Zendesk, new List { "zendesk.com", "zopim.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Autodesk, new List { "autodesk.com", "tinkercad.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.RailNation, new List { "railnation.ru", "railnation.de", "rail-nation.com", "railnation.gr", "railnation.us", "trucknation.de", "traviangames.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Wpcu, new List { "wpcu.coop", "wpcuonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mathletics, new List { "mathletics.com", "mathletics.com.au", "mathletics.co.uk" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Discountbank, new List { "discountbank.co.il", "telebank.co.il" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mi, new List { "mi.com", "xiaomi.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Postepay, new List { "postepay.it", "poste.it" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Facebook, new List { "facebook.com", "messenger.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Skysports, new List { "skysports.com", "skybet.com", "skyvegas.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Disney, new List { "disneymoviesanywhere.com", "go.com", "disney.com", "dadt.com", "disneyplus.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Pokemon, new List { "pokemon-gl.com", "pokemon.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Uv, new List { "myuv.com", "uvvu.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mdsol, new List { "mdsol.com", "imedidata.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Yahavo, new List { "bank-yahav.co.il", "bankhapoalim.co.il" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sears, new List { "sears.com", "shld.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Xiami, new List { "xiami.com", "alipay.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Belkin, new List { "belkin.com", "seedonk.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Turbotax, new List { "turbotax.com", "intuit.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Shopify, new List { "shopify.com", "myshopify.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ebay, new List { "ebay.com", "ebay.at", "ebay.be", "ebay.ca", "ebay.ch", "ebay.cn", "ebay.co.jp", "ebay.co.th", "ebay.co.uk", "ebay.com.au", "ebay.com.hk", "ebay.com.my", "ebay.com.sg", "ebay.com.tw", "ebay.de", "ebay.es", "ebay.fr", "ebay.ie", "ebay.in", "ebay.it", "ebay.nl", "ebay.ph", "ebay.pl" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Techdata, new List { "techdata.com", "techdata.ch" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Schwab, new List { "schwab.com", "schwabplan.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Tesla, new List { "tesla.com", "teslamotors.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.MorganStanley, new List { "morganstanley.com", "morganstanleyclientserv.com", "stockplanconnect.com", "ms.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.TaxAct, new List { "taxact.com", "taxactonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Wikimedia, new List { "mediawiki.org", "wikibooks.org", "wikidata.org", "wikimedia.org", "wikinews.org", "wikipedia.org", "wikiquote.org", "wikisource.org", "wikiversity.org", "wikivoyage.org", "wiktionary.org" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Airbnb, new List { "airbnb.at", "airbnb.be", "airbnb.ca", "airbnb.ch", "airbnb.cl", "airbnb.co.cr", "airbnb.co.id", "airbnb.co.in", "airbnb.co.kr", "airbnb.co.nz", "airbnb.co.uk", "airbnb.co.ve", "airbnb.com", "airbnb.com.ar", "airbnb.com.au", "airbnb.com.bo", "airbnb.com.br", "airbnb.com.bz", "airbnb.com.co", "airbnb.com.ec", "airbnb.com.gt", "airbnb.com.hk", "airbnb.com.hn", "airbnb.com.mt", "airbnb.com.my", "airbnb.com.ni", "airbnb.com.pa", "airbnb.com.pe", "airbnb.com.py", "airbnb.com.sg", "airbnb.com.sv", "airbnb.com.tr", "airbnb.com.tw", "airbnb.cz", "airbnb.de", "airbnb.dk", "airbnb.es", "airbnb.fi", "airbnb.fr", "airbnb.gr", "airbnb.gy", "airbnb.hu", "airbnb.ie", "airbnb.is", "airbnb.it", "airbnb.jp", "airbnb.mx", "airbnb.nl", "airbnb.no", "airbnb.pl", "airbnb.pt", "airbnb.ru", "airbnb.se" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Eventbrite, new List { "eventbrite.at", "eventbrite.be", "eventbrite.ca", "eventbrite.ch", "eventbrite.cl", "eventbrite.co", "eventbrite.co.nz", "eventbrite.co.uk", "eventbrite.com", "eventbrite.com.ar", "eventbrite.com.au", "eventbrite.com.br", "eventbrite.com.mx", "eventbrite.com.pe", "eventbrite.de", "eventbrite.dk", "eventbrite.es", "eventbrite.fi", "eventbrite.fr", "eventbrite.hk", "eventbrite.ie", "eventbrite.it", "eventbrite.nl", "eventbrite.pt", "eventbrite.se", "eventbrite.sg" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.StackExchange, new List { "stackexchange.com", "superuser.com", "stackoverflow.com", "serverfault.com", "mathoverflow.net", "askubuntu.com", "stackapps.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Docusign, new List { "docusign.com", "docusign.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Envato, new List { "envato.com", "themeforest.net", "codecanyon.net", "videohive.net", "audiojungle.net", "graphicriver.net", "photodune.net", "3docean.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.X10Hosting, new List { "x10hosting.com", "x10premium.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cisco, new List { "dnsomatic.com", "opendns.com", "umbrella.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.CedarFair, new List { "cagreatamerica.com", "canadaswonderland.com", "carowinds.com", "cedarfair.com", "cedarpoint.com", "dorneypark.com", "kingsdominion.com", "knotts.com", "miadventure.com", "schlitterbahn.com", "valleyfair.com", "visitkingsisland.com", "worldsoffun.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ubiquiti, new List { "ubnt.com", "ui.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Discord, new List { "discordapp.com", "discord.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Netcup, new List { "netcup.de", "netcup.eu", "customercontrolpanel.de" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Yandex, new List { "yandex.com", "ya.ru", "yandex.az", "yandex.by", "yandex.co.il", "yandex.com.am", "yandex.com.ge", "yandex.com.tr", "yandex.ee", "yandex.fi", "yandex.fr", "yandex.kg", "yandex.kz", "yandex.lt", "yandex.lv", "yandex.md", "yandex.pl", "yandex.ru", "yandex.tj", "yandex.tm", "yandex.ua", "yandex.uz" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sony, new List { "sonyentertainmentnetwork.com", "sony.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Proton, new List { "proton.me", "protonmail.com", "protonvpn.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ubisoft, new List { "ubisoft.com", "ubi.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.TransferWise, new List { "transferwise.com", "wise.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.TakeawayEU, new List { "takeaway.com", "just-eat.dk", "just-eat.no", "just-eat.fr", "just-eat.ch", "lieferando.de", "lieferando.at", "thuisbezorgd.nl", "pyszne.pl" }); + #endregion + + #region Plans + + Plans = new List { - Type = PlanType.Free, - Product = ProductType.Free, - Name = "Free", - NameLocalizationKey = "planNameFree", - DescriptionLocalizationKey = "planDescFree", - BaseSeats = 2, - MaxCollections = 2, - MaxUsers = 2, + new Plan + { + Type = PlanType.Free, + Product = ProductType.Free, + Name = "Free", + NameLocalizationKey = "planNameFree", + DescriptionLocalizationKey = "planDescFree", + BaseSeats = 2, + MaxCollections = 2, + MaxUsers = 2, - UpgradeSortOrder = -1, // Always the lowest plan, cannot be upgraded to - DisplaySortOrder = -1, + UpgradeSortOrder = -1, // Always the lowest plan, cannot be upgraded to + DisplaySortOrder = -1, - AllowSeatAutoscale = false, - }, - new Plan + AllowSeatAutoscale = false, + }, + new Plan + { + Type = PlanType.FamiliesAnnually2019, + Product = ProductType.Families, + Name = "Families 2019", + IsAnnual = true, + NameLocalizationKey = "planNameFamilies", + DescriptionLocalizationKey = "planDescFamilies", + BaseSeats = 5, + BaseStorageGb = 1, + MaxUsers = 5, + + HasAdditionalStorageOption = true, + HasPremiumAccessOption = true, + TrialPeriodDays = 7, + + HasSelfHost = true, + HasTotp = true, + + UpgradeSortOrder = 1, + DisplaySortOrder = 1, + LegacyYear = 2020, + + StripePlanId = "personal-org-annually", + StripeStoragePlanId = "storage-gb-annually", + StripePremiumAccessPlanId = "personal-org-premium-access-annually", + BasePrice = 12, + AdditionalStoragePricePerGb = 4, + PremiumAccessOptionPrice = 40, + + AllowSeatAutoscale = false, + }, + new Plan + { + Type = PlanType.TeamsAnnually2019, + Product = ProductType.Teams, + Name = "Teams (Annually) 2019", + IsAnnual = true, + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseSeats = 5, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasTotp = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + LegacyYear = 2020, + + StripePlanId = "teams-org-annually", + StripeSeatPlanId = "teams-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 60, + SeatPrice = 24, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.TeamsMonthly2019, + Product = ProductType.Teams, + Name = "Teams (Monthly) 2019", + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseSeats = 5, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasTotp = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + LegacyYear = 2020, + + StripePlanId = "teams-org-monthly", + StripeSeatPlanId = "teams-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + BasePrice = 8, + SeatPrice = 2.5M, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseAnnually2019, + Name = "Enterprise (Annually) 2019", + IsAnnual = true, + Product = ProductType.Enterprise, + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasSelfHost = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + LegacyYear = 2020, + + StripePlanId = null, + StripeSeatPlanId = "enterprise-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 0, + SeatPrice = 36, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseMonthly2019, + Product = ProductType.Enterprise, + Name = "Enterprise (Monthly) 2019", + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + HasSelfHost = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + LegacyYear = 2020, + + StripePlanId = null, + StripeSeatPlanId = "enterprise-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + BasePrice = 0, + SeatPrice = 4M, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.FamiliesAnnually, + Product = ProductType.Families, + Name = "Families", + IsAnnual = true, + NameLocalizationKey = "planNameFamilies", + DescriptionLocalizationKey = "planDescFamilies", + BaseSeats = 6, + BaseStorageGb = 1, + MaxUsers = 6, + + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasSelfHost = true, + HasTotp = true, + UsersGetPremium = true, + + UpgradeSortOrder = 1, + DisplaySortOrder = 1, + + StripePlanId = "2020-families-org-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 40, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = false, + }, + new Plan + { + Type = PlanType.TeamsAnnually, + Product = ProductType.Teams, + Name = "Teams (Annually)", + IsAnnual = true, + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseStorageGb = 1, + BaseSeats = 0, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + Has2fa = true, + HasApi = true, + HasDirectory = true, + HasEvents = true, + HasGroups = true, + HasTotp = true, + UsersGetPremium = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + + StripeSeatPlanId = "2020-teams-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + SeatPrice = 36, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.TeamsMonthly, + Product = ProductType.Teams, + Name = "Teams (Monthly)", + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseStorageGb = 1, + BaseSeats = 0, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + Has2fa = true, + HasApi = true, + HasDirectory = true, + HasEvents = true, + HasGroups = true, + HasTotp = true, + UsersGetPremium = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + + StripeSeatPlanId = "2020-teams-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + SeatPrice = 4, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseAnnually, + Name = "Enterprise (Annually)", + Product = ProductType.Enterprise, + IsAnnual = true, + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasSelfHost = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + HasSso = true, + HasKeyConnector = true, + HasScim = true, + HasResetPassword = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + + StripeSeatPlanId = "2020-enterprise-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 0, + SeatPrice = 60, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseMonthly, + Product = ProductType.Enterprise, + Name = "Enterprise (Monthly)", + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + HasSelfHost = true, + HasSso = true, + HasKeyConnector = true, + HasScim = true, + HasResetPassword = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + + StripeSeatPlanId = "2020-enterprise-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + BasePrice = 0, + SeatPrice = 6, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.Custom, + + AllowSeatAutoscale = true, + }, + }; + + #endregion + } + + public static IDictionary> GlobalDomains { get; set; } + public static IEnumerable Plans { get; set; } + public static IEnumerable SponsoredPlans { get; set; } = new[] { - Type = PlanType.FamiliesAnnually2019, - Product = ProductType.Families, - Name = "Families 2019", - IsAnnual = true, - NameLocalizationKey = "planNameFamilies", - DescriptionLocalizationKey = "planDescFamilies", - BaseSeats = 5, - BaseStorageGb = 1, - MaxUsers = 5, - - HasAdditionalStorageOption = true, - HasPremiumAccessOption = true, - TrialPeriodDays = 7, - - HasSelfHost = true, - HasTotp = true, - - UpgradeSortOrder = 1, - DisplaySortOrder = 1, - LegacyYear = 2020, - - StripePlanId = "personal-org-annually", - StripeStoragePlanId = "storage-gb-annually", - StripePremiumAccessPlanId = "personal-org-premium-access-annually", - BasePrice = 12, - AdditionalStoragePricePerGb = 4, - PremiumAccessOptionPrice = 40, - - AllowSeatAutoscale = false, - }, - new Plan - { - Type = PlanType.TeamsAnnually2019, - Product = ProductType.Teams, - Name = "Teams (Annually) 2019", - IsAnnual = true, - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseSeats = 5, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasTotp = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - LegacyYear = 2020, - - StripePlanId = "teams-org-annually", - StripeSeatPlanId = "teams-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 60, - SeatPrice = 24, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.TeamsMonthly2019, - Product = ProductType.Teams, - Name = "Teams (Monthly) 2019", - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseSeats = 5, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasTotp = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - LegacyYear = 2020, - - StripePlanId = "teams-org-monthly", - StripeSeatPlanId = "teams-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - BasePrice = 8, - SeatPrice = 2.5M, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseAnnually2019, - Name = "Enterprise (Annually) 2019", - IsAnnual = true, - Product = ProductType.Enterprise, - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasSelfHost = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - LegacyYear = 2020, - - StripePlanId = null, - StripeSeatPlanId = "enterprise-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 0, - SeatPrice = 36, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseMonthly2019, - Product = ProductType.Enterprise, - Name = "Enterprise (Monthly) 2019", - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - HasSelfHost = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - LegacyYear = 2020, - - StripePlanId = null, - StripeSeatPlanId = "enterprise-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - BasePrice = 0, - SeatPrice = 4M, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.FamiliesAnnually, - Product = ProductType.Families, - Name = "Families", - IsAnnual = true, - NameLocalizationKey = "planNameFamilies", - DescriptionLocalizationKey = "planDescFamilies", - BaseSeats = 6, - BaseStorageGb = 1, - MaxUsers = 6, - - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasSelfHost = true, - HasTotp = true, - UsersGetPremium = true, - - UpgradeSortOrder = 1, - DisplaySortOrder = 1, - - StripePlanId = "2020-families-org-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 40, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = false, - }, - new Plan - { - Type = PlanType.TeamsAnnually, - Product = ProductType.Teams, - Name = "Teams (Annually)", - IsAnnual = true, - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseStorageGb = 1, - BaseSeats = 0, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - Has2fa = true, - HasApi = true, - HasDirectory = true, - HasEvents = true, - HasGroups = true, - HasTotp = true, - UsersGetPremium = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - - StripeSeatPlanId = "2020-teams-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - SeatPrice = 36, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.TeamsMonthly, - Product = ProductType.Teams, - Name = "Teams (Monthly)", - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseStorageGb = 1, - BaseSeats = 0, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - Has2fa = true, - HasApi = true, - HasDirectory = true, - HasEvents = true, - HasGroups = true, - HasTotp = true, - UsersGetPremium = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - - StripeSeatPlanId = "2020-teams-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - SeatPrice = 4, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseAnnually, - Name = "Enterprise (Annually)", - Product = ProductType.Enterprise, - IsAnnual = true, - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasSelfHost = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - HasSso = true, - HasKeyConnector = true, - HasScim = true, - HasResetPassword = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - - StripeSeatPlanId = "2020-enterprise-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 0, - SeatPrice = 60, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseMonthly, - Product = ProductType.Enterprise, - Name = "Enterprise (Monthly)", - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - HasSelfHost = true, - HasSso = true, - HasKeyConnector = true, - HasScim = true, - HasResetPassword = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - - StripeSeatPlanId = "2020-enterprise-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - BasePrice = 0, - SeatPrice = 6, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.Custom, - - AllowSeatAutoscale = true, - }, - }; - - #endregion + new SponsoredPlan + { + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + SponsoredProductType = ProductType.Families, + SponsoringProductType = ProductType.Enterprise, + StripePlanId = "2021-family-for-enterprise-annually", + UsersCanSponsor = (OrganizationUserOrganizationDetails org) => + GetPlan(org.PlanType).Product == ProductType.Enterprise, + } + }; + public static Plan GetPlan(PlanType planType) => + Plans.FirstOrDefault(p => p.Type == planType); + public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => + SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); } - - public static IDictionary> GlobalDomains { get; set; } - public static IEnumerable Plans { get; set; } - public static IEnumerable SponsoredPlans { get; set; } = new[] - { - new SponsoredPlan - { - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - SponsoredProductType = ProductType.Families, - SponsoringProductType = ProductType.Enterprise, - StripePlanId = "2021-family-for-enterprise-annually", - UsersCanSponsor = (OrganizationUserOrganizationDetails org) => - GetPlan(org.PlanType).Product == ProductType.Enterprise, - } - }; - public static Plan GetPlan(PlanType planType) => - Plans.FirstOrDefault(p => p.Type == planType); - public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => - SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); } diff --git a/src/Core/Utilities/StrictEmailAddressAttribute.cs b/src/Core/Utilities/StrictEmailAddressAttribute.cs index f84e41852d..15347ab836 100644 --- a/src/Core/Utilities/StrictEmailAddressAttribute.cs +++ b/src/Core/Utilities/StrictEmailAddressAttribute.cs @@ -2,51 +2,52 @@ using System.Text.RegularExpressions; using MimeKit; -namespace Bit.Core.Utilities; - -public class StrictEmailAddressAttribute : ValidationAttribute +namespace Bit.Core.Utilities { - public StrictEmailAddressAttribute() - : base("The {0} field is not a supported e-mail address format.") - { } - - public override bool IsValid(object value) + public class StrictEmailAddressAttribute : ValidationAttribute { - var emailAddress = value?.ToString(); - if (emailAddress == null) - { - return false; - } + public StrictEmailAddressAttribute() + : base("The {0} field is not a supported e-mail address format.") + { } - try + public override bool IsValid(object value) { - var parsedEmailAddress = MailboxAddress.Parse(emailAddress).Address; - if (parsedEmailAddress != emailAddress) + var emailAddress = value?.ToString(); + if (emailAddress == null) { return false; } - } - catch (ParseException) - { - return false; - } - /** - The regex below is intended to catch edge cases that are not handled by the general parsing check above. - This enforces the following rules: - * Requires ASCII only in the local-part (code points 0-127) - * Requires an @ symbol - * Allows any char in second-level domain name, including unicode and symbols - * Requires at least one period (.) separating SLD from TLD - * Must end in a letter (including unicode) - See the unit tests for examples of what is allowed. - **/ - var emailFormat = @"^[\x00-\x7F]+@.+\.\p{L}+$"; - if (!Regex.IsMatch(emailAddress, emailFormat)) - { - return false; - } + try + { + var parsedEmailAddress = MailboxAddress.Parse(emailAddress).Address; + if (parsedEmailAddress != emailAddress) + { + return false; + } + } + catch (ParseException) + { + return false; + } - return new EmailAddressAttribute().IsValid(emailAddress); + /** + The regex below is intended to catch edge cases that are not handled by the general parsing check above. + This enforces the following rules: + * Requires ASCII only in the local-part (code points 0-127) + * Requires an @ symbol + * Allows any char in second-level domain name, including unicode and symbols + * Requires at least one period (.) separating SLD from TLD + * Must end in a letter (including unicode) + See the unit tests for examples of what is allowed. + **/ + var emailFormat = @"^[\x00-\x7F]+@.+\.\p{L}+$"; + if (!Regex.IsMatch(emailAddress, emailFormat)) + { + return false; + } + + return new EmailAddressAttribute().IsValid(emailAddress); + } } } diff --git a/src/Core/Utilities/StrictEmailAddressListAttribute.cs b/src/Core/Utilities/StrictEmailAddressListAttribute.cs index 456980397a..dcff171cd9 100644 --- a/src/Core/Utilities/StrictEmailAddressListAttribute.cs +++ b/src/Core/Utilities/StrictEmailAddressListAttribute.cs @@ -1,38 +1,39 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Utilities; - -public class StrictEmailAddressListAttribute : ValidationAttribute +namespace Bit.Core.Utilities { - protected override ValidationResult IsValid(object value, ValidationContext validationContext) + public class StrictEmailAddressListAttribute : ValidationAttribute { - var strictEmailAttribute = new StrictEmailAddressAttribute(); - var emails = value as IList; - - if (!emails?.Any() ?? true) + protected override ValidationResult IsValid(object value, ValidationContext validationContext) { - return new ValidationResult("An email is required."); - } + var strictEmailAttribute = new StrictEmailAddressAttribute(); + var emails = value as IList; - if (emails.Count() > 20) - { - return new ValidationResult("You can only submit up to 20 emails at a time."); - } - - for (var i = 0; i < emails.Count(); i++) - { - var email = emails.ElementAt(i); - if (!strictEmailAttribute.IsValid(email)) + if (!emails?.Any() ?? true) { - return new ValidationResult($"Email #{i + 1} is not valid."); + return new ValidationResult("An email is required."); } - if (email.Length > 256) + if (emails.Count() > 20) { - return new ValidationResult($"Email #{i + 1} is longer than 256 characters."); + return new ValidationResult("You can only submit up to 20 emails at a time."); } - } - return ValidationResult.Success; + for (var i = 0; i < emails.Count(); i++) + { + var email = emails.ElementAt(i); + if (!strictEmailAttribute.IsValid(email)) + { + return new ValidationResult($"Email #{i + 1} is not valid."); + } + + if (email.Length > 256) + { + return new ValidationResult($"Email #{i + 1} is longer than 256 characters."); + } + } + + return ValidationResult.Success; + } } } diff --git a/src/Events/Controllers/CollectController.cs b/src/Events/Controllers/CollectController.cs index aaed0b3584..f2599d26ec 100644 --- a/src/Events/Controllers/CollectController.cs +++ b/src/Events/Controllers/CollectController.cs @@ -8,98 +8,99 @@ using Bit.Events.Models; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Events.Controllers; - -[Route("collect")] -[Authorize("Application")] -public class CollectController : Controller +namespace Bit.Events.Controllers { - private readonly ICurrentContext _currentContext; - private readonly IEventService _eventService; - private readonly ICipherRepository _cipherRepository; - private readonly IOrganizationRepository _organizationRepository; - - public CollectController( - ICurrentContext currentContext, - IEventService eventService, - ICipherRepository cipherRepository, - IOrganizationRepository organizationRepository) + [Route("collect")] + [Authorize("Application")] + public class CollectController : Controller { - _currentContext = currentContext; - _eventService = eventService; - _cipherRepository = cipherRepository; - _organizationRepository = organizationRepository; - } + private readonly ICurrentContext _currentContext; + private readonly IEventService _eventService; + private readonly ICipherRepository _cipherRepository; + private readonly IOrganizationRepository _organizationRepository; - [HttpPost] - public async Task Post([FromBody] IEnumerable model) - { - if (model == null || !model.Any()) + public CollectController( + ICurrentContext currentContext, + IEventService eventService, + ICipherRepository cipherRepository, + IOrganizationRepository organizationRepository) { - return new BadRequestResult(); + _currentContext = currentContext; + _eventService = eventService; + _cipherRepository = cipherRepository; + _organizationRepository = organizationRepository; } - var cipherEvents = new List>(); - var ciphersCache = new Dictionary(); - foreach (var eventModel in model) + + [HttpPost] + public async Task Post([FromBody] IEnumerable model) { - switch (eventModel.Type) + if (model == null || !model.Any()) { - // User events - case EventType.User_ClientExportedVault: - await _eventService.LogUserEventAsync(_currentContext.UserId.Value, eventModel.Type, eventModel.Date); - break; - // Cipher events - case EventType.Cipher_ClientAutofilled: - case EventType.Cipher_ClientCopiedHiddenField: - case EventType.Cipher_ClientCopiedPassword: - case EventType.Cipher_ClientCopiedCardCode: - case EventType.Cipher_ClientToggledCardCodeVisible: - case EventType.Cipher_ClientToggledHiddenFieldVisible: - case EventType.Cipher_ClientToggledPasswordVisible: - case EventType.Cipher_ClientViewed: - if (!eventModel.CipherId.HasValue) - { - continue; - } - Cipher cipher = null; - if (ciphersCache.ContainsKey(eventModel.CipherId.Value)) - { - cipher = ciphersCache[eventModel.CipherId.Value]; - } - else - { - cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value, - _currentContext.UserId.Value); - } - if (cipher == null) - { - continue; - } - if (!ciphersCache.ContainsKey(eventModel.CipherId.Value)) - { - ciphersCache.Add(eventModel.CipherId.Value, cipher); - } - cipherEvents.Add(new Tuple(cipher, eventModel.Type, eventModel.Date)); - break; - case EventType.Organization_ClientExportedVault: - if (!eventModel.OrganizationId.HasValue) - { - continue; - } - var organization = await _organizationRepository.GetByIdAsync(eventModel.OrganizationId.Value); - await _eventService.LogOrganizationEventAsync(organization, eventModel.Type, eventModel.Date); - break; - default: - continue; + return new BadRequestResult(); } - } - if (cipherEvents.Any()) - { - foreach (var eventsBatch in cipherEvents.Batch(50)) + var cipherEvents = new List>(); + var ciphersCache = new Dictionary(); + foreach (var eventModel in model) { - await _eventService.LogCipherEventsAsync(eventsBatch); + switch (eventModel.Type) + { + // User events + case EventType.User_ClientExportedVault: + await _eventService.LogUserEventAsync(_currentContext.UserId.Value, eventModel.Type, eventModel.Date); + break; + // Cipher events + case EventType.Cipher_ClientAutofilled: + case EventType.Cipher_ClientCopiedHiddenField: + case EventType.Cipher_ClientCopiedPassword: + case EventType.Cipher_ClientCopiedCardCode: + case EventType.Cipher_ClientToggledCardCodeVisible: + case EventType.Cipher_ClientToggledHiddenFieldVisible: + case EventType.Cipher_ClientToggledPasswordVisible: + case EventType.Cipher_ClientViewed: + if (!eventModel.CipherId.HasValue) + { + continue; + } + Cipher cipher = null; + if (ciphersCache.ContainsKey(eventModel.CipherId.Value)) + { + cipher = ciphersCache[eventModel.CipherId.Value]; + } + else + { + cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value, + _currentContext.UserId.Value); + } + if (cipher == null) + { + continue; + } + if (!ciphersCache.ContainsKey(eventModel.CipherId.Value)) + { + ciphersCache.Add(eventModel.CipherId.Value, cipher); + } + cipherEvents.Add(new Tuple(cipher, eventModel.Type, eventModel.Date)); + break; + case EventType.Organization_ClientExportedVault: + if (!eventModel.OrganizationId.HasValue) + { + continue; + } + var organization = await _organizationRepository.GetByIdAsync(eventModel.OrganizationId.Value); + await _eventService.LogOrganizationEventAsync(organization, eventModel.Type, eventModel.Date); + break; + default: + continue; + } } + if (cipherEvents.Any()) + { + foreach (var eventsBatch in cipherEvents.Batch(50)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + } + return new OkResult(); } - return new OkResult(); } } diff --git a/src/Events/Controllers/InfoController.cs b/src/Events/Controllers/InfoController.cs index 6d42f67579..23234c654c 100644 --- a/src/Events/Controllers/InfoController.cs +++ b/src/Events/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Events.Controllers; - -public class InfoController : Controller +namespace Bit.Events.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/src/Events/Models/EventModel.cs b/src/Events/Models/EventModel.cs index dc5cef0843..80b69398a3 100644 --- a/src/Events/Models/EventModel.cs +++ b/src/Events/Models/EventModel.cs @@ -1,11 +1,12 @@ using Bit.Core.Enums; -namespace Bit.Events.Models; - -public class EventModel +namespace Bit.Events.Models { - public EventType Type { get; set; } - public Guid? CipherId { get; set; } - public DateTime Date { get; set; } - public Guid? OrganizationId { get; set; } + public class EventModel + { + public EventType Type { get; set; } + public Guid? CipherId { get; set; } + public DateTime Date { get; set; } + public Guid? OrganizationId { get; set; } + } } diff --git a/src/Events/Program.cs b/src/Events/Program.cs index 74f82cd414..a6a95646a6 100644 --- a/src/Events/Program.cs +++ b/src/Events/Program.cs @@ -1,39 +1,40 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Events; - -public class Program +namespace Bit.Events { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return e.Level > LogEventLevel.Error; - } + var context = e.Properties["SourceContext"].ToString(); + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + { + return e.Level > LogEventLevel.Error; + } - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); + } } } diff --git a/src/Events/Startup.cs b/src/Events/Startup.cs index c44ca3c1a0..7c777abb42 100644 --- a/src/Events/Startup.cs +++ b/src/Events/Startup.cs @@ -6,112 +6,113 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using IdentityModel; -namespace Bit.Events; - -public class Startup +namespace Bit.Events { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - - // Identity - services.AddIdentityAuthenticationServices(globalSettings, Environment, config => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - config.AddPolicy("Application", policy => + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + + // Identity + services.AddIdentityAuthenticationServices(globalSettings, Environment, config => { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); + config.AddPolicy("Application", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + }); }); - }); - // Services - var usingServiceBusAppCache = CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName); - if (usingServiceBusAppCache) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - services.AddScoped(); - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); + // Services + var usingServiceBusAppCache = CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName); + if (usingServiceBusAppCache) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + services.AddScoped(); + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + // Mvc + services.AddMvc(config => + { + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + + if (usingServiceBusAppCache) + { + services.AddHostedService(); + } } - // Mvc - services.AddMvc(config => + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); + app.UseSerilog(env, appLifetime, globalSettings); - if (usingServiceBusAppCache) - { - services.AddHostedService(); + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add authentication and authorization to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add current context + app.UseMiddleware(); + + // Add MVC to the request pipeline. + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); } } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add authentication and authorization to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add current context - app.UseMiddleware(); - - // Add MVC to the request pipeline. - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - } } diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index 837e4ad148..41f203e9d1 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -5,121 +5,122 @@ using Bit.Core.Models.Data; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.EventsProcessor; - -public class AzureQueueHostedService : IHostedService, IDisposable +namespace Bit.EventsProcessor { - private readonly ILogger _logger; - private readonly IConfiguration _configuration; - - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; - private IEventWriteService _eventWriteService; - - public AzureQueueHostedService( - ILogger logger, - IConfiguration configuration) + public class AzureQueueHostedService : IHostedService, IDisposable { - _logger = logger; - _configuration = configuration; - } + private readonly ILogger _logger; + private readonly IConfiguration _configuration; - public Task StartAsync(CancellationToken cancellationToken) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Starting service."); - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + private Task _executingTask; + private CancellationTokenSource _cts; + private QueueClient _queueClient; + private IEventWriteService _eventWriteService; - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) + public AzureQueueHostedService( + ILogger logger, + IConfiguration configuration) { - return; - } - _logger.LogWarning("Stopping service."); - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - - public void Dispose() - { } - - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - var storageConnectionString = _configuration["azureStorageConnectionString"]; - if (string.IsNullOrWhiteSpace(storageConnectionString)) - { - return; + _logger = logger; + _configuration = configuration; } - var repo = new Core.Repositories.TableStorage.EventRepository(storageConnectionString); - _eventWriteService = new RepositoryEventWriteService(repo); - _queueClient = new QueueClient(storageConnectionString, "event"); - - while (!cancellationToken.IsCancellationRequested) + public Task StartAsync(CancellationToken cancellationToken) { - try + _logger.LogInformation(Constants.BypassFiltersEventId, "Starting service."); + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - var messages = await _queueClient.ReceiveMessagesAsync(32); - if (messages.Value?.Any() ?? false) + return; + } + _logger.LogWarning("Stopping service."); + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } + + public void Dispose() + { } + + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + var storageConnectionString = _configuration["azureStorageConnectionString"]; + if (string.IsNullOrWhiteSpace(storageConnectionString)) + { + return; + } + + var repo = new Core.Repositories.TableStorage.EventRepository(storageConnectionString); + _eventWriteService = new RepositoryEventWriteService(repo); + _queueClient = new QueueClient(storageConnectionString, "event"); + + while (!cancellationToken.IsCancellationRequested) + { + try { - foreach (var message in messages.Value) + var messages = await _queueClient.ReceiveMessagesAsync(32); + if (messages.Value?.Any() ?? false) { - await ProcessQueueMessageAsync(message.DecodeMessageText(), cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + foreach (var message in messages.Value) + { + await ProcessQueueMessageAsync(message.DecodeMessageText(), cancellationToken); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + } + } + else + { + await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - else + catch (Exception e) { + _logger.LogError(e, "Exception occurred: " + e.Message); await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - catch (Exception e) - { - _logger.LogError(e, "Exception occurred: " + e.Message); - await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); - } + + _logger.LogWarning("Done processing."); } - _logger.LogWarning("Done processing."); - } - - public async Task ProcessQueueMessageAsync(string message, CancellationToken cancellationToken) - { - if (_eventWriteService == null || message == null || message.Length == 0) + public async Task ProcessQueueMessageAsync(string message, CancellationToken cancellationToken) { - return; - } - - try - { - _logger.LogInformation("Processing message."); - var events = new List(); - - using var jsonDocument = JsonDocument.Parse(message); - var root = jsonDocument.RootElement; - if (root.ValueKind == JsonValueKind.Array) + if (_eventWriteService == null || message == null || message.Length == 0) { - var indexedEntities = root.ToObject>() - .SelectMany(e => EventTableEntity.IndexEvent(e)); - events.AddRange(indexedEntities); - } - else if (root.ValueKind == JsonValueKind.Object) - { - var eventMessage = root.ToObject(); - events.AddRange(EventTableEntity.IndexEvent(eventMessage)); + return; } - await _eventWriteService.CreateManyAsync(events); - _logger.LogInformation("Processed message."); - } - catch (JsonException) - { - _logger.LogError("JsonReaderException: Unable to parse message."); + try + { + _logger.LogInformation("Processing message."); + var events = new List(); + + using var jsonDocument = JsonDocument.Parse(message); + var root = jsonDocument.RootElement; + if (root.ValueKind == JsonValueKind.Array) + { + var indexedEntities = root.ToObject>() + .SelectMany(e => EventTableEntity.IndexEvent(e)); + events.AddRange(indexedEntities); + } + else if (root.ValueKind == JsonValueKind.Object) + { + var eventMessage = root.ToObject(); + events.AddRange(EventTableEntity.IndexEvent(eventMessage)); + } + + await _eventWriteService.CreateManyAsync(events); + _logger.LogInformation("Processed message."); + } + catch (JsonException) + { + _logger.LogError("JsonReaderException: Unable to parse message."); + } } } } diff --git a/src/EventsProcessor/Program.cs b/src/EventsProcessor/Program.cs index 0cf2d17fab..a63c7742c4 100644 --- a/src/EventsProcessor/Program.cs +++ b/src/EventsProcessor/Program.cs @@ -1,21 +1,22 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.EventsProcessor; - -public class Program +namespace Bit.EventsProcessor { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Warning)); - }) - .Build() - .Run(); + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Warning)); + }) + .Build() + .Run(); + } } } diff --git a/src/EventsProcessor/Startup.cs b/src/EventsProcessor/Startup.cs index d0a624f737..e995816a01 100644 --- a/src/EventsProcessor/Startup.cs +++ b/src/EventsProcessor/Startup.cs @@ -4,52 +4,53 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.IdentityModel.Logging; -namespace Bit.EventsProcessor; - -public class Startup +namespace Bit.EventsProcessor { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - services.AddGlobalSettingsServices(Configuration, Environment); - - // Hosted Services - services.AddHostedService(); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers - app.UseMiddleware(); - app.UseRouting(); - app.UseEndpoints(endpoints => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - endpoints.MapGet("/alive", - async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); - endpoints.MapGet("/now", - async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); - endpoints.MapGet("/version", - async context => await context.Response.WriteAsJsonAsync(CoreHelpers.GetVersion())); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } - }); + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + services.AddGlobalSettingsServices(Configuration, Environment); + + // Hosted Services + services.AddHostedService(); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + IdentityModelEventSource.ShowPII = true; + app.UseSerilog(env, appLifetime, globalSettings); + // Add general security headers + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/alive", + async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); + endpoints.MapGet("/now", + async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); + endpoints.MapGet("/version", + async context => await context.Response.WriteAsJsonAsync(CoreHelpers.GetVersion())); + + }); + } } } diff --git a/src/Icons/Controllers/IconsController.cs b/src/Icons/Controllers/IconsController.cs index ad9b6cfd4f..5e27ece56a 100644 --- a/src/Icons/Controllers/IconsController.cs +++ b/src/Icons/Controllers/IconsController.cs @@ -3,105 +3,106 @@ using Bit.Icons.Services; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Caching.Memory; -namespace Bit.Icons.Controllers; - -[Route("")] -public class IconsController : Controller +namespace Bit.Icons.Controllers { - // Basic bwi-globe icon - private static readonly byte[] _notFoundImage = Convert.FromBase64String("iVBORw0KGgoAAAANSUhEUg" + - "AAABMAAAATCAQAAADYWf5HAAABu0lEQVR42nXSvWuTURTH8R+t0heI9Y04aJycdBLNJNrBFBU7OFgUER3q21I0bXK+JwZ" + - "pXISm/QdcRB3EgqBBsNihsUbbgODQQSKCuKSDOApJuuhj8tCYQj/jvYfD795z1MZ+nBKrNKhSwrMxbZTrtRnqlEjZkB/x" + - "C/xmhZrlc71qS0Up8yVzTCGucFNKD1JhORVd70SZNU4okNx5d4+U2UXRIpJFWLClsR79YzN88wQvLWNzzPKEeS/wkQGpW" + - "VhhqhW8TtDJD3Mm1x/23zLSrZCdpBY8BueTNjHSbc+8wC9HlHgU5Aj5AW5zPdcVdpq0UcknWBSr/pjixO4gfp899Kd23p" + - "M2qQCH7LkCnqAqGh73OK/8NPOcaibr90LrW/yWAnaUhqjaOSl9nFR2r5rsqo22ypn1B5IN8VOUMHVgOnNQIX+d62plcz6" + - "rg1/jskK8CMb4we4pG6OWHtR/LBJkC2E4a7ZPkuX5ntumAOM2xxveclEhLvGH6XCmLPs735Eetrw63NnOgr9P9q1viC3x" + - "lRUGOjImqFDuOBvrYYoaZU9z1uPpYae5NfdvbNVG2ZjDIlXq/oMi46lo++4vjjPBl2Dlg00AAAAASUVORK5CYII="); - - private readonly IMemoryCache _memoryCache; - private readonly IDomainMappingService _domainMappingService; - private readonly IIconFetchingService _iconFetchingService; - private readonly ILogger _logger; - private readonly IconsSettings _iconsSettings; - - public IconsController( - IMemoryCache memoryCache, - IDomainMappingService domainMappingService, - IIconFetchingService iconFetchingService, - ILogger logger, - IconsSettings iconsSettings) + [Route("")] + public class IconsController : Controller { - _memoryCache = memoryCache; - _domainMappingService = domainMappingService; - _iconFetchingService = iconFetchingService; - _logger = logger; - _iconsSettings = iconsSettings; - } + // Basic bwi-globe icon + private static readonly byte[] _notFoundImage = Convert.FromBase64String("iVBORw0KGgoAAAANSUhEUg" + + "AAABMAAAATCAQAAADYWf5HAAABu0lEQVR42nXSvWuTURTH8R+t0heI9Y04aJycdBLNJNrBFBU7OFgUER3q21I0bXK+JwZ" + + "pXISm/QdcRB3EgqBBsNihsUbbgODQQSKCuKSDOApJuuhj8tCYQj/jvYfD795z1MZ+nBKrNKhSwrMxbZTrtRnqlEjZkB/x" + + "C/xmhZrlc71qS0Up8yVzTCGucFNKD1JhORVd70SZNU4okNx5d4+U2UXRIpJFWLClsR79YzN88wQvLWNzzPKEeS/wkQGpW" + + "VhhqhW8TtDJD3Mm1x/23zLSrZCdpBY8BueTNjHSbc+8wC9HlHgU5Aj5AW5zPdcVdpq0UcknWBSr/pjixO4gfp899Kd23p" + + "M2qQCH7LkCnqAqGh73OK/8NPOcaibr90LrW/yWAnaUhqjaOSl9nFR2r5rsqo22ypn1B5IN8VOUMHVgOnNQIX+d62plcz6" + + "rg1/jskK8CMb4we4pG6OWHtR/LBJkC2E4a7ZPkuX5ntumAOM2xxveclEhLvGH6XCmLPs735Eetrw63NnOgr9P9q1viC3x" + + "lRUGOjImqFDuOBvrYYoaZU9z1uPpYae5NfdvbNVG2ZjDIlXq/oMi46lo++4vjjPBl2Dlg00AAAAASUVORK5CYII="); - [HttpGet("~/config")] - public IActionResult GetConfig() - { - return new JsonResult(new + private readonly IMemoryCache _memoryCache; + private readonly IDomainMappingService _domainMappingService; + private readonly IIconFetchingService _iconFetchingService; + private readonly ILogger _logger; + private readonly IconsSettings _iconsSettings; + + public IconsController( + IMemoryCache memoryCache, + IDomainMappingService domainMappingService, + IIconFetchingService iconFetchingService, + ILogger logger, + IconsSettings iconsSettings) { - CacheEnabled = _iconsSettings.CacheEnabled, - CacheHours = _iconsSettings.CacheHours, - CacheSizeLimit = _iconsSettings.CacheSizeLimit - }); - } - - [HttpGet("{hostname}/icon.png")] - public async Task Get(string hostname) - { - if (string.IsNullOrWhiteSpace(hostname) || !hostname.Contains(".")) - { - return new BadRequestResult(); + _memoryCache = memoryCache; + _domainMappingService = domainMappingService; + _iconFetchingService = iconFetchingService; + _logger = logger; + _iconsSettings = iconsSettings; } - var url = $"http://{hostname}"; - if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + [HttpGet("~/config")] + public IActionResult GetConfig() { - return new BadRequestResult(); + return new JsonResult(new + { + CacheEnabled = _iconsSettings.CacheEnabled, + CacheHours = _iconsSettings.CacheHours, + CacheSizeLimit = _iconsSettings.CacheSizeLimit + }); } - var domain = uri.Host; - // Convert sub.domain.com => domain.com - //if(DomainName.TryParseBaseDomain(domain, out var baseDomain)) - //{ - // domain = baseDomain; - //} - - var mappedDomain = _domainMappingService.MapDomain(domain); - if (!_iconsSettings.CacheEnabled || !_memoryCache.TryGetValue(mappedDomain, out Icon icon)) + [HttpGet("{hostname}/icon.png")] + public async Task Get(string hostname) { - var result = await _iconFetchingService.GetIconAsync(domain); - if (result == null) + if (string.IsNullOrWhiteSpace(hostname) || !hostname.Contains(".")) { - _logger.LogWarning("Null result returned for {0}.", domain); - icon = null; - } - else - { - icon = result.Icon; + return new BadRequestResult(); } - // Only cache not found and smaller images (<= 50kb) - if (_iconsSettings.CacheEnabled && (icon == null || icon.Image.Length <= 50012)) + var url = $"http://{hostname}"; + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) { - _logger.LogInformation("Cache icon for {0}.", domain); - _memoryCache.Set(mappedDomain, icon, new MemoryCacheEntryOptions + return new BadRequestResult(); + } + + var domain = uri.Host; + // Convert sub.domain.com => domain.com + //if(DomainName.TryParseBaseDomain(domain, out var baseDomain)) + //{ + // domain = baseDomain; + //} + + var mappedDomain = _domainMappingService.MapDomain(domain); + if (!_iconsSettings.CacheEnabled || !_memoryCache.TryGetValue(mappedDomain, out Icon icon)) + { + var result = await _iconFetchingService.GetIconAsync(domain); + if (result == null) { - AbsoluteExpirationRelativeToNow = new TimeSpan(_iconsSettings.CacheHours, 0, 0), - Size = icon?.Image.Length ?? 0, - Priority = icon == null ? CacheItemPriority.High : CacheItemPriority.Normal - }); + _logger.LogWarning("Null result returned for {0}.", domain); + icon = null; + } + else + { + icon = result.Icon; + } + + // Only cache not found and smaller images (<= 50kb) + if (_iconsSettings.CacheEnabled && (icon == null || icon.Image.Length <= 50012)) + { + _logger.LogInformation("Cache icon for {0}.", domain); + _memoryCache.Set(mappedDomain, icon, new MemoryCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = new TimeSpan(_iconsSettings.CacheHours, 0, 0), + Size = icon?.Image.Length ?? 0, + Priority = icon == null ? CacheItemPriority.High : CacheItemPriority.Normal + }); + } } - } - if (icon == null) - { - return new FileContentResult(_notFoundImage, "image/png"); - } + if (icon == null) + { + return new FileContentResult(_notFoundImage, "image/png"); + } - return new FileContentResult(icon.Image, icon.Format); + return new FileContentResult(icon.Image, icon.Format); + } } } diff --git a/src/Icons/Controllers/InfoController.cs b/src/Icons/Controllers/InfoController.cs index 1ebbd473a1..47c6ca553d 100644 --- a/src/Icons/Controllers/InfoController.cs +++ b/src/Icons/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Icons.Controllers; - -public class InfoController : Controller +namespace Bit.Icons.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/src/Icons/IconsSettings.cs b/src/Icons/IconsSettings.cs index 7cfd64d112..e6de866291 100644 --- a/src/Icons/IconsSettings.cs +++ b/src/Icons/IconsSettings.cs @@ -1,8 +1,9 @@ -namespace Bit.Icons; - -public class IconsSettings +namespace Bit.Icons { - public virtual bool CacheEnabled { get; set; } - public virtual int CacheHours { get; set; } - public virtual long? CacheSizeLimit { get; set; } + public class IconsSettings + { + public virtual bool CacheEnabled { get; set; } + public virtual int CacheHours { get; set; } + public virtual long? CacheSizeLimit { get; set; } + } } diff --git a/src/Icons/Models/DomainName.cs b/src/Icons/Models/DomainName.cs index b040110504..ee5a5f0d44 100644 --- a/src/Icons/Models/DomainName.cs +++ b/src/Icons/Models/DomainName.cs @@ -2,323 +2,324 @@ using System.Reflection; using System.Text.RegularExpressions; -namespace Bit.Icons.Models; - -// ref: https://github.com/danesparza/domainname-parser -public class DomainName +namespace Bit.Icons.Models { - private const string IpRegex = "^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + - "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + - "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + - "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"; - - private string _subDomain = string.Empty; - private string _domain = string.Empty; - private string _tld = string.Empty; - private TLDRule _tldRule = null; - - public string SubDomain => _subDomain; - public string Domain => _domain; - public string SLD => _domain; - public string TLD => _tld; - public TLDRule Rule => _tldRule; - public string BaseDomain => $"{_domain}.{_tld}"; - - public DomainName(string TLD, string SLD, string SubDomain, TLDRule TLDRule) + // ref: https://github.com/danesparza/domainname-parser + public class DomainName { - _tld = TLD; - _domain = SLD; - _subDomain = SubDomain; - _tldRule = TLDRule; - } + private const string IpRegex = "^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + + "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + + "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + + "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"; - public static bool TryParse(string domainString, out DomainName result) - { - var retval = false; + private string _subDomain = string.Empty; + private string _domain = string.Empty; + private string _tld = string.Empty; + private TLDRule _tldRule = null; - // Our temporary domain parts: - var tld = string.Empty; - var sld = string.Empty; - var subdomain = string.Empty; - TLDRule _tldrule = null; - result = null; + public string SubDomain => _subDomain; + public string Domain => _domain; + public string SLD => _domain; + public string TLD => _tld; + public TLDRule Rule => _tldRule; + public string BaseDomain => $"{_domain}.{_tld}"; - try + public DomainName(string TLD, string SLD, string SubDomain, TLDRule TLDRule) { - // Try parsing the domain name ... this might throw formatting exceptions - ParseDomainName(domainString, out tld, out sld, out subdomain, out _tldrule); - // Construct a new DomainName object and return it - result = new DomainName(tld, sld, subdomain, _tldrule); - // Return 'true' - retval = true; - } - catch - { - // Looks like something bad happened -- return 'false' - retval = false; + _tld = TLD; + _domain = SLD; + _subDomain = SubDomain; + _tldRule = TLDRule; } - return retval; - } - - public static bool TryParseBaseDomain(string domainString, out string result) - { - if (Regex.IsMatch(domainString, IpRegex)) + public static bool TryParse(string domainString, out DomainName result) { - result = domainString; - return true; - } + var retval = false; - DomainName domain; - var retval = TryParse(domainString, out domain); - result = domain?.BaseDomain; - return retval; - } + // Our temporary domain parts: + var tld = string.Empty; + var sld = string.Empty; + var subdomain = string.Empty; + TLDRule _tldrule = null; + result = null; - private static void ParseDomainName(string domainString, out string TLD, out string SLD, - out string SubDomain, out TLDRule MatchingRule) - { - // Make sure domain is all lowercase - domainString = domainString.ToLower(); - - TLD = string.Empty; - SLD = string.Empty; - SubDomain = string.Empty; - MatchingRule = null; - - // If the fqdn is empty, we have a problem already - if (domainString.Trim() == string.Empty) - { - throw new ArgumentException("The domain cannot be blank"); - } - - // Next, find the matching rule: - MatchingRule = FindMatchingTLDRule(domainString); - - // At this point, no rules match, we have a problem - if (MatchingRule == null) - { - throw new FormatException("The domain does not have a recognized TLD"); - } - - // Based on the tld rule found, get the domain (and possibly the subdomain) - var tempSudomainAndDomain = string.Empty; - var tldIndex = 0; - - // First, determine what type of rule we have, and set the TLD accordingly - switch (MatchingRule.Type) - { - case TLDRule.RuleType.Normal: - tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - TLD = domainString.Substring(tldIndex + 1); - break; - case TLDRule.RuleType.Wildcard: - // This finds the last portion of the TLD... - tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - - // But we need to find the wildcard portion of it: - tldIndex = tempSudomainAndDomain.LastIndexOf("."); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - TLD = domainString.Substring(tldIndex + 1); - break; - case TLDRule.RuleType.Exception: - tldIndex = domainString.LastIndexOf("."); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - TLD = domainString.Substring(tldIndex + 1); - break; - } - - // See if we have a subdomain: - List lstRemainingParts = new List(tempSudomainAndDomain.Split('.')); - - // If we have 0 parts left, there is just a tld and no domain or subdomain - // If we have 1 part, it's the domain, and there is no subdomain - // If we have 2+ parts, the last part is the domain, the other parts (combined) are the subdomain - if (lstRemainingParts.Count > 0) - { - // Set the domain: - SLD = lstRemainingParts[lstRemainingParts.Count - 1]; - - // Set the subdomain, if there is one to set: - if (lstRemainingParts.Count > 1) + try { - // We strip off the trailing period, too - SubDomain = tempSudomainAndDomain.Substring(0, tempSudomainAndDomain.Length - SLD.Length - 1); + // Try parsing the domain name ... this might throw formatting exceptions + ParseDomainName(domainString, out tld, out sld, out subdomain, out _tldrule); + // Construct a new DomainName object and return it + result = new DomainName(tld, sld, subdomain, _tldrule); + // Return 'true' + retval = true; } - } - } - - private static TLDRule FindMatchingTLDRule(string domainString) - { - // Split our domain into parts (based on the '.') - // ...Put these parts in a list - // ...Make sure these parts are in reverse order - // (we'll be checking rules from the right-most pat of the domain) - var lstDomainParts = domainString.Split('.').ToList(); - lstDomainParts.Reverse(); - - // Begin building our partial domain to check rules with: - var checkAgainst = string.Empty; - - // Our 'matches' collection: - var ruleMatches = new List(); - - foreach (string domainPart in lstDomainParts) - { - // Add on our next domain part: - checkAgainst = string.Format("{0}.{1}", domainPart, checkAgainst); - - // If we end in a period, strip it off: - if (checkAgainst.EndsWith(".")) + catch { - checkAgainst = checkAgainst.Substring(0, checkAgainst.Length - 1); + // Looks like something bad happened -- return 'false' + retval = false; } - var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); - foreach (var rule in rules) + return retval; + } + + public static bool TryParseBaseDomain(string domainString, out string result) + { + if (Regex.IsMatch(domainString, IpRegex)) { - // Try to match rule: - TLDRule result; - if (TLDRulesCache.Instance.TLDRuleLists[rule].TryGetValue(checkAgainst, out result)) + result = domainString; + return true; + } + + DomainName domain; + var retval = TryParse(domainString, out domain); + result = domain?.BaseDomain; + return retval; + } + + private static void ParseDomainName(string domainString, out string TLD, out string SLD, + out string SubDomain, out TLDRule MatchingRule) + { + // Make sure domain is all lowercase + domainString = domainString.ToLower(); + + TLD = string.Empty; + SLD = string.Empty; + SubDomain = string.Empty; + MatchingRule = null; + + // If the fqdn is empty, we have a problem already + if (domainString.Trim() == string.Empty) + { + throw new ArgumentException("The domain cannot be blank"); + } + + // Next, find the matching rule: + MatchingRule = FindMatchingTLDRule(domainString); + + // At this point, no rules match, we have a problem + if (MatchingRule == null) + { + throw new FormatException("The domain does not have a recognized TLD"); + } + + // Based on the tld rule found, get the domain (and possibly the subdomain) + var tempSudomainAndDomain = string.Empty; + var tldIndex = 0; + + // First, determine what type of rule we have, and set the TLD accordingly + switch (MatchingRule.Type) + { + case TLDRule.RuleType.Normal: + tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + TLD = domainString.Substring(tldIndex + 1); + break; + case TLDRule.RuleType.Wildcard: + // This finds the last portion of the TLD... + tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + + // But we need to find the wildcard portion of it: + tldIndex = tempSudomainAndDomain.LastIndexOf("."); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + TLD = domainString.Substring(tldIndex + 1); + break; + case TLDRule.RuleType.Exception: + tldIndex = domainString.LastIndexOf("."); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + TLD = domainString.Substring(tldIndex + 1); + break; + } + + // See if we have a subdomain: + List lstRemainingParts = new List(tempSudomainAndDomain.Split('.')); + + // If we have 0 parts left, there is just a tld and no domain or subdomain + // If we have 1 part, it's the domain, and there is no subdomain + // If we have 2+ parts, the last part is the domain, the other parts (combined) are the subdomain + if (lstRemainingParts.Count > 0) + { + // Set the domain: + SLD = lstRemainingParts[lstRemainingParts.Count - 1]; + + // Set the subdomain, if there is one to set: + if (lstRemainingParts.Count > 1) { - ruleMatches.Add(result); + // We strip off the trailing period, too + SubDomain = tempSudomainAndDomain.Substring(0, tempSudomainAndDomain.Length - SLD.Length - 1); } } } - // Sort our matches list (longest rule wins, according to : - var results = from match in ruleMatches - orderby match.Name.Length descending - select match; - - // Take the top result (our primary match): - var primaryMatch = results.Take(1).SingleOrDefault(); - return primaryMatch; - } - - public class TLDRule : IComparable - { - public string Name { get; private set; } - public RuleType Type { get; private set; } - - public TLDRule(string RuleInfo) + private static TLDRule FindMatchingTLDRule(string domainString) { - // Parse the rule and set properties accordingly: - if (RuleInfo.StartsWith("*")) + // Split our domain into parts (based on the '.') + // ...Put these parts in a list + // ...Make sure these parts are in reverse order + // (we'll be checking rules from the right-most pat of the domain) + var lstDomainParts = domainString.Split('.').ToList(); + lstDomainParts.Reverse(); + + // Begin building our partial domain to check rules with: + var checkAgainst = string.Empty; + + // Our 'matches' collection: + var ruleMatches = new List(); + + foreach (string domainPart in lstDomainParts) { - Type = RuleType.Wildcard; - Name = RuleInfo.Substring(2); - } - else if (RuleInfo.StartsWith("!")) - { - Type = RuleType.Exception; - Name = RuleInfo.Substring(1); - } - else - { - Type = RuleType.Normal; - Name = RuleInfo; - } - } + // Add on our next domain part: + checkAgainst = string.Format("{0}.{1}", domainPart, checkAgainst); - public int CompareTo(TLDRule other) - { - if (other == null) - { - return -1; - } - - return Name.CompareTo(other.Name); - } - - public enum RuleType - { - Normal, - Wildcard, - Exception - } - } - - public class TLDRulesCache - { - private static volatile TLDRulesCache _uniqueInstance; - private static object _syncObj = new object(); - private static object _syncList = new object(); - - private TLDRulesCache() - { - // Initialize our internal list: - TLDRuleLists = GetTLDRules(); - } - - public static TLDRulesCache Instance - { - get - { - if (_uniqueInstance == null) + // If we end in a period, strip it off: + if (checkAgainst.EndsWith(".")) { - lock (_syncObj) + checkAgainst = checkAgainst.Substring(0, checkAgainst.Length - 1); + } + + var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); + foreach (var rule in rules) + { + // Try to match rule: + TLDRule result; + if (TLDRulesCache.Instance.TLDRuleLists[rule].TryGetValue(checkAgainst, out result)) { - if (_uniqueInstance == null) - { - _uniqueInstance = new TLDRulesCache(); - } + ruleMatches.Add(result); } } - return (_uniqueInstance); } + + // Sort our matches list (longest rule wins, according to : + var results = from match in ruleMatches + orderby match.Name.Length descending + select match; + + // Take the top result (our primary match): + var primaryMatch = results.Take(1).SingleOrDefault(); + return primaryMatch; } - public IDictionary> TLDRuleLists { get; set; } - - public static void Reset() + public class TLDRule : IComparable { - lock (_syncObj) + public string Name { get; private set; } + public RuleType Type { get; private set; } + + public TLDRule(string RuleInfo) { - _uniqueInstance = null; - } - } - - private IDictionary> GetTLDRules() - { - var results = new Dictionary>(); - var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); - foreach (var rule in rules) - { - results[rule] = new Dictionary(StringComparer.CurrentCultureIgnoreCase); - } - - var ruleStrings = ReadRulesData(); - - // Strip out any lines that are: - // a.) A comment - // b.) Blank - var rulesStrings = ruleStrings - .Where(ruleString => !ruleString.StartsWith("//") && ruleString.Trim().Length != 0); - foreach (var ruleString in rulesStrings) - { - var result = new TLDRule(ruleString); - results[result.Type][result.Name] = result; - } - - // Return our results: - Debug.WriteLine(string.Format("Loaded {0} rules into cache.", - results.Values.Sum(r => r.Values.Count))); - return results; - } - - private IEnumerable ReadRulesData() - { - var assembly = typeof(TLDRulesCache).GetTypeInfo().Assembly; - var stream = assembly.GetManifestResourceStream("Bit.Icons.Resources.public_suffix_list.dat"); - string line; - using (var reader = new StreamReader(stream)) - { - while ((line = reader.ReadLine()) != null) + // Parse the rule and set properties accordingly: + if (RuleInfo.StartsWith("*")) { - yield return line; + Type = RuleType.Wildcard; + Name = RuleInfo.Substring(2); + } + else if (RuleInfo.StartsWith("!")) + { + Type = RuleType.Exception; + Name = RuleInfo.Substring(1); + } + else + { + Type = RuleType.Normal; + Name = RuleInfo; + } + } + + public int CompareTo(TLDRule other) + { + if (other == null) + { + return -1; + } + + return Name.CompareTo(other.Name); + } + + public enum RuleType + { + Normal, + Wildcard, + Exception + } + } + + public class TLDRulesCache + { + private static volatile TLDRulesCache _uniqueInstance; + private static object _syncObj = new object(); + private static object _syncList = new object(); + + private TLDRulesCache() + { + // Initialize our internal list: + TLDRuleLists = GetTLDRules(); + } + + public static TLDRulesCache Instance + { + get + { + if (_uniqueInstance == null) + { + lock (_syncObj) + { + if (_uniqueInstance == null) + { + _uniqueInstance = new TLDRulesCache(); + } + } + } + return (_uniqueInstance); + } + } + + public IDictionary> TLDRuleLists { get; set; } + + public static void Reset() + { + lock (_syncObj) + { + _uniqueInstance = null; + } + } + + private IDictionary> GetTLDRules() + { + var results = new Dictionary>(); + var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); + foreach (var rule in rules) + { + results[rule] = new Dictionary(StringComparer.CurrentCultureIgnoreCase); + } + + var ruleStrings = ReadRulesData(); + + // Strip out any lines that are: + // a.) A comment + // b.) Blank + var rulesStrings = ruleStrings + .Where(ruleString => !ruleString.StartsWith("//") && ruleString.Trim().Length != 0); + foreach (var ruleString in rulesStrings) + { + var result = new TLDRule(ruleString); + results[result.Type][result.Name] = result; + } + + // Return our results: + Debug.WriteLine(string.Format("Loaded {0} rules into cache.", + results.Values.Sum(r => r.Values.Count))); + return results; + } + + private IEnumerable ReadRulesData() + { + var assembly = typeof(TLDRulesCache).GetTypeInfo().Assembly; + var stream = assembly.GetManifestResourceStream("Bit.Icons.Resources.public_suffix_list.dat"); + string line; + using (var reader = new StreamReader(stream)) + { + while ((line = reader.ReadLine()) != null) + { + yield return line; + } } } } diff --git a/src/Icons/Models/Icon.cs b/src/Icons/Models/Icon.cs index 8bd23541fa..cca6d78d55 100644 --- a/src/Icons/Models/Icon.cs +++ b/src/Icons/Models/Icon.cs @@ -1,7 +1,8 @@ -namespace Bit.Icons.Models; - -public class Icon +namespace Bit.Icons.Models { - public byte[] Image { get; set; } - public string Format { get; set; } + public class Icon + { + public byte[] Image { get; set; } + public string Format { get; set; } + } } diff --git a/src/Icons/Models/IconResult.cs b/src/Icons/Models/IconResult.cs index ca1e6929ed..104c2627ad 100644 --- a/src/Icons/Models/IconResult.cs +++ b/src/Icons/Models/IconResult.cs @@ -1,65 +1,66 @@ -namespace Bit.Icons.Models; - -public class IconResult +namespace Bit.Icons.Models { - public IconResult(string href, string sizes) + public class IconResult { - Path = href; - if (!string.IsNullOrWhiteSpace(sizes)) + public IconResult(string href, string sizes) { - var sizeParts = sizes.Split('x'); - if (sizeParts.Length == 2 && int.TryParse(sizeParts[0].Trim(), out var width) && - int.TryParse(sizeParts[1].Trim(), out var height)) + Path = href; + if (!string.IsNullOrWhiteSpace(sizes)) { - DefinedWidth = width; - DefinedHeight = height; - - if (width == height) + var sizeParts = sizes.Split('x'); + if (sizeParts.Length == 2 && int.TryParse(sizeParts[0].Trim(), out var width) && + int.TryParse(sizeParts[1].Trim(), out var height)) { - if (width == 32) + DefinedWidth = width; + DefinedHeight = height; + + if (width == height) { - Priority = 1; - } - else if (width == 64) - { - Priority = 2; - } - else if (width >= 24 && width <= 128) - { - Priority = 3; - } - else if (width == 16) - { - Priority = 4; - } - else - { - Priority = 100; + if (width == 32) + { + Priority = 1; + } + else if (width == 64) + { + Priority = 2; + } + else if (width >= 24 && width <= 128) + { + Priority = 3; + } + else if (width == 16) + { + Priority = 4; + } + else + { + Priority = 100; + } } } } + + if (Priority == 0) + { + Priority = 200; + } } - if (Priority == 0) + public IconResult(Uri uri, byte[] bytes, string format) { - Priority = 200; + Path = uri.ToString(); + Icon = new Icon + { + Image = bytes, + Format = format + }; + Priority = 10; } - } - public IconResult(Uri uri, byte[] bytes, string format) - { - Path = uri.ToString(); - Icon = new Icon - { - Image = bytes, - Format = format - }; - Priority = 10; + public string Path { get; set; } + public int? DefinedWidth { get; set; } + public int? DefinedHeight { get; set; } + public Icon Icon { get; set; } + public int Priority { get; set; } } - - public string Path { get; set; } - public int? DefinedWidth { get; set; } - public int? DefinedHeight { get; set; } - public Icon Icon { get; set; } - public int Priority { get; set; } } diff --git a/src/Icons/Program.cs b/src/Icons/Program.cs index d57a6fd1cd..1f65ea4067 100644 --- a/src/Icons/Program.cs +++ b/src/Icons/Program.cs @@ -1,21 +1,22 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Icons; - -public class Program +namespace Bit.Icons { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Error)); - }) - .Build() - .Run(); + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Error)); + }) + .Build() + .Run(); + } } } diff --git a/src/Icons/Services/DomainMappingService.cs b/src/Icons/Services/DomainMappingService.cs index b41d48233f..406145af92 100644 --- a/src/Icons/Services/DomainMappingService.cs +++ b/src/Icons/Services/DomainMappingService.cs @@ -1,23 +1,24 @@ -namespace Bit.Icons.Services; - -public class DomainMappingService : IDomainMappingService +namespace Bit.Icons.Services { - private readonly Dictionary _map = new Dictionary + public class DomainMappingService : IDomainMappingService { - ["login.yahoo.com"] = "yahoo.com", - ["accounts.google.com"] = "google.com", - ["photo.walgreens.com"] = "walgreens.com", - ["passport.yandex.com"] = "yandex.com", - // TODO: Add others here - }; - - public string MapDomain(string hostname) - { - if (_map.ContainsKey(hostname)) + private readonly Dictionary _map = new Dictionary { - return _map[hostname]; - } + ["login.yahoo.com"] = "yahoo.com", + ["accounts.google.com"] = "google.com", + ["photo.walgreens.com"] = "walgreens.com", + ["passport.yandex.com"] = "yandex.com", + // TODO: Add others here + }; - return hostname; + public string MapDomain(string hostname) + { + if (_map.ContainsKey(hostname)) + { + return _map[hostname]; + } + + return hostname; + } } } diff --git a/src/Icons/Services/IDomainMappingService.cs b/src/Icons/Services/IDomainMappingService.cs index 4ee3f45948..194ee8f641 100644 --- a/src/Icons/Services/IDomainMappingService.cs +++ b/src/Icons/Services/IDomainMappingService.cs @@ -1,6 +1,7 @@ -namespace Bit.Icons.Services; - -public interface IDomainMappingService +namespace Bit.Icons.Services { - string MapDomain(string hostname); + public interface IDomainMappingService + { + string MapDomain(string hostname); + } } diff --git a/src/Icons/Services/IIconFetchingService.cs b/src/Icons/Services/IIconFetchingService.cs index ff6704291f..4c15ddffb5 100644 --- a/src/Icons/Services/IIconFetchingService.cs +++ b/src/Icons/Services/IIconFetchingService.cs @@ -1,8 +1,9 @@ using Bit.Icons.Models; -namespace Bit.Icons.Services; - -public interface IIconFetchingService +namespace Bit.Icons.Services { - Task GetIconAsync(string domain); + public interface IIconFetchingService + { + Task GetIconAsync(string domain); + } } diff --git a/src/Icons/Services/IconFetchingService.cs b/src/Icons/Services/IconFetchingService.cs index 166d5a0aa7..e7ae384507 100644 --- a/src/Icons/Services/IconFetchingService.cs +++ b/src/Icons/Services/IconFetchingService.cs @@ -3,447 +3,448 @@ using System.Text; using AngleSharp.Html.Parser; using Bit.Icons.Models; -namespace Bit.Icons.Services; - -public class IconFetchingService : IIconFetchingService +namespace Bit.Icons.Services { - private readonly HashSet _iconRels = - new HashSet { "icon", "apple-touch-icon", "shortcut icon" }; - private readonly HashSet _blacklistedRels = - new HashSet { "preload", "image_src", "preconnect", "canonical", "alternate", "stylesheet" }; - private readonly HashSet _iconExtensions = - new HashSet { ".ico", ".png", ".jpg", ".jpeg" }; - - private readonly string _pngMediaType = "image/png"; - private readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 }; - private readonly byte[] _webpHeader = Encoding.UTF8.GetBytes("RIFF"); - - private readonly string _icoMediaType = "image/x-icon"; - private readonly string _icoAltMediaType = "image/vnd.microsoft.icon"; - private readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 }; - - private readonly string _jpegMediaType = "image/jpeg"; - private readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 }; - - private readonly HashSet _allowedMediaTypes; - private readonly HttpClient _httpClient; - private readonly ILogger _logger; - - public IconFetchingService(ILogger logger) + public class IconFetchingService : IIconFetchingService { - _logger = logger; - _allowedMediaTypes = new HashSet - { - _pngMediaType, - _icoMediaType, - _icoAltMediaType, - _jpegMediaType - }; + private readonly HashSet _iconRels = + new HashSet { "icon", "apple-touch-icon", "shortcut icon" }; + private readonly HashSet _blacklistedRels = + new HashSet { "preload", "image_src", "preconnect", "canonical", "alternate", "stylesheet" }; + private readonly HashSet _iconExtensions = + new HashSet { ".ico", ".png", ".jpg", ".jpeg" }; - _httpClient = new HttpClient(new HttpClientHandler - { - AllowAutoRedirect = false, - AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, - }); - _httpClient.Timeout = TimeSpan.FromSeconds(20); - _httpClient.MaxResponseContentBufferSize = 5000000; // 5 MB - } + private readonly string _pngMediaType = "image/png"; + private readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 }; + private readonly byte[] _webpHeader = Encoding.UTF8.GetBytes("RIFF"); - public async Task GetIconAsync(string domain) - { - if (IPAddress.TryParse(domain, out _)) - { - _logger.LogWarning("IP address: {0}.", domain); - return null; - } + private readonly string _icoMediaType = "image/x-icon"; + private readonly string _icoAltMediaType = "image/vnd.microsoft.icon"; + private readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 }; - if (!Uri.TryCreate($"https://{domain}", UriKind.Absolute, out var parsedHttpsUri)) - { - _logger.LogWarning("Bad domain: {0}.", domain); - return null; - } + private readonly string _jpegMediaType = "image/jpeg"; + private readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 }; - var uri = parsedHttpsUri; - var response = await GetAndFollowAsync(uri, 2); - if ((response == null || !response.IsSuccessStatusCode) && - Uri.TryCreate($"http://{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedHttpUri)) - { - Cleanup(response); - uri = parsedHttpUri; - response = await GetAndFollowAsync(uri, 2); + private readonly HashSet _allowedMediaTypes; + private readonly HttpClient _httpClient; + private readonly ILogger _logger; - if (response == null || !response.IsSuccessStatusCode) + public IconFetchingService(ILogger logger) + { + _logger = logger; + _allowedMediaTypes = new HashSet { - var dotCount = domain.Count(c => c == '.'); - if (dotCount > 1 && DomainName.TryParseBaseDomain(domain, out var baseDomain) && - Uri.TryCreate($"https://{baseDomain}", UriKind.Absolute, out var parsedBaseUri)) - { - Cleanup(response); - uri = parsedBaseUri; - response = await GetAndFollowAsync(uri, 2); - } - else if (dotCount < 2 && - Uri.TryCreate($"https://www.{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedWwwUri)) - { - Cleanup(response); - uri = parsedWwwUri; - response = await GetAndFollowAsync(uri, 2); - } - } - } + _pngMediaType, + _icoMediaType, + _icoAltMediaType, + _jpegMediaType + }; - if (response?.Content == null || !response.IsSuccessStatusCode) - { - _logger.LogWarning("Couldn't load a website for {0}: {1}.", domain, - response?.StatusCode.ToString() ?? "null"); - Cleanup(response); - return null; - } - - var parser = new HtmlParser(); - using (response) - using (var htmlStream = await response.Content.ReadAsStreamAsync()) - using (var document = await parser.ParseDocumentAsync(htmlStream)) - { - uri = response.RequestMessage.RequestUri; - if (document.DocumentElement == null) + _httpClient = new HttpClient(new HttpClientHandler { - _logger.LogWarning("No DocumentElement for {0}.", domain); + AllowAutoRedirect = false, + AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, + }); + _httpClient.Timeout = TimeSpan.FromSeconds(20); + _httpClient.MaxResponseContentBufferSize = 5000000; // 5 MB + } + + public async Task GetIconAsync(string domain) + { + if (IPAddress.TryParse(domain, out _)) + { + _logger.LogWarning("IP address: {0}.", domain); return null; } - var baseUrl = "/"; - var baseUrlNode = document.QuerySelector("head base[href]"); - if (baseUrlNode != null) + if (!Uri.TryCreate($"https://{domain}", UriKind.Absolute, out var parsedHttpsUri)) { - var hrefAttr = baseUrlNode.Attributes["href"]; - if (!string.IsNullOrWhiteSpace(hrefAttr?.Value)) - { - baseUrl = hrefAttr.Value; - } - - baseUrlNode = null; - hrefAttr = null; + _logger.LogWarning("Bad domain: {0}.", domain); + return null; } - var icons = new List(); - var links = document.QuerySelectorAll("head link[href]"); - if (links != null) + var uri = parsedHttpsUri; + var response = await GetAndFollowAsync(uri, 2); + if ((response == null || !response.IsSuccessStatusCode) && + Uri.TryCreate($"http://{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedHttpUri)) { - foreach (var link in links.Take(200)) + Cleanup(response); + uri = parsedHttpUri; + response = await GetAndFollowAsync(uri, 2); + + if (response == null || !response.IsSuccessStatusCode) { - var hrefAttr = link.Attributes["href"]; - if (string.IsNullOrWhiteSpace(hrefAttr?.Value)) + var dotCount = domain.Count(c => c == '.'); + if (dotCount > 1 && DomainName.TryParseBaseDomain(domain, out var baseDomain) && + Uri.TryCreate($"https://{baseDomain}", UriKind.Absolute, out var parsedBaseUri)) { - continue; + Cleanup(response); + uri = parsedBaseUri; + response = await GetAndFollowAsync(uri, 2); + } + else if (dotCount < 2 && + Uri.TryCreate($"https://www.{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedWwwUri)) + { + Cleanup(response); + uri = parsedWwwUri; + response = await GetAndFollowAsync(uri, 2); + } + } + } + + if (response?.Content == null || !response.IsSuccessStatusCode) + { + _logger.LogWarning("Couldn't load a website for {0}: {1}.", domain, + response?.StatusCode.ToString() ?? "null"); + Cleanup(response); + return null; + } + + var parser = new HtmlParser(); + using (response) + using (var htmlStream = await response.Content.ReadAsStreamAsync()) + using (var document = await parser.ParseDocumentAsync(htmlStream)) + { + uri = response.RequestMessage.RequestUri; + if (document.DocumentElement == null) + { + _logger.LogWarning("No DocumentElement for {0}.", domain); + return null; + } + + var baseUrl = "/"; + var baseUrlNode = document.QuerySelector("head base[href]"); + if (baseUrlNode != null) + { + var hrefAttr = baseUrlNode.Attributes["href"]; + if (!string.IsNullOrWhiteSpace(hrefAttr?.Value)) + { + baseUrl = hrefAttr.Value; } - var relAttr = link.Attributes["rel"]; - var sizesAttr = link.Attributes["sizes"]; - if (relAttr != null && _iconRels.Contains(relAttr.Value.ToLower())) - { - icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); - } - else if (relAttr == null || !_blacklistedRels.Contains(relAttr.Value.ToLower())) - { - try - { - var extension = Path.GetExtension(hrefAttr.Value); - if (_iconExtensions.Contains(extension.ToLower())) - { - icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); - } - } - catch (ArgumentException) { } - } - - sizesAttr = null; - relAttr = null; + baseUrlNode = null; hrefAttr = null; } - links = null; - } - - var iconResultTasks = new List(); - foreach (var icon in icons.OrderBy(i => i.Priority).Take(10)) - { - Uri iconUri = null; - if (icon.Path.StartsWith("//") && Uri.TryCreate($"{GetScheme(uri)}://{icon.Path.Substring(2)}", - UriKind.Absolute, out var slashUri)) + var icons = new List(); + var links = document.QuerySelectorAll("head link[href]"); + if (links != null) { - iconUri = slashUri; - } - else if (Uri.TryCreate(icon.Path, UriKind.Relative, out var relUri)) - { - iconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", baseUrl, relUri.OriginalString); - } - else if (Uri.TryCreate(icon.Path, UriKind.Absolute, out var absUri)) - { - iconUri = absUri; - } - - if (iconUri != null) - { - var task = GetIconAsync(iconUri).ContinueWith(async (r) => + foreach (var link in links.Take(200)) { - var result = await r; - if (result != null) + var hrefAttr = link.Attributes["href"]; + if (string.IsNullOrWhiteSpace(hrefAttr?.Value)) { - icon.Path = iconUri.ToString(); - icon.Icon = result.Icon; + continue; } - }); - iconResultTasks.Add(task); - } - } - await Task.WhenAll(iconResultTasks); - if (!icons.Any(i => i.Icon != null)) - { - var faviconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", "favicon.ico"); - var result = await GetIconAsync(faviconUri); - if (result != null) - { - icons.Add(result); + var relAttr = link.Attributes["rel"]; + var sizesAttr = link.Attributes["sizes"]; + if (relAttr != null && _iconRels.Contains(relAttr.Value.ToLower())) + { + icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); + } + else if (relAttr == null || !_blacklistedRels.Contains(relAttr.Value.ToLower())) + { + try + { + var extension = Path.GetExtension(hrefAttr.Value); + if (_iconExtensions.Contains(extension.ToLower())) + { + icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); + } + } + catch (ArgumentException) { } + } + + sizesAttr = null; + relAttr = null; + hrefAttr = null; + } + + links = null; } - else + + var iconResultTasks = new List(); + foreach (var icon in icons.OrderBy(i => i.Priority).Take(10)) { - _logger.LogWarning("No favicon.ico found for {0}.", uri.Host); + Uri iconUri = null; + if (icon.Path.StartsWith("//") && Uri.TryCreate($"{GetScheme(uri)}://{icon.Path.Substring(2)}", + UriKind.Absolute, out var slashUri)) + { + iconUri = slashUri; + } + else if (Uri.TryCreate(icon.Path, UriKind.Relative, out var relUri)) + { + iconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", baseUrl, relUri.OriginalString); + } + else if (Uri.TryCreate(icon.Path, UriKind.Absolute, out var absUri)) + { + iconUri = absUri; + } + + if (iconUri != null) + { + var task = GetIconAsync(iconUri).ContinueWith(async (r) => + { + var result = await r; + if (result != null) + { + icon.Path = iconUri.ToString(); + icon.Icon = result.Icon; + } + }); + iconResultTasks.Add(task); + } + } + + await Task.WhenAll(iconResultTasks); + if (!icons.Any(i => i.Icon != null)) + { + var faviconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", "favicon.ico"); + var result = await GetIconAsync(faviconUri); + if (result != null) + { + icons.Add(result); + } + else + { + _logger.LogWarning("No favicon.ico found for {0}.", uri.Host); + return null; + } + } + + return icons.Where(i => i.Icon != null).OrderBy(i => i.Priority).First(); + } + } + + private async Task GetIconAsync(Uri uri) + { + using (var response = await GetAndFollowAsync(uri, 2)) + { + if (response?.Content?.Headers == null || !response.IsSuccessStatusCode) + { + response?.Content?.Dispose(); return null; } + + var format = response.Content.Headers?.ContentType?.MediaType; + var bytes = await response.Content.ReadAsByteArrayAsync(); + response.Content.Dispose(); + if (format == null || !_allowedMediaTypes.Contains(format)) + { + if (HeaderMatch(bytes, _icoHeader)) + { + format = _icoMediaType; + } + else if (HeaderMatch(bytes, _pngHeader) || HeaderMatch(bytes, _webpHeader)) + { + format = _pngMediaType; + } + else if (HeaderMatch(bytes, _jpegHeader)) + { + format = _jpegMediaType; + } + else + { + return null; + } + } + + return new IconResult(uri, bytes, format); } - - return icons.Where(i => i.Icon != null).OrderBy(i => i.Priority).First(); - } - } - - private async Task GetIconAsync(Uri uri) - { - using (var response = await GetAndFollowAsync(uri, 2)) - { - if (response?.Content?.Headers == null || !response.IsSuccessStatusCode) - { - response?.Content?.Dispose(); - return null; - } - - var format = response.Content.Headers?.ContentType?.MediaType; - var bytes = await response.Content.ReadAsByteArrayAsync(); - response.Content.Dispose(); - if (format == null || !_allowedMediaTypes.Contains(format)) - { - if (HeaderMatch(bytes, _icoHeader)) - { - format = _icoMediaType; - } - else if (HeaderMatch(bytes, _pngHeader) || HeaderMatch(bytes, _webpHeader)) - { - format = _pngMediaType; - } - else if (HeaderMatch(bytes, _jpegHeader)) - { - format = _jpegMediaType; - } - else - { - return null; - } - } - - return new IconResult(uri, bytes, format); - } - } - - private async Task GetAndFollowAsync(Uri uri, int maxRedirectCount) - { - var response = await GetAsync(uri); - if (response == null) - { - return null; - } - return await FollowRedirectsAsync(response, maxRedirectCount); - } - - private async Task GetAsync(Uri uri) - { - if (uri == null) - { - return null; } - // Prevent non-http(s) and non-default ports - if ((uri.Scheme != "http" && uri.Scheme != "https") || !uri.IsDefaultPort) + private async Task GetAndFollowAsync(Uri uri, int maxRedirectCount) { - return null; - } - - // Prevent local hosts (localhost, bobs-pc, etc) and IP addresses - if (!uri.Host.Contains(".") || IPAddress.TryParse(uri.Host, out _)) - { - return null; - } - - // Resolve host to make sure it is not an internal/private IP address - try - { - var hostEntry = Dns.GetHostEntry(uri.Host); - if (hostEntry?.AddressList.Any(ip => IsInternal(ip)) ?? true) + var response = await GetAsync(uri); + if (response == null) { return null; } - } - catch - { - return null; + return await FollowRedirectsAsync(response, maxRedirectCount); } - using (var message = new HttpRequestMessage()) + private async Task GetAsync(Uri uri) { - message.RequestUri = uri; - message.Method = HttpMethod.Get; + if (uri == null) + { + return null; + } - // Let's add some headers to look like we're coming from a web browser request. Some websites - // will block our request without these. - message.Headers.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + - "(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36 Edge/16.16299"); - message.Headers.Add("Accept-Language", "en-US,en;q=0.8"); - message.Headers.Add("Cache-Control", "no-cache"); - message.Headers.Add("Pragma", "no-cache"); - message.Headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;" + - "q=0.9,image/webp,image/apng,*/*;q=0.8"); + // Prevent non-http(s) and non-default ports + if ((uri.Scheme != "http" && uri.Scheme != "https") || !uri.IsDefaultPort) + { + return null; + } + // Prevent local hosts (localhost, bobs-pc, etc) and IP addresses + if (!uri.Host.Contains(".") || IPAddress.TryParse(uri.Host, out _)) + { + return null; + } + + // Resolve host to make sure it is not an internal/private IP address try { - return await _httpClient.SendAsync(message); + var hostEntry = Dns.GetHostEntry(uri.Host); + if (hostEntry?.AddressList.Any(ip => IsInternal(ip)) ?? true) + { + return null; + } } catch { return null; } - } - } - private async Task FollowRedirectsAsync(HttpResponseMessage response, - int maxFollowCount, int followCount = 0) - { - if (response == null || response.IsSuccessStatusCode || followCount > maxFollowCount) - { - return response; - } - - if (!(response.StatusCode == HttpStatusCode.Redirect || - response.StatusCode == HttpStatusCode.MovedPermanently || - response.StatusCode == HttpStatusCode.RedirectKeepVerb || - response.StatusCode == HttpStatusCode.SeeOther) || - response.Headers.Location == null) - { - Cleanup(response); - return null; - } - - Uri location = null; - if (response.Headers.Location.IsAbsoluteUri) - { - if (response.Headers.Location.Scheme != "http" && response.Headers.Location.Scheme != "https") + using (var message = new HttpRequestMessage()) { - if (Uri.TryCreate($"https://{response.Headers.Location.OriginalString}", - UriKind.Absolute, out var newUri)) + message.RequestUri = uri; + message.Method = HttpMethod.Get; + + // Let's add some headers to look like we're coming from a web browser request. Some websites + // will block our request without these. + message.Headers.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + + "(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36 Edge/16.16299"); + message.Headers.Add("Accept-Language", "en-US,en;q=0.8"); + message.Headers.Add("Cache-Control", "no-cache"); + message.Headers.Add("Pragma", "no-cache"); + message.Headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;" + + "q=0.9,image/webp,image/apng,*/*;q=0.8"); + + try { - location = newUri; + return await _httpClient.SendAsync(message); + } + catch + { + return null; + } + } + } + + private async Task FollowRedirectsAsync(HttpResponseMessage response, + int maxFollowCount, int followCount = 0) + { + if (response == null || response.IsSuccessStatusCode || followCount > maxFollowCount) + { + return response; + } + + if (!(response.StatusCode == HttpStatusCode.Redirect || + response.StatusCode == HttpStatusCode.MovedPermanently || + response.StatusCode == HttpStatusCode.RedirectKeepVerb || + response.StatusCode == HttpStatusCode.SeeOther) || + response.Headers.Location == null) + { + Cleanup(response); + return null; + } + + Uri location = null; + if (response.Headers.Location.IsAbsoluteUri) + { + if (response.Headers.Location.Scheme != "http" && response.Headers.Location.Scheme != "https") + { + if (Uri.TryCreate($"https://{response.Headers.Location.OriginalString}", + UriKind.Absolute, out var newUri)) + { + location = newUri; + } + } + else + { + location = response.Headers.Location; } } else { - location = response.Headers.Location; + var requestUri = response.RequestMessage.RequestUri; + location = ResolveUri($"{GetScheme(requestUri)}://{requestUri.Host}", + response.Headers.Location.OriginalString); } - } - else - { - var requestUri = response.RequestMessage.RequestUri; - location = ResolveUri($"{GetScheme(requestUri)}://{requestUri.Host}", - response.Headers.Location.OriginalString); - } - Cleanup(response); - var newResponse = await GetAsync(location); - if (newResponse != null) - { - followCount++; - var redirectedResponse = await FollowRedirectsAsync(newResponse, maxFollowCount, followCount); - if (redirectedResponse != null) + Cleanup(response); + var newResponse = await GetAsync(location); + if (newResponse != null) { - if (redirectedResponse != newResponse) + followCount++; + var redirectedResponse = await FollowRedirectsAsync(newResponse, maxFollowCount, followCount); + if (redirectedResponse != null) { - Cleanup(newResponse); + if (redirectedResponse != newResponse) + { + Cleanup(newResponse); + } + return redirectedResponse; } - return redirectedResponse; } + + return null; } - return null; - } - - private bool HeaderMatch(byte[] imageBytes, byte[] header) - { - return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length)); - } - - private Uri ResolveUri(string baseUrl, params string[] paths) - { - var url = baseUrl; - foreach (var path in paths) + private bool HeaderMatch(byte[] imageBytes, byte[] header) { - if (Uri.TryCreate(new Uri(url), path, out var r)) + return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length)); + } + + private Uri ResolveUri(string baseUrl, params string[] paths) + { + var url = baseUrl; + foreach (var path in paths) { - url = r.ToString(); + if (Uri.TryCreate(new Uri(url), path, out var r)) + { + url = r.ToString(); + } } - } - return new Uri(url); - } - - private void Cleanup(IDisposable obj) - { - obj?.Dispose(); - obj = null; - } - - private string GetScheme(Uri uri) - { - return uri != null && uri.Scheme == "http" ? "http" : "https"; - } - - public static bool IsInternal(IPAddress ip) - { - if (IPAddress.IsLoopback(ip)) - { - return true; + return new Uri(url); } - var ipString = ip.ToString(); - if (ipString == "::1" || ipString == "::" || ipString.StartsWith("::ffff:")) + private void Cleanup(IDisposable obj) { - return true; + obj?.Dispose(); + obj = null; } - // IPv6 - if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) + private string GetScheme(Uri uri) { - return ipString.StartsWith("fc") || ipString.StartsWith("fd") || - ipString.StartsWith("fe") || ipString.StartsWith("ff"); + return uri != null && uri.Scheme == "http" ? "http" : "https"; } - // IPv4 - var bytes = ip.GetAddressBytes(); - return (bytes[0]) switch + public static bool IsInternal(IPAddress ip) { - 0 => true, - 10 => true, - 127 => true, - 169 => bytes[1] == 254, // Cloud environments, such as AWS - 172 => bytes[1] < 32 && bytes[1] >= 16, - 192 => bytes[1] == 168, - _ => false, - }; + if (IPAddress.IsLoopback(ip)) + { + return true; + } + + var ipString = ip.ToString(); + if (ipString == "::1" || ipString == "::" || ipString.StartsWith("::ffff:")) + { + return true; + } + + // IPv6 + if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) + { + return ipString.StartsWith("fc") || ipString.StartsWith("fd") || + ipString.StartsWith("fe") || ipString.StartsWith("ff"); + } + + // IPv4 + var bytes = ip.GetAddressBytes(); + return (bytes[0]) switch + { + 0 => true, + 10 => true, + 127 => true, + 169 => bytes[1] == 254, // Cloud environments, such as AWS + 172 => bytes[1] < 32 && bytes[1] >= 16, + 192 => bytes[1] == 168, + _ => false, + }; + } } } diff --git a/src/Icons/Startup.cs b/src/Icons/Startup.cs index f64ea07edf..71442772b3 100644 --- a/src/Icons/Startup.cs +++ b/src/Icons/Startup.cs @@ -5,72 +5,73 @@ using Bit.Icons.Services; using Bit.SharedWeb.Utilities; using Microsoft.Net.Http.Headers; -namespace Bit.Icons; - -public class Startup +namespace Bit.Icons { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - var iconsSettings = new IconsSettings(); - ConfigurationBinder.Bind(Configuration.GetSection("IconsSettings"), iconsSettings); - services.AddSingleton(s => iconsSettings); - - // Cache - services.AddMemoryCache(options => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - options.SizeLimit = iconsSettings.CacheSizeLimit; - }); - - // Services - services.AddSingleton(); - services.AddSingleton(); - - // Mvc - services.AddMvc(); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - app.Use(async (context, next) => - { - context.Response.GetTypedHeaders().CacheControl = new CacheControlHeaderValue - { - Public = true, - MaxAge = TimeSpan.FromDays(7) - }; - await next(); - }); + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; } - app.UseRouting(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + var iconsSettings = new IconsSettings(); + ConfigurationBinder.Bind(Configuration.GetSection("IconsSettings"), iconsSettings); + services.AddSingleton(s => iconsSettings); + + // Cache + services.AddMemoryCache(options => + { + options.SizeLimit = iconsSettings.CacheSizeLimit; + }); + + // Services + services.AddSingleton(); + services.AddSingleton(); + + // Mvc + services.AddMvc(); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.Use(async (context, next) => + { + context.Response.GetTypedHeaders().CacheControl = new CacheControlHeaderValue + { + Public = true, + MaxAge = TimeSpan.FromDays(7) + }; + await next(); + }); + + app.UseRouting(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + } } } diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index 940e2ab97e..d7151a3ee0 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -9,60 +9,61 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Identity.Controllers; - -[Route("accounts")] -[ExceptionHandlerFilter] -public class AccountsController : Controller +namespace Bit.Identity.Controllers { - private readonly ILogger _logger; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - - public AccountsController( - ILogger logger, - IUserRepository userRepository, - IUserService userService) + [Route("accounts")] + [ExceptionHandlerFilter] + public class AccountsController : Controller { - _logger = logger; - _userRepository = userRepository; - _userService = userService; - } + private readonly ILogger _logger; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; - // Moved from API, If you modify this endpoint, please update API as well. - [HttpPost("register")] - [CaptchaProtected] - public async Task PostRegister([FromBody] RegisterRequestModel model) - { - var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, - model.Token, model.OrganizationUserId); - if (result.Succeeded) + public AccountsController( + ILogger logger, + IUserRepository userRepository, + IUserService userService) { - return; + _logger = logger; + _userRepository = userRepository; + _userService = userService; } - foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) + // Moved from API, If you modify this endpoint, please update API as well. + [HttpPost("register")] + [CaptchaProtected] + public async Task PostRegister([FromBody] RegisterRequestModel model) { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - // Moved from API, If you modify this endpoint, please update API as well. - [HttpPost("prelogin")] - public async Task PostPrelogin([FromBody] PreloginRequestModel model) - { - var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); - if (kdfInformation == null) - { - kdfInformation = new UserKdfInformation + var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, + model.Token, model.OrganizationUserId); + if (result.Succeeded) { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 100000, - }; + return; + } + + foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + // Moved from API, If you modify this endpoint, please update API as well. + [HttpPost("prelogin")] + public async Task PostPrelogin([FromBody] PreloginRequestModel model) + { + var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); + if (kdfInformation == null) + { + kdfInformation = new UserKdfInformation + { + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 100000, + }; + } + return new PreloginResponseModel(kdfInformation); } - return new PreloginResponseModel(kdfInformation); } } diff --git a/src/Identity/Controllers/InfoController.cs b/src/Identity/Controllers/InfoController.cs index c06812cdf5..d8c161c61d 100644 --- a/src/Identity/Controllers/InfoController.cs +++ b/src/Identity/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Identity.Controllers; - -public class InfoController : Controller +namespace Bit.Identity.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/src/Identity/Controllers/SsoController.cs b/src/Identity/Controllers/SsoController.cs index b61d89b86c..e3dc8f504f 100644 --- a/src/Identity/Controllers/SsoController.cs +++ b/src/Identity/Controllers/SsoController.cs @@ -11,264 +11,265 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Localization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Identity.Controllers; - -// TODO: 2022-01-12, Remove account alias -[Route("account/[action]")] -[Route("sso/[action]")] -public class SsoController : Controller +namespace Bit.Identity.Controllers { - private readonly IIdentityServerInteractionService _interaction; - private readonly ILogger _logger; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IUserRepository _userRepository; - private readonly IHttpClientFactory _clientFactory; - - public SsoController( - IIdentityServerInteractionService interaction, - ILogger logger, - ISsoConfigRepository ssoConfigRepository, - IUserRepository userRepository, - IHttpClientFactory clientFactory) + // TODO: 2022-01-12, Remove account alias + [Route("account/[action]")] + [Route("sso/[action]")] + public class SsoController : Controller { - _interaction = interaction; - _logger = logger; - _ssoConfigRepository = ssoConfigRepository; - _userRepository = userRepository; - _clientFactory = clientFactory; - } + private readonly IIdentityServerInteractionService _interaction; + private readonly ILogger _logger; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IUserRepository _userRepository; + private readonly IHttpClientFactory _clientFactory; - [HttpGet] - public async Task PreValidate(string domainHint) - { - if (string.IsNullOrWhiteSpace(domainHint)) + public SsoController( + IIdentityServerInteractionService interaction, + ILogger logger, + ISsoConfigRepository ssoConfigRepository, + IUserRepository userRepository, + IHttpClientFactory clientFactory) { - Response.StatusCode = 400; - return Json(new ErrorResponseModel("No domain hint was provided")); + _interaction = interaction; + _logger = logger; + _ssoConfigRepository = ssoConfigRepository; + _userRepository = userRepository; + _clientFactory = clientFactory; } - try - { - // Calls Sso Pre-Validate, assumes baseUri set - var requestCultureFeature = Request.HttpContext.Features.Get(); - var culture = requestCultureFeature.RequestCulture.Culture.Name; - var requestPath = $"/Account/PreValidate?domainHint={domainHint}&culture={culture}"; - var httpClient = _clientFactory.CreateClient("InternalSso"); - // Forward the internal SSO result - using var responseMessage = await httpClient.GetAsync(requestPath); - var responseJson = await responseMessage.Content.ReadAsStringAsync(); - Response.StatusCode = (int)responseMessage.StatusCode; - return Content(responseJson, "application/json"); - } - catch (Exception ex) + [HttpGet] + public async Task PreValidate(string domainHint) { - _logger.LogError(ex, "Error pre-validating against SSO service"); - Response.StatusCode = 500; - return Json(new ErrorResponseModel("Error pre-validating SSO authentication") + if (string.IsNullOrWhiteSpace(domainHint)) { - ExceptionMessage = ex.Message, - ExceptionStackTrace = ex.StackTrace, - InnerExceptionMessage = ex.InnerException?.Message, + Response.StatusCode = 400; + return Json(new ErrorResponseModel("No domain hint was provided")); + } + try + { + // Calls Sso Pre-Validate, assumes baseUri set + var requestCultureFeature = Request.HttpContext.Features.Get(); + var culture = requestCultureFeature.RequestCulture.Culture.Name; + var requestPath = $"/Account/PreValidate?domainHint={domainHint}&culture={culture}"; + var httpClient = _clientFactory.CreateClient("InternalSso"); + + // Forward the internal SSO result + using var responseMessage = await httpClient.GetAsync(requestPath); + var responseJson = await responseMessage.Content.ReadAsStringAsync(); + Response.StatusCode = (int)responseMessage.StatusCode; + return Content(responseJson, "application/json"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error pre-validating against SSO service"); + Response.StatusCode = 500; + return Json(new ErrorResponseModel("Error pre-validating SSO authentication") + { + ExceptionMessage = ex.Message, + ExceptionStackTrace = ex.StackTrace, + InnerExceptionMessage = ex.InnerException?.Message, + }); + } + } + + [HttpGet] + public async Task Login(string returnUrl) + { + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + + var domainHint = context.Parameters.AllKeys.Contains("domain_hint") ? + context.Parameters["domain_hint"] : null; + var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; + + if (string.IsNullOrWhiteSpace(domainHint)) + { + throw new Exception("No domain_hint provided"); + } + + var userIdentifier = context.Parameters.AllKeys.Contains("user_identifier") ? + context.Parameters["user_identifier"] : null; + + return RedirectToAction(nameof(ExternalChallenge), new + { + domainHint = domainHint, + returnUrl, + userIdentifier, + ssoToken, }); } - } - [HttpGet] - public async Task Login(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - - var domainHint = context.Parameters.AllKeys.Contains("domain_hint") ? - context.Parameters["domain_hint"] : null; - var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; - - if (string.IsNullOrWhiteSpace(domainHint)) + [HttpGet] + public async Task ExternalChallenge(string domainHint, string returnUrl, + string userIdentifier, string ssoToken) { - throw new Exception("No domain_hint provided"); - } - - var userIdentifier = context.Parameters.AllKeys.Contains("user_identifier") ? - context.Parameters["user_identifier"] : null; - - return RedirectToAction(nameof(ExternalChallenge), new - { - domainHint = domainHint, - returnUrl, - userIdentifier, - ssoToken, - }); - } - - [HttpGet] - public async Task ExternalChallenge(string domainHint, string returnUrl, - string userIdentifier, string ssoToken) - { - if (string.IsNullOrWhiteSpace(domainHint)) - { - throw new Exception("Invalid organization reference id."); - } - - var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); - if (ssoConfig == null || !ssoConfig.Enabled) - { - throw new Exception("Organization not found or SSO configuration not enabled"); - } - var organizationId = ssoConfig.OrganizationId.ToString(); - - var scheme = "sso"; - var props = new AuthenticationProperties - { - RedirectUri = Url.Action(nameof(ExternalCallback)), - Items = + if (string.IsNullOrWhiteSpace(domainHint)) { - { "return_url", returnUrl }, - { "domain_hint", domainHint }, - { "organizationId", organizationId }, - { "scheme", scheme }, - }, - Parameters = - { - { "ssoToken", ssoToken }, - } - }; - - if (!string.IsNullOrWhiteSpace(userIdentifier)) - { - props.Items.Add("user_identifier", userIdentifier); - } - - return Challenge(props, scheme); - } - - [HttpGet] - public async Task ExternalCallback() - { - // Read external identity from the temporary cookie - var result = await HttpContext.AuthenticateAsync( - Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - if (result?.Succeeded != true) - { - throw new Exception("External authentication error"); - } - - // Debugging - var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); - _logger.LogDebug("External claims: {@claims}", externalClaims); - - var (user, provider, providerUserId, claims) = await FindUserFromExternalProviderAsync(result); - if (user == null) - { - // Should never happen - throw new Exception("Cannot find user."); - } - - // This allows us to collect any additional claims or properties - // for the specific protocols used and store them in the local auth cookie. - // this is typically used to store data needed for signout from those protocols. - var additionalLocalClaims = new List(); - var localSignInProps = new AuthenticationProperties - { - IsPersistent = true, - ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) - }; - if (result.Properties != null && result.Properties.Items.TryGetValue("organizationId", out var organization)) - { - additionalLocalClaims.Add(new Claim("organizationId", organization)); - } - ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); - - // Issue authentication cookie for user - await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) - { - DisplayName = user.Email, - IdentityProvider = provider, - AdditionalClaims = additionalLocalClaims.ToArray() - }, localSignInProps); - - // Delete temporary cookie used during external authentication - await HttpContext.SignOutAsync(Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - - // Retrieve return URL - var returnUrl = result.Properties.Items["return_url"] ?? "~/"; - - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context != null) - { - if (IsNativeClient(context)) - { - // The client is native, so this change in how to - // return the response is for better UX for the end user. - HttpContext.Response.StatusCode = 200; - HttpContext.Response.Headers["Location"] = string.Empty; - return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); + throw new Exception("Invalid organization reference id."); } - // We can trust model.ReturnUrl since GetAuthorizationContextAsync returned non-null - return Redirect(returnUrl); + var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); + if (ssoConfig == null || !ssoConfig.Enabled) + { + throw new Exception("Organization not found or SSO configuration not enabled"); + } + var organizationId = ssoConfig.OrganizationId.ToString(); + + var scheme = "sso"; + var props = new AuthenticationProperties + { + RedirectUri = Url.Action(nameof(ExternalCallback)), + Items = + { + { "return_url", returnUrl }, + { "domain_hint", domainHint }, + { "organizationId", organizationId }, + { "scheme", scheme }, + }, + Parameters = + { + { "ssoToken", ssoToken }, + } + }; + + if (!string.IsNullOrWhiteSpace(userIdentifier)) + { + props.Items.Add("user_identifier", userIdentifier); + } + + return Challenge(props, scheme); } - // Request for a local page - if (Url.IsLocalUrl(returnUrl)) + [HttpGet] + public async Task ExternalCallback() { - return Redirect(returnUrl); + // Read external identity from the temporary cookie + var result = await HttpContext.AuthenticateAsync( + Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + if (result?.Succeeded != true) + { + throw new Exception("External authentication error"); + } + + // Debugging + var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); + _logger.LogDebug("External claims: {@claims}", externalClaims); + + var (user, provider, providerUserId, claims) = await FindUserFromExternalProviderAsync(result); + if (user == null) + { + // Should never happen + throw new Exception("Cannot find user."); + } + + // This allows us to collect any additional claims or properties + // for the specific protocols used and store them in the local auth cookie. + // this is typically used to store data needed for signout from those protocols. + var additionalLocalClaims = new List(); + var localSignInProps = new AuthenticationProperties + { + IsPersistent = true, + ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) + }; + if (result.Properties != null && result.Properties.Items.TryGetValue("organizationId", out var organization)) + { + additionalLocalClaims.Add(new Claim("organizationId", organization)); + } + ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); + + // Issue authentication cookie for user + await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) + { + DisplayName = user.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + + // Delete temporary cookie used during external authentication + await HttpContext.SignOutAsync(Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + + // Retrieve return URL + var returnUrl = result.Properties.Items["return_url"] ?? "~/"; + + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + if (context != null) + { + if (IsNativeClient(context)) + { + // The client is native, so this change in how to + // return the response is for better UX for the end user. + HttpContext.Response.StatusCode = 200; + HttpContext.Response.Headers["Location"] = string.Empty; + return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); + } + + // We can trust model.ReturnUrl since GetAuthorizationContextAsync returned non-null + return Redirect(returnUrl); + } + + // Request for a local page + if (Url.IsLocalUrl(returnUrl)) + { + return Redirect(returnUrl); + } + else if (string.IsNullOrEmpty(returnUrl)) + { + return Redirect("~/"); + } + else + { + // User might have clicked on a malicious link - should be logged + throw new Exception("invalid return URL"); + } } - else if (string.IsNullOrEmpty(returnUrl)) + + private async Task<(User user, string provider, string providerUserId, IEnumerable claims)> + FindUserFromExternalProviderAsync(AuthenticateResult result) { - return Redirect("~/"); + var externalUser = result.Principal; + + // Try to determine the unique id of the external user (issued by the provider) + // the most common claim type for that are the sub claim and the NameIdentifier + // depending on the external provider, some other claim type might be used + var userIdClaim = externalUser.FindFirst(JwtClaimTypes.Subject) ?? + externalUser.FindFirst(ClaimTypes.NameIdentifier) ?? + throw new Exception("Unknown userid"); + + // remove the user id claim so we don't include it as an extra claim if/when we provision the user + var claims = externalUser.Claims.ToList(); + claims.Remove(userIdClaim); + + var provider = result.Properties.Items["scheme"]; + var providerUserId = userIdClaim.Value; + var user = await _userRepository.GetByIdAsync(new Guid(providerUserId)); + + return (user, provider, providerUserId, claims); } - else + + private void ProcessLoginCallback(AuthenticateResult externalResult, List localClaims, + AuthenticationProperties localSignInProps) { - // User might have clicked on a malicious link - should be logged - throw new Exception("invalid return URL"); + // If the external system sent a session id claim, copy it over + // so we can use it for single sign-out + var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); + if (sid != null) + { + localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); + } + + // If the external provider issued an idToken, we'll keep it for signout + var idToken = externalResult.Properties.GetTokenValue("id_token"); + if (idToken != null) + { + localSignInProps.StoreTokens( + new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); + } } - } - private async Task<(User user, string provider, string providerUserId, IEnumerable claims)> - FindUserFromExternalProviderAsync(AuthenticateResult result) - { - var externalUser = result.Principal; - - // Try to determine the unique id of the external user (issued by the provider) - // the most common claim type for that are the sub claim and the NameIdentifier - // depending on the external provider, some other claim type might be used - var userIdClaim = externalUser.FindFirst(JwtClaimTypes.Subject) ?? - externalUser.FindFirst(ClaimTypes.NameIdentifier) ?? - throw new Exception("Unknown userid"); - - // remove the user id claim so we don't include it as an extra claim if/when we provision the user - var claims = externalUser.Claims.ToList(); - claims.Remove(userIdClaim); - - var provider = result.Properties.Items["scheme"]; - var providerUserId = userIdClaim.Value; - var user = await _userRepository.GetByIdAsync(new Guid(providerUserId)); - - return (user, provider, providerUserId, claims); - } - - private void ProcessLoginCallback(AuthenticateResult externalResult, List localClaims, - AuthenticationProperties localSignInProps) - { - // If the external system sent a session id claim, copy it over - // so we can use it for single sign-out - var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); - if (sid != null) + private bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) { - localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); + return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) + && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); } - - // If the external provider issued an idToken, we'll keep it for signout - var idToken = externalResult.Properties.GetTokenValue("id_token"); - if (idToken != null) - { - localSignInProps.StoreTokens( - new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); - } - } - - private bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) - { - return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) - && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); } } diff --git a/src/Identity/Models/RedirectViewModel.cs b/src/Identity/Models/RedirectViewModel.cs index 5cf7663b4b..848fdf871a 100644 --- a/src/Identity/Models/RedirectViewModel.cs +++ b/src/Identity/Models/RedirectViewModel.cs @@ -1,6 +1,7 @@ -namespace Bit.Identity.Models; - -public class RedirectViewModel +namespace Bit.Identity.Models { - public string RedirectUrl { get; set; } + public class RedirectViewModel + { + public string RedirectUrl { get; set; } + } } diff --git a/src/Identity/Program.cs b/src/Identity/Program.cs index e87f81aa62..540e3ac751 100644 --- a/src/Identity/Program.cs +++ b/src/Identity/Program.cs @@ -2,43 +2,44 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Identity; - -public class Program +namespace Bit.Identity { - public static void Main(string[] args) + public class Program { - CreateHostBuilder(args) - .Build() - .Run(); - } + public static void Main(string[] args) + { + CreateHostBuilder(args) + .Build() + .Run(); + } - public static IHostBuilder CreateHostBuilder(string[] args) - { - return Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains(typeof(IpRateLimitMiddleware).FullName) && - e.Level == LogEventLevel.Information) + public static IHostBuilder CreateHostBuilder(string[] args) + { + return Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return true; - } + var context = e.Properties["SourceContext"].ToString(); + if (context.Contains(typeof(IpRateLimitMiddleware).FullName) && + e.Level == LogEventLevel.Information) + { + return true; + } - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) - { - return e.Level > LogEventLevel.Error; - } + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + { + return e.Level > LogEventLevel.Error; + } - return e.Level >= LogEventLevel.Error; - })); - }); + return e.Level >= LogEventLevel.Error; + })); + }); + } } } diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index 170e2b931e..e355f01235 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -13,213 +13,214 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.IdentityModel.Logging; using Microsoft.OpenApi.Models; -namespace Bit.Identity; - -public class Startup +namespace Bit.Identity { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; private set; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - if (!globalSettings.SelfHosted) + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - services.Configure(Configuration.GetSection("IpRateLimitOptions")); - services.Configure(Configuration.GetSection("IpRateLimitPolicies")); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); + public IConfiguration Configuration { get; private set; } + public IWebHostEnvironment Environment { get; set; } - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - services.TryAddSingleton(); - - // Caching - services.AddMemoryCache(); - services.AddDistributedCache(globalSettings); - - // Mvc - services.AddMvc(config => + public void ConfigureServices(IServiceCollection services) { - config.Filters.Add(new ModelStateValidationFilterAttribute()); - }); + // Options + services.AddOptions(); - services.AddSwaggerGen(c => - { - c.SwaggerDoc("v1", new OpenApiInfo { Title = "Bitwarden Identity", Version = "v1" }); - }); - - if (!globalSettings.SelfHosted) - { - services.AddIpRateLimiting(globalSettings); - } - - // Cookies - if (Environment.IsDevelopment()) - { - services.Configure(options => + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + if (!globalSettings.SelfHosted) { - options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - options.OnAppendCookie = ctx => - { - ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - }; + services.Configure(Configuration.GetSection("IpRateLimitOptions")); + services.Configure(Configuration.GetSection("IpRateLimitPolicies")); + } + + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + services.TryAddSingleton(); + + // Caching + services.AddMemoryCache(); + services.AddDistributedCache(globalSettings); + + // Mvc + services.AddMvc(config => + { + config.Filters.Add(new ModelStateValidationFilterAttribute()); }); - } - JwtSecurityTokenHandler.DefaultMapInboundClaims = false; - - // Authentication - services - .AddDistributedIdentityServices(globalSettings) - .AddAuthentication() - .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) - .AddOpenIdConnect("sso", "Single Sign On", options => + services.AddSwaggerGen(c => { - options.Authority = globalSettings.BaseServiceUri.InternalSso; - options.RequireHttpsMetadata = !Environment.IsDevelopment() && - globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); - options.ClientId = "oidc-identity"; - options.ClientSecret = globalSettings.OidcIdentityClientKey; - options.ResponseMode = "form_post"; + c.SwaggerDoc("v1", new OpenApiInfo { Title = "Bitwarden Identity", Version = "v1" }); + }); - options.SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; - options.ResponseType = "code"; - options.SaveTokens = false; - options.GetClaimsFromUserInfoEndpoint = true; + if (!globalSettings.SelfHosted) + { + services.AddIpRateLimiting(globalSettings); + } - options.Events = new Microsoft.AspNetCore.Authentication.OpenIdConnect.OpenIdConnectEvents + // Cookies + if (Environment.IsDevelopment()) + { + services.Configure(options => { - OnRedirectToIdentityProvider = context => + options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + options.OnAppendCookie = ctx => { - // Pass domain_hint onto the sso idp - context.ProtocolMessage.DomainHint = context.Properties.Items["domain_hint"]; - context.ProtocolMessage.Parameters.Add("organizationId", context.Properties.Items["organizationId"]); - if (context.Properties.Items.ContainsKey("user_identifier")) + ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + }; + }); + } + + JwtSecurityTokenHandler.DefaultMapInboundClaims = false; + + // Authentication + services + .AddDistributedIdentityServices(globalSettings) + .AddAuthentication() + .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) + .AddOpenIdConnect("sso", "Single Sign On", options => + { + options.Authority = globalSettings.BaseServiceUri.InternalSso; + options.RequireHttpsMetadata = !Environment.IsDevelopment() && + globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); + options.ClientId = "oidc-identity"; + options.ClientSecret = globalSettings.OidcIdentityClientKey; + options.ResponseMode = "form_post"; + + options.SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; + options.ResponseType = "code"; + options.SaveTokens = false; + options.GetClaimsFromUserInfoEndpoint = true; + + options.Events = new Microsoft.AspNetCore.Authentication.OpenIdConnect.OpenIdConnectEvents + { + OnRedirectToIdentityProvider = context => { - context.ProtocolMessage.SessionState = context.Properties.Items["user_identifier"]; + // Pass domain_hint onto the sso idp + context.ProtocolMessage.DomainHint = context.Properties.Items["domain_hint"]; + context.ProtocolMessage.Parameters.Add("organizationId", context.Properties.Items["organizationId"]); + if (context.Properties.Items.ContainsKey("user_identifier")) + { + context.ProtocolMessage.SessionState = context.Properties.Items["user_identifier"]; + } + + if (context.Properties.Parameters.Count > 0 && + context.Properties.Parameters.TryGetValue(SsoTokenable.TokenIdentifier, out var tokenValue)) + { + var token = tokenValue?.ToString() ?? ""; + context.ProtocolMessage.Parameters.Add(SsoTokenable.TokenIdentifier, token); + } + return Task.FromResult(0); } + }; + }); - if (context.Properties.Parameters.Count > 0 && - context.Properties.Parameters.TryGetValue(SsoTokenable.TokenIdentifier, out var tokenValue)) - { - var token = tokenValue?.ToString() ?? ""; - context.ProtocolMessage.Parameters.Add(SsoTokenable.TokenIdentifier, token); - } - return Task.FromResult(0); - } - }; - }); + // IdentityServer + services.AddCustomIdentityServerServices(Environment, globalSettings); - // IdentityServer - services.AddCustomIdentityServerServices(Environment, globalSettings); + // Identity + services.AddCustomIdentityServices(globalSettings); - // Identity - services.AddCustomIdentityServices(globalSettings); + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + services.AddCoreLocalizationServices(); - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - services.AddCoreLocalizationServices(); - - if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) - { - services.AddHostedService(); - } - - // HttpClients - services.AddHttpClient("InternalSso", client => - { - client.BaseAddress = new Uri(globalSettings.BaseServiceUri.InternalSso); - }); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings, - ILogger logger) - { - IdentityModelEventSource.ShowPII = true; - - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (!env.IsDevelopment()) - { - var uri = new Uri(globalSettings.BaseServiceUri.Identity); - app.Use(async (ctx, next) => + if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) { - ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); - await next(); + services.AddHostedService(); + } + + // HttpClients + services.AddHttpClient("InternalSso", client => + { + client.BaseAddress = new Uri(globalSettings.BaseServiceUri.InternalSso); }); } - if (globalSettings.SelfHosted) + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings, + ILogger logger) { - app.UsePathBase("/identity"); - app.UseForwardedHeaders(globalSettings); + IdentityModelEventSource.ShowPII = true; + + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (!env.IsDevelopment()) + { + var uri = new Uri(globalSettings.BaseServiceUri.Identity); + app.Use(async (ctx, next) => + { + ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); + await next(); + }); + } + + if (globalSettings.SelfHosted) + { + app.UsePathBase("/identity"); + app.UseForwardedHeaders(globalSettings); + } + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + if (!globalSettings.SelfHosted) + { + // Rate limiting + app.UseMiddleware(); + } + + if (env.IsDevelopment()) + { + app.UseSwagger(); + app.UseDeveloperExceptionPage(); + app.UseCookiePolicy(); + } + + // Add localization + app.UseCoreLocalization(); + + // Add static files to the request pipeline. + app.UseStaticFiles(); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add current context + app.UseMiddleware(); + + // Add IdentityServer to the request pipeline. + app.UseIdentityServer(); + + // Add Mvc stuff + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + + // Log startup + logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } - - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - if (!globalSettings.SelfHosted) - { - // Rate limiting - app.UseMiddleware(); - } - - if (env.IsDevelopment()) - { - app.UseSwagger(); - app.UseDeveloperExceptionPage(); - app.UseCookiePolicy(); - } - - // Add localization - app.UseCoreLocalization(); - - // Add static files to the request pipeline. - app.UseStaticFiles(); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add current context - app.UseMiddleware(); - - // Add IdentityServer to the request pipeline. - app.UseIdentityServer(); - - // Add Mvc stuff - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - - // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } } diff --git a/src/Identity/Utilities/DiscoveryResponseGenerator.cs b/src/Identity/Utilities/DiscoveryResponseGenerator.cs index da06180989..32a5e6ddb7 100644 --- a/src/Identity/Utilities/DiscoveryResponseGenerator.cs +++ b/src/Identity/Utilities/DiscoveryResponseGenerator.cs @@ -5,31 +5,32 @@ using IdentityServer4.Services; using IdentityServer4.Stores; using IdentityServer4.Validation; -namespace Bit.Identity.Utilities; - -public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator +namespace Bit.Identity.Utilities { - private readonly GlobalSettings _globalSettings; - - public DiscoveryResponseGenerator( - IdentityServerOptions options, - IResourceStore resourceStore, - IKeyMaterialService keys, - ExtensionGrantValidator extensionGrants, - ISecretsListParser secretParsers, - IResourceOwnerPasswordValidator resourceOwnerValidator, - ILogger logger, - GlobalSettings globalSettings) - : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) + public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator { - _globalSettings = globalSettings; - } + private readonly GlobalSettings _globalSettings; - public override async Task> CreateDiscoveryDocumentAsync( - string baseUrl, string issuerUri) - { - var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); - return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Identity, - _globalSettings.BaseServiceUri.InternalIdentity); + public DiscoveryResponseGenerator( + IdentityServerOptions options, + IResourceStore resourceStore, + IKeyMaterialService keys, + ExtensionGrantValidator extensionGrants, + ISecretsListParser secretParsers, + IResourceOwnerPasswordValidator resourceOwnerValidator, + ILogger logger, + GlobalSettings globalSettings) + : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) + { + _globalSettings = globalSettings; + } + + public override async Task> CreateDiscoveryDocumentAsync( + string baseUrl, string issuerUri) + { + var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); + return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Identity, + _globalSettings.BaseServiceUri.InternalIdentity); + } } } diff --git a/src/Identity/Utilities/ServiceCollectionExtensions.cs b/src/Identity/Utilities/ServiceCollectionExtensions.cs index df3a6dec82..82000ebcf5 100644 --- a/src/Identity/Utilities/ServiceCollectionExtensions.cs +++ b/src/Identity/Utilities/ServiceCollectionExtensions.cs @@ -5,47 +5,48 @@ using IdentityServer4.ResponseHandling; using IdentityServer4.Services; using IdentityServer4.Stores; -namespace Bit.Identity.Utilities; - -public static class ServiceCollectionExtensions +namespace Bit.Identity.Utilities { - public static IIdentityServerBuilder AddCustomIdentityServerServices(this IServiceCollection services, - IWebHostEnvironment env, GlobalSettings globalSettings) + public static class ServiceCollectionExtensions { - services.AddTransient(); + public static IIdentityServerBuilder AddCustomIdentityServerServices(this IServiceCollection services, + IWebHostEnvironment env, GlobalSettings globalSettings) + { + services.AddTransient(); - services.AddSingleton(); - services.AddTransient(); + services.AddSingleton(); + services.AddTransient(); - var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity); - var identityServerBuilder = services - .AddIdentityServer(options => - { - options.Endpoints.EnableIntrospectionEndpoint = false; - options.Endpoints.EnableEndSessionEndpoint = false; - options.Endpoints.EnableUserInfoEndpoint = false; - options.Endpoints.EnableCheckSessionEndpoint = false; - options.Endpoints.EnableTokenRevocationEndpoint = false; - options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; - options.Caching.ClientStoreExpiration = new TimeSpan(0, 5, 0); - if (env.IsDevelopment()) + var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity); + var identityServerBuilder = services + .AddIdentityServer(options => { - options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - } - options.InputLengthRestrictions.UserName = 256; - }) - .AddInMemoryCaching() - .AddInMemoryApiResources(ApiResources.GetApiResources()) - .AddInMemoryApiScopes(ApiScopes.GetApiScopes()) - .AddClientStoreCache() - .AddCustomTokenRequestValidator() - .AddProfileService() - .AddResourceOwnerValidator() - .AddPersistedGrantStore() - .AddClientStore() - .AddIdentityServerCertificate(env, globalSettings); + options.Endpoints.EnableIntrospectionEndpoint = false; + options.Endpoints.EnableEndSessionEndpoint = false; + options.Endpoints.EnableUserInfoEndpoint = false; + options.Endpoints.EnableCheckSessionEndpoint = false; + options.Endpoints.EnableTokenRevocationEndpoint = false; + options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; + options.Caching.ClientStoreExpiration = new TimeSpan(0, 5, 0); + if (env.IsDevelopment()) + { + options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + } + options.InputLengthRestrictions.UserName = 256; + }) + .AddInMemoryCaching() + .AddInMemoryApiResources(ApiResources.GetApiResources()) + .AddInMemoryApiScopes(ApiScopes.GetApiScopes()) + .AddClientStoreCache() + .AddCustomTokenRequestValidator() + .AddProfileService() + .AddResourceOwnerValidator() + .AddPersistedGrantStore() + .AddClientStore() + .AddIdentityServerCertificate(env, globalSettings); - services.AddTransient(); - return identityServerBuilder; + services.AddTransient(); + return identityServerBuilder; + } } } diff --git a/src/Infrastructure.Dapper/DapperHelpers.cs b/src/Infrastructure.Dapper/DapperHelpers.cs index 48949df671..7203556499 100644 --- a/src/Infrastructure.Dapper/DapperHelpers.cs +++ b/src/Infrastructure.Dapper/DapperHelpers.cs @@ -3,132 +3,133 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; using Dapper; -namespace Bit.Infrastructure.Dapper; - -public static class DapperHelpers +namespace Bit.Infrastructure.Dapper { - public static DataTable ToGuidIdArrayTVP(this IEnumerable ids) + public static class DapperHelpers { - return ids.ToArrayTVP("GuidId"); - } - - public static DataTable ToArrayTVP(this IEnumerable values, string columnName) - { - var table = new DataTable(); - table.SetTypeName($"[dbo].[{columnName}Array]"); - table.Columns.Add(columnName, typeof(T)); - - if (values != null) + public static DataTable ToGuidIdArrayTVP(this IEnumerable ids) { - foreach (var value in values) + return ids.ToArrayTVP("GuidId"); + } + + public static DataTable ToArrayTVP(this IEnumerable values, string columnName) + { + var table = new DataTable(); + table.SetTypeName($"[dbo].[{columnName}Array]"); + table.Columns.Add(columnName, typeof(T)); + + if (values != null) { - table.Rows.Add(value); + foreach (var value in values) + { + table.Rows.Add(value); + } } + + return table; } - return table; - } - - public static DataTable ToArrayTVP(this IEnumerable values) - { - var table = new DataTable(); - table.SetTypeName("[dbo].[SelectionReadOnlyArray]"); - - var idColumn = new DataColumn("Id", typeof(Guid)); - table.Columns.Add(idColumn); - var readOnlyColumn = new DataColumn("ReadOnly", typeof(bool)); - table.Columns.Add(readOnlyColumn); - var hidePasswordsColumn = new DataColumn("HidePasswords", typeof(bool)); - table.Columns.Add(hidePasswordsColumn); - - if (values != null) + public static DataTable ToArrayTVP(this IEnumerable values) { - foreach (var value in values) + var table = new DataTable(); + table.SetTypeName("[dbo].[SelectionReadOnlyArray]"); + + var idColumn = new DataColumn("Id", typeof(Guid)); + table.Columns.Add(idColumn); + var readOnlyColumn = new DataColumn("ReadOnly", typeof(bool)); + table.Columns.Add(readOnlyColumn); + var hidePasswordsColumn = new DataColumn("HidePasswords", typeof(bool)); + table.Columns.Add(hidePasswordsColumn); + + if (values != null) { - var row = table.NewRow(); - row[idColumn] = value.Id; - row[readOnlyColumn] = value.ReadOnly; - row[hidePasswordsColumn] = value.HidePasswords; - table.Rows.Add(row); + foreach (var value in values) + { + var row = table.NewRow(); + row[idColumn] = value.Id; + row[readOnlyColumn] = value.ReadOnly; + row[hidePasswordsColumn] = value.HidePasswords; + table.Rows.Add(row); + } } + + return table; } - return table; - } - - public static DataTable ToTvp(this IEnumerable orgUsers) - { - var table = new DataTable(); - table.SetTypeName("[dbo].[OrganizationUserType]"); - - var columnData = new List<(string name, Type type, Func getter)> + public static DataTable ToTvp(this IEnumerable orgUsers) { - (nameof(OrganizationUser.Id), typeof(Guid), ou => ou.Id), - (nameof(OrganizationUser.OrganizationId), typeof(Guid), ou => ou.OrganizationId), - (nameof(OrganizationUser.UserId), typeof(Guid), ou => ou.UserId), - (nameof(OrganizationUser.Email), typeof(string), ou => ou.Email), - (nameof(OrganizationUser.Key), typeof(string), ou => ou.Key), - (nameof(OrganizationUser.Status), typeof(byte), ou => ou.Status), - (nameof(OrganizationUser.Type), typeof(byte), ou => ou.Type), - (nameof(OrganizationUser.AccessAll), typeof(bool), ou => ou.AccessAll), - (nameof(OrganizationUser.ExternalId), typeof(string), ou => ou.ExternalId), - (nameof(OrganizationUser.CreationDate), typeof(DateTime), ou => ou.CreationDate), - (nameof(OrganizationUser.RevisionDate), typeof(DateTime), ou => ou.RevisionDate), - (nameof(OrganizationUser.Permissions), typeof(string), ou => ou.Permissions), - (nameof(OrganizationUser.ResetPasswordKey), typeof(string), ou => ou.ResetPasswordKey), - }; + var table = new DataTable(); + table.SetTypeName("[dbo].[OrganizationUserType]"); - return orgUsers.BuildTable(table, columnData); - } + var columnData = new List<(string name, Type type, Func getter)> + { + (nameof(OrganizationUser.Id), typeof(Guid), ou => ou.Id), + (nameof(OrganizationUser.OrganizationId), typeof(Guid), ou => ou.OrganizationId), + (nameof(OrganizationUser.UserId), typeof(Guid), ou => ou.UserId), + (nameof(OrganizationUser.Email), typeof(string), ou => ou.Email), + (nameof(OrganizationUser.Key), typeof(string), ou => ou.Key), + (nameof(OrganizationUser.Status), typeof(byte), ou => ou.Status), + (nameof(OrganizationUser.Type), typeof(byte), ou => ou.Type), + (nameof(OrganizationUser.AccessAll), typeof(bool), ou => ou.AccessAll), + (nameof(OrganizationUser.ExternalId), typeof(string), ou => ou.ExternalId), + (nameof(OrganizationUser.CreationDate), typeof(DateTime), ou => ou.CreationDate), + (nameof(OrganizationUser.RevisionDate), typeof(DateTime), ou => ou.RevisionDate), + (nameof(OrganizationUser.Permissions), typeof(string), ou => ou.Permissions), + (nameof(OrganizationUser.ResetPasswordKey), typeof(string), ou => ou.ResetPasswordKey), + }; - public static DataTable ToTvp(this IEnumerable organizationSponsorships) - { - var table = new DataTable(); - table.SetTypeName("[dbo].[OrganizationSponsorshipType]"); - - var columnData = new List<(string name, Type type, Func getter)> - { - (nameof(OrganizationSponsorship.Id), typeof(Guid), ou => ou.Id), - (nameof(OrganizationSponsorship.SponsoringOrganizationId), typeof(Guid), ou => ou.SponsoringOrganizationId), - (nameof(OrganizationSponsorship.SponsoringOrganizationUserId), typeof(Guid), ou => ou.SponsoringOrganizationUserId), - (nameof(OrganizationSponsorship.SponsoredOrganizationId), typeof(Guid), ou => ou.SponsoredOrganizationId), - (nameof(OrganizationSponsorship.FriendlyName), typeof(string), ou => ou.FriendlyName), - (nameof(OrganizationSponsorship.OfferedToEmail), typeof(string), ou => ou.OfferedToEmail), - (nameof(OrganizationSponsorship.PlanSponsorshipType), typeof(byte), ou => ou.PlanSponsorshipType), - (nameof(OrganizationSponsorship.LastSyncDate), typeof(DateTime), ou => ou.LastSyncDate), - (nameof(OrganizationSponsorship.ValidUntil), typeof(DateTime), ou => ou.ValidUntil), - (nameof(OrganizationSponsorship.ToDelete), typeof(bool), ou => ou.ToDelete), - }; - - return organizationSponsorships.BuildTable(table, columnData); - } - - private static DataTable BuildTable(this IEnumerable entities, DataTable table, List<(string name, Type type, Func getter)> columnData) - { - foreach (var (name, type, getter) in columnData) - { - var column = new DataColumn(name, type); - table.Columns.Add(column); + return orgUsers.BuildTable(table, columnData); } - foreach (var entity in entities ?? new T[] { }) + public static DataTable ToTvp(this IEnumerable organizationSponsorships) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[OrganizationSponsorshipType]"); + + var columnData = new List<(string name, Type type, Func getter)> + { + (nameof(OrganizationSponsorship.Id), typeof(Guid), ou => ou.Id), + (nameof(OrganizationSponsorship.SponsoringOrganizationId), typeof(Guid), ou => ou.SponsoringOrganizationId), + (nameof(OrganizationSponsorship.SponsoringOrganizationUserId), typeof(Guid), ou => ou.SponsoringOrganizationUserId), + (nameof(OrganizationSponsorship.SponsoredOrganizationId), typeof(Guid), ou => ou.SponsoredOrganizationId), + (nameof(OrganizationSponsorship.FriendlyName), typeof(string), ou => ou.FriendlyName), + (nameof(OrganizationSponsorship.OfferedToEmail), typeof(string), ou => ou.OfferedToEmail), + (nameof(OrganizationSponsorship.PlanSponsorshipType), typeof(byte), ou => ou.PlanSponsorshipType), + (nameof(OrganizationSponsorship.LastSyncDate), typeof(DateTime), ou => ou.LastSyncDate), + (nameof(OrganizationSponsorship.ValidUntil), typeof(DateTime), ou => ou.ValidUntil), + (nameof(OrganizationSponsorship.ToDelete), typeof(bool), ou => ou.ToDelete), + }; + + return organizationSponsorships.BuildTable(table, columnData); + } + + private static DataTable BuildTable(this IEnumerable entities, DataTable table, List<(string name, Type type, Func getter)> columnData) { - var row = table.NewRow(); foreach (var (name, type, getter) in columnData) { - var val = getter(entity); - if (val == null) - { - row[name] = DBNull.Value; - } - else - { - row[name] = val; - } + var column = new DataColumn(name, type); + table.Columns.Add(column); } - table.Rows.Add(row); - } - return table; + foreach (var entity in entities ?? new T[] { }) + { + var row = table.NewRow(); + foreach (var (name, type, getter) in columnData) + { + var val = getter(entity); + if (val == null) + { + row[name] = DBNull.Value; + } + else + { + row[name] = val; + } + } + table.Rows.Add(row); + } + + return table; + } } } diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index 9c138f7b02..a75a05fcc1 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -2,42 +2,43 @@ using Bit.Infrastructure.Dapper.Repositories; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.Dapper; - -public static class DapperServiceCollectionExtensions +namespace Bit.Infrastructure.Dapper { - public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted) + public static class DapperServiceCollectionExtensions { - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - - if (selfHosted) + public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted) { - services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + + if (selfHosted) + { + services.AddSingleton(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs index 4a3694d859..135f024a6c 100644 --- a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs @@ -1,29 +1,30 @@ using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public abstract class BaseRepository +namespace Bit.Infrastructure.Dapper.Repositories { - static BaseRepository() + public abstract class BaseRepository { - SqlMapper.AddTypeHandler(new DateTimeHandler()); - } - - public BaseRepository(string connectionString, string readOnlyConnectionString) - { - if (string.IsNullOrWhiteSpace(connectionString)) + static BaseRepository() { - throw new ArgumentNullException(nameof(connectionString)); - } - if (string.IsNullOrWhiteSpace(readOnlyConnectionString)) - { - throw new ArgumentNullException(nameof(readOnlyConnectionString)); + SqlMapper.AddTypeHandler(new DateTimeHandler()); } - ConnectionString = connectionString; - ReadOnlyConnectionString = readOnlyConnectionString; - } + public BaseRepository(string connectionString, string readOnlyConnectionString) + { + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new ArgumentNullException(nameof(connectionString)); + } + if (string.IsNullOrWhiteSpace(readOnlyConnectionString)) + { + throw new ArgumentNullException(nameof(readOnlyConnectionString)); + } - protected string ConnectionString { get; private set; } - protected string ReadOnlyConnectionString { get; private set; } + ConnectionString = connectionString; + ReadOnlyConnectionString = readOnlyConnectionString; + } + + protected string ConnectionString { get; private set; } + protected string ReadOnlyConnectionString { get; private set; } + } } diff --git a/src/Infrastructure.Dapper/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Repositories/CipherRepository.cs index a2b757a712..33560aad5e 100644 --- a/src/Infrastructure.Dapper/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CipherRepository.cs @@ -8,325 +8,325 @@ using Bit.Core.Settings; using Core.Models.Data; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class CipherRepository : Repository, ICipherRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public CipherRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public CipherRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + public class CipherRepository : Repository, ICipherRepository { - using (var connection = new SqlConnection(ConnectionString)) + public CipherRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public CipherRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[CipherDetails_ReadByIdUserId]", - new { Id = id, UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.FirstOrDefault(); - } - } - - public async Task GetOrganizationDetailsByIdAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[CipherOrganizationDetails_ReadById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - return results.FirstOrDefault(); - } - } - - public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( - Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[CipherOrganizationDetails_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.QueryFirstOrDefaultAsync( - $"[{Schema}].[Cipher_ReadCanEditByIdUserId]", - new { UserId = userId, Id = cipherId }, - commandType: CommandType.StoredProcedure); - - return result; - } - } - - public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) - { - string sprocName = null; - if (withOrganizations) - { - sprocName = $"[{Schema}].[CipherDetails_ReadByUserId]"; - } - else - { - sprocName = $"[{Schema}].[CipherDetails_ReadWithoutOrganizationsByUserId]"; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - sprocName, - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results - .GroupBy(c => c.Id) - .Select(g => g.OrderByDescending(og => og.Edit).First()) - .ToList(); - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Cipher_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task CreateAsync(Cipher cipher, IEnumerable collectionIds) - { - cipher.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(cipher)); - objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } - } - - public async Task CreateAsync(CipherDetails cipher) - { - cipher.SetNewId(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CipherDetails_Create]", - cipher, - commandType: CommandType.StoredProcedure); - } - } - - public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) - { - cipher.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(cipher)); - objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CipherDetails_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } - } - - public async Task ReplaceAsync(CipherDetails obj) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CipherDetails_Update]", - obj, - commandType: CommandType.StoredProcedure); - } - } - - public async Task UpsertAsync(CipherDetails cipher) - { - if (cipher.Id.Equals(default)) - { - await CreateAsync(cipher); - } - else - { - await ReplaceAsync(cipher); - } - } - - public async Task ReplaceAsync(Cipher obj, IEnumerable collectionIds) - { - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(obj)); - objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - $"[{Schema}].[Cipher_UpdateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - return result >= 0; - } - } - - public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_UpdatePartial]", - new { Id = id, UserId = userId, FolderId = folderId, Favorite = favorite }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task UpdateAttachmentAsync(CipherAttachment attachment) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_UpdateAttachment]", - attachment, - commandType: CommandType.StoredProcedure); - } - } - - public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteAttachment]", - new { Id = cipherId, AttachmentId = attachmentId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task DeleteAsync(IEnumerable ids, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_Delete]", - new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteByIdsOrganizationId]", - new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_SoftDeleteByIdsOrganizationId]", - new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_Move]", - new { Ids = ids.ToGuidIdArrayTVP(), FolderId = folderId, UserId = userId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task DeleteByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task DeleteByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } - } - - public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) - { - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) + using (var connection = new SqlConnection(ConnectionString)) { - try + var results = await connection.QueryAsync( + $"[{Schema}].[CipherDetails_ReadByIdUserId]", + new { Id = id, UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.FirstOrDefault(); + } + } + + public async Task GetOrganizationDetailsByIdAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[CipherOrganizationDetails_ReadById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.FirstOrDefault(); + } + } + + public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( + Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[CipherOrganizationDetails_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.QueryFirstOrDefaultAsync( + $"[{Schema}].[Cipher_ReadCanEditByIdUserId]", + new { UserId = userId, Id = cipherId }, + commandType: CommandType.StoredProcedure); + + return result; + } + } + + public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) + { + string sprocName = null; + if (withOrganizations) + { + sprocName = $"[{Schema}].[CipherDetails_ReadByUserId]"; + } + else + { + sprocName = $"[{Schema}].[CipherDetails_ReadWithoutOrganizationsByUserId]"; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + sprocName, + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results + .GroupBy(c => c.Id) + .Select(g => g.OrderByDescending(og => og.Edit).First()) + .ToList(); + } + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Cipher_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task CreateAsync(Cipher cipher, IEnumerable collectionIds) + { + cipher.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(cipher)); + objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } + } + + public async Task CreateAsync(CipherDetails cipher) + { + cipher.SetNewId(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CipherDetails_Create]", + cipher, + commandType: CommandType.StoredProcedure); + } + } + + public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) + { + cipher.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(cipher)); + objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CipherDetails_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } + } + + public async Task ReplaceAsync(CipherDetails obj) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CipherDetails_Update]", + obj, + commandType: CommandType.StoredProcedure); + } + } + + public async Task UpsertAsync(CipherDetails cipher) + { + if (cipher.Id.Equals(default)) + { + await CreateAsync(cipher); + } + else + { + await ReplaceAsync(cipher); + } + } + + public async Task ReplaceAsync(Cipher obj, IEnumerable collectionIds) + { + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(obj)); + objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.ExecuteScalarAsync( + $"[{Schema}].[Cipher_UpdateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + return result >= 0; + } + } + + public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_UpdatePartial]", + new { Id = id, UserId = userId, FolderId = folderId, Favorite = favorite }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task UpdateAttachmentAsync(CipherAttachment attachment) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_UpdateAttachment]", + attachment, + commandType: CommandType.StoredProcedure); + } + } + + public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteAttachment]", + new { Id = cipherId, AttachmentId = attachmentId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task DeleteAsync(IEnumerable ids, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_Delete]", + new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteByIdsOrganizationId]", + new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_SoftDeleteByIdsOrganizationId]", + new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_Move]", + new { Ids = ids.ToGuidIdArrayTVP(), FolderId = folderId, UserId = userId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task DeleteByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task DeleteByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + } + } + + public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + { + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) { - // 1. Update user. - - using (var cmd = new SqlCommand("[dbo].[User_UpdateKeys]", connection, transaction)) + try { - cmd.CommandType = CommandType.StoredProcedure; - cmd.Parameters.Add("@Id", SqlDbType.UniqueIdentifier).Value = user.Id; - cmd.Parameters.Add("@SecurityStamp", SqlDbType.NVarChar).Value = user.SecurityStamp; - cmd.Parameters.Add("@Key", SqlDbType.VarChar).Value = user.Key; + // 1. Update user. - if (string.IsNullOrWhiteSpace(user.PrivateKey)) + using (var cmd = new SqlCommand("[dbo].[User_UpdateKeys]", connection, transaction)) { - cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = DBNull.Value; - } - else - { - cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = user.PrivateKey; + cmd.CommandType = CommandType.StoredProcedure; + cmd.Parameters.Add("@Id", SqlDbType.UniqueIdentifier).Value = user.Id; + cmd.Parameters.Add("@SecurityStamp", SqlDbType.NVarChar).Value = user.SecurityStamp; + cmd.Parameters.Add("@Key", SqlDbType.VarChar).Value = user.Key; + + if (string.IsNullOrWhiteSpace(user.PrivateKey)) + { + cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = DBNull.Value; + } + else + { + cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = user.PrivateKey; + } + + cmd.Parameters.Add("@RevisionDate", SqlDbType.DateTime2).Value = user.RevisionDate; + cmd.ExecuteNonQuery(); } - cmd.Parameters.Add("@RevisionDate", SqlDbType.DateTime2).Value = user.RevisionDate; - cmd.ExecuteNonQuery(); - } + // 2. Create temp tables to bulk copy into. - // 2. Create temp tables to bulk copy into. - - var sqlCreateTemp = @" + var sqlCreateTemp = @" SELECT TOP 0 * INTO #TempCipher FROM [dbo].[Cipher] @@ -339,50 +339,50 @@ public class CipherRepository : Repository, ICipherRepository INTO #TempSend FROM [dbo].[Send]"; - using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) - { - cmd.ExecuteNonQuery(); - } - - // 3. Bulk copy into temp tables. - - if (ciphers.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) { - bulkCopy.DestinationTableName = "#TempCipher"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); + cmd.ExecuteNonQuery(); } - } - if (folders.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + // 3. Bulk copy into temp tables. + + if (ciphers.Any()) { - bulkCopy.DestinationTableName = "#TempFolder"; - var dataTable = BuildFoldersTable(bulkCopy, folders); - bulkCopy.WriteToServer(dataTable); + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempCipher"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } } - } - if (sends.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + if (folders.Any()) { - bulkCopy.DestinationTableName = "#TempSend"; - var dataTable = BuildSendsTable(bulkCopy, sends); - bulkCopy.WriteToServer(dataTable); + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempFolder"; + var dataTable = BuildFoldersTable(bulkCopy, folders); + bulkCopy.WriteToServer(dataTable); + } } - } - // 4. Insert into real tables from temp tables and clean up. + if (sends.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempSend"; + var dataTable = BuildSendsTable(bulkCopy, sends); + bulkCopy.WriteToServer(dataTable); + } + } - var sql = string.Empty; + // 4. Insert into real tables from temp tables and clean up. - if (ciphers.Any()) - { - sql += @" + var sql = string.Empty; + + if (ciphers.Any()) + { + sql += @" UPDATE [dbo].[Cipher] SET @@ -395,11 +395,11 @@ public class CipherRepository : Repository, ICipherRepository #TempCipher TC ON C.Id = TC.Id WHERE C.[UserId] = @UserId"; - } + } - if (folders.Any()) - { - sql += @" + if (folders.Any()) + { + sql += @" UPDATE [dbo].[Folder] SET @@ -411,11 +411,11 @@ public class CipherRepository : Repository, ICipherRepository #TempFolder TF ON F.Id = TF.Id WHERE F.[UserId] = @UserId"; - } + } - if (sends.Any()) - { - sql += @" + if (sends.Any()) + { + sql += @" UPDATE [dbo].[Send] SET @@ -427,72 +427,72 @@ public class CipherRepository : Repository, ICipherRepository #TempSend TS ON S.Id = TS.Id WHERE S.[UserId] = @UserId"; - } + } - sql += @" + sql += @" DROP TABLE #TempCipher DROP TABLE #TempFolder DROP TABLE #TempSend"; - using (var cmd = new SqlCommand(sql, connection, transaction)) - { - cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = user.Id; - cmd.ExecuteNonQuery(); - } + using (var cmd = new SqlCommand(sql, connection, transaction)) + { + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = user.Id; + cmd.ExecuteNonQuery(); + } - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } } } + + return Task.FromResult(0); } - return Task.FromResult(0); - } - - public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) - { - if (!ciphers.Any()) + public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) { - return; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) + if (!ciphers.Any()) { - try - { - // 1. Create temp tables to bulk copy into. + return; + } - var sqlCreateTemp = @" + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + // 1. Create temp tables to bulk copy into. + + var sqlCreateTemp = @" SELECT TOP 0 * INTO #TempCipher FROM [dbo].[Cipher]"; - using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) - { - cmd.ExecuteNonQuery(); - } + using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } - // 2. Bulk copy into temp tables. - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempCipher"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } + // 2. Bulk copy into temp tables. + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempCipher"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } - // 3. Insert into real tables from temp tables and clean up. + // 3. Insert into real tables from temp tables and clean up. - // Intentionally not including Favorites, Folders, and CreationDate - // since those are not meant to be bulk updated at this time - var sql = @" + // Intentionally not including Favorites, Folders, and CreationDate + // since those are not meant to be bulk updated at this time + var sql = @" UPDATE [dbo].[Cipher] SET @@ -512,451 +512,452 @@ public class CipherRepository : Repository, ICipherRepository DROP TABLE #TempCipher"; - using (var cmd = new SqlCommand(sql, connection, transaction)) - { - cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; - cmd.ExecuteNonQuery(); - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDate]", - new { Id = userId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; - } - } - } - } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) - { - if (!ciphers.Any()) - { - return; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) - { - try - { - if (folders.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + using (var cmd = new SqlCommand(sql, connection, transaction)) { - bulkCopy.DestinationTableName = "[dbo].[Folder]"; - var dataTable = BuildFoldersTable(bulkCopy, folders); - bulkCopy.WriteToServer(dataTable); + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; + cmd.ExecuteNonQuery(); } - } - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Cipher]"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - await connection.ExecuteAsync( + await connection.ExecuteAsync( $"[{Schema}].[User_BumpAccountRevisionDate]", - new { Id = ciphers.First().UserId }, + new { Id = userId }, commandType: CommandType.StoredProcedure, transaction: transaction); - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } } } } - } - public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers) - { - if (!ciphers.Any()) + public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) { - return; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) + if (!ciphers.Any()) { - try + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + try { - bulkCopy.DestinationTableName = "[dbo].[Cipher]"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - if (collections.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Collection]"; - var dataTable = BuildCollectionsTable(bulkCopy, collections); - bulkCopy.WriteToServer(dataTable); - } - - if (collectionCiphers.Any()) + if (folders.Any()) { using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) { - bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; - var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers); + bulkCopy.DestinationTableName = "[dbo].[Folder]"; + var dataTable = BuildFoldersTable(bulkCopy, folders); bulkCopy.WriteToServer(dataTable); } } + + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Cipher]"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = ciphers.First().UserId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", - new { OrganizationId = ciphers.First().OrganizationId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; } } } - } - public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_SoftDelete]", - new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, - commandType: CommandType.StoredProcedure); - } - } + if (!ciphers.Any()) + { + return; + } - public async Task RestoreAsync(IEnumerable ids, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - $"[{Schema}].[Cipher_Restore]", - new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); - return results; - } - } + using (var transaction = connection.BeginTransaction()) + { + try + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Cipher]"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } - public async Task DeleteDeletedAsync(DateTime deletedDateBefore) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteDeleted]", - new { DeletedDateBefore = deletedDateBefore }, - commandType: CommandType.StoredProcedure, - commandTimeout: 43200); - } - } + if (collections.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Collection]"; + var dataTable = BuildCollectionsTable(bulkCopy, collections); + bulkCopy.WriteToServer(dataTable); + } - private DataTable BuildCiphersTable(SqlBulkCopy bulkCopy, IEnumerable ciphers) - { - var c = ciphers.FirstOrDefault(); - if (c == null) - { - throw new ApplicationException("Must have some ciphers to bulk import."); + if (collectionCiphers.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; + var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers); + bulkCopy.WriteToServer(dataTable); + } + } + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", + new { OrganizationId = ciphers.First().OrganizationId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } } - var ciphersTable = new DataTable("CipherDataTable"); - - var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); - ciphersTable.Columns.Add(idColumn); - var userIdColumn = new DataColumn(nameof(c.UserId), typeof(Guid)); - ciphersTable.Columns.Add(userIdColumn); - var organizationId = new DataColumn(nameof(c.OrganizationId), typeof(Guid)); - ciphersTable.Columns.Add(organizationId); - var typeColumn = new DataColumn(nameof(c.Type), typeof(short)); - ciphersTable.Columns.Add(typeColumn); - var dataColumn = new DataColumn(nameof(c.Data), typeof(string)); - ciphersTable.Columns.Add(dataColumn); - var favoritesColumn = new DataColumn(nameof(c.Favorites), typeof(string)); - ciphersTable.Columns.Add(favoritesColumn); - var foldersColumn = new DataColumn(nameof(c.Folders), typeof(string)); - ciphersTable.Columns.Add(foldersColumn); - var attachmentsColumn = new DataColumn(nameof(c.Attachments), typeof(string)); - ciphersTable.Columns.Add(attachmentsColumn); - var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); - ciphersTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); - ciphersTable.Columns.Add(revisionDateColumn); - var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); - ciphersTable.Columns.Add(deletedDateColumn); - var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); - ciphersTable.Columns.Add(repromptColumn); - - foreach (DataColumn col in ciphersTable.Columns) + public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_SoftDelete]", + new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, + commandType: CommandType.StoredProcedure); + } } - var keys = new DataColumn[1]; - keys[0] = idColumn; - ciphersTable.PrimaryKey = keys; - - foreach (var cipher in ciphers) + public async Task RestoreAsync(IEnumerable ids, Guid userId) { - var row = ciphersTable.NewRow(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + $"[{Schema}].[Cipher_Restore]", + new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, + commandType: CommandType.StoredProcedure); - row[idColumn] = cipher.Id; - row[userIdColumn] = cipher.UserId.HasValue ? (object)cipher.UserId.Value : DBNull.Value; - row[organizationId] = cipher.OrganizationId.HasValue ? (object)cipher.OrganizationId.Value : DBNull.Value; - row[typeColumn] = (short)cipher.Type; - row[dataColumn] = cipher.Data; - row[favoritesColumn] = cipher.Favorites; - row[foldersColumn] = cipher.Folders; - row[attachmentsColumn] = cipher.Attachments; - row[creationDateColumn] = cipher.CreationDate; - row[revisionDateColumn] = cipher.RevisionDate; - row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; - row[repromptColumn] = cipher.Reprompt; - - ciphersTable.Rows.Add(row); + return results; + } } - return ciphersTable; - } - - private DataTable BuildFoldersTable(SqlBulkCopy bulkCopy, IEnumerable folders) - { - var f = folders.FirstOrDefault(); - if (f == null) + public async Task DeleteDeletedAsync(DateTime deletedDateBefore) { - throw new ApplicationException("Must have some folders to bulk import."); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteDeleted]", + new { DeletedDateBefore = deletedDateBefore }, + commandType: CommandType.StoredProcedure, + commandTimeout: 43200); + } } - var foldersTable = new DataTable("FolderDataTable"); - - var idColumn = new DataColumn(nameof(f.Id), f.Id.GetType()); - foldersTable.Columns.Add(idColumn); - var userIdColumn = new DataColumn(nameof(f.UserId), f.UserId.GetType()); - foldersTable.Columns.Add(userIdColumn); - var nameColumn = new DataColumn(nameof(f.Name), typeof(string)); - foldersTable.Columns.Add(nameColumn); - var creationDateColumn = new DataColumn(nameof(f.CreationDate), f.CreationDate.GetType()); - foldersTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(f.RevisionDate), f.RevisionDate.GetType()); - foldersTable.Columns.Add(revisionDateColumn); - - foreach (DataColumn col in foldersTable.Columns) + private DataTable BuildCiphersTable(SqlBulkCopy bulkCopy, IEnumerable ciphers) { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + var c = ciphers.FirstOrDefault(); + if (c == null) + { + throw new ApplicationException("Must have some ciphers to bulk import."); + } + + var ciphersTable = new DataTable("CipherDataTable"); + + var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); + ciphersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(c.UserId), typeof(Guid)); + ciphersTable.Columns.Add(userIdColumn); + var organizationId = new DataColumn(nameof(c.OrganizationId), typeof(Guid)); + ciphersTable.Columns.Add(organizationId); + var typeColumn = new DataColumn(nameof(c.Type), typeof(short)); + ciphersTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(c.Data), typeof(string)); + ciphersTable.Columns.Add(dataColumn); + var favoritesColumn = new DataColumn(nameof(c.Favorites), typeof(string)); + ciphersTable.Columns.Add(favoritesColumn); + var foldersColumn = new DataColumn(nameof(c.Folders), typeof(string)); + ciphersTable.Columns.Add(foldersColumn); + var attachmentsColumn = new DataColumn(nameof(c.Attachments), typeof(string)); + ciphersTable.Columns.Add(attachmentsColumn); + var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); + ciphersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); + ciphersTable.Columns.Add(revisionDateColumn); + var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); + ciphersTable.Columns.Add(deletedDateColumn); + var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); + ciphersTable.Columns.Add(repromptColumn); + + foreach (DataColumn col in ciphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + ciphersTable.PrimaryKey = keys; + + foreach (var cipher in ciphers) + { + var row = ciphersTable.NewRow(); + + row[idColumn] = cipher.Id; + row[userIdColumn] = cipher.UserId.HasValue ? (object)cipher.UserId.Value : DBNull.Value; + row[organizationId] = cipher.OrganizationId.HasValue ? (object)cipher.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)cipher.Type; + row[dataColumn] = cipher.Data; + row[favoritesColumn] = cipher.Favorites; + row[foldersColumn] = cipher.Folders; + row[attachmentsColumn] = cipher.Attachments; + row[creationDateColumn] = cipher.CreationDate; + row[revisionDateColumn] = cipher.RevisionDate; + row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; + row[repromptColumn] = cipher.Reprompt; + + ciphersTable.Rows.Add(row); + } + + return ciphersTable; } - var keys = new DataColumn[1]; - keys[0] = idColumn; - foldersTable.PrimaryKey = keys; - - foreach (var folder in folders) + private DataTable BuildFoldersTable(SqlBulkCopy bulkCopy, IEnumerable folders) { - var row = foldersTable.NewRow(); + var f = folders.FirstOrDefault(); + if (f == null) + { + throw new ApplicationException("Must have some folders to bulk import."); + } - row[idColumn] = folder.Id; - row[userIdColumn] = folder.UserId; - row[nameColumn] = folder.Name; - row[creationDateColumn] = folder.CreationDate; - row[revisionDateColumn] = folder.RevisionDate; + var foldersTable = new DataTable("FolderDataTable"); - foldersTable.Rows.Add(row); + var idColumn = new DataColumn(nameof(f.Id), f.Id.GetType()); + foldersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(f.UserId), f.UserId.GetType()); + foldersTable.Columns.Add(userIdColumn); + var nameColumn = new DataColumn(nameof(f.Name), typeof(string)); + foldersTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(f.CreationDate), f.CreationDate.GetType()); + foldersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(f.RevisionDate), f.RevisionDate.GetType()); + foldersTable.Columns.Add(revisionDateColumn); + + foreach (DataColumn col in foldersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + foldersTable.PrimaryKey = keys; + + foreach (var folder in folders) + { + var row = foldersTable.NewRow(); + + row[idColumn] = folder.Id; + row[userIdColumn] = folder.UserId; + row[nameColumn] = folder.Name; + row[creationDateColumn] = folder.CreationDate; + row[revisionDateColumn] = folder.RevisionDate; + + foldersTable.Rows.Add(row); + } + + return foldersTable; } - return foldersTable; - } - - private DataTable BuildCollectionsTable(SqlBulkCopy bulkCopy, IEnumerable collections) - { - var c = collections.FirstOrDefault(); - if (c == null) + private DataTable BuildCollectionsTable(SqlBulkCopy bulkCopy, IEnumerable collections) { - throw new ApplicationException("Must have some collections to bulk import."); + var c = collections.FirstOrDefault(); + if (c == null) + { + throw new ApplicationException("Must have some collections to bulk import."); + } + + var collectionsTable = new DataTable("CollectionDataTable"); + + var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); + collectionsTable.Columns.Add(idColumn); + var organizationIdColumn = new DataColumn(nameof(c.OrganizationId), c.OrganizationId.GetType()); + collectionsTable.Columns.Add(organizationIdColumn); + var nameColumn = new DataColumn(nameof(c.Name), typeof(string)); + collectionsTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); + collectionsTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); + collectionsTable.Columns.Add(revisionDateColumn); + + foreach (DataColumn col in collectionsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + collectionsTable.PrimaryKey = keys; + + foreach (var collection in collections) + { + var row = collectionsTable.NewRow(); + + row[idColumn] = collection.Id; + row[organizationIdColumn] = collection.OrganizationId; + row[nameColumn] = collection.Name; + row[creationDateColumn] = collection.CreationDate; + row[revisionDateColumn] = collection.RevisionDate; + + collectionsTable.Rows.Add(row); + } + + return collectionsTable; } - var collectionsTable = new DataTable("CollectionDataTable"); - - var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); - collectionsTable.Columns.Add(idColumn); - var organizationIdColumn = new DataColumn(nameof(c.OrganizationId), c.OrganizationId.GetType()); - collectionsTable.Columns.Add(organizationIdColumn); - var nameColumn = new DataColumn(nameof(c.Name), typeof(string)); - collectionsTable.Columns.Add(nameColumn); - var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); - collectionsTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); - collectionsTable.Columns.Add(revisionDateColumn); - - foreach (DataColumn col in collectionsTable.Columns) + private DataTable BuildCollectionCiphersTable(SqlBulkCopy bulkCopy, IEnumerable collectionCiphers) { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + var cc = collectionCiphers.FirstOrDefault(); + if (cc == null) + { + throw new ApplicationException("Must have some collectionCiphers to bulk import."); + } + + var collectionCiphersTable = new DataTable("CollectionCipherDataTable"); + + var collectionIdColumn = new DataColumn(nameof(cc.CollectionId), cc.CollectionId.GetType()); + collectionCiphersTable.Columns.Add(collectionIdColumn); + var cipherIdColumn = new DataColumn(nameof(cc.CipherId), cc.CipherId.GetType()); + collectionCiphersTable.Columns.Add(cipherIdColumn); + + foreach (DataColumn col in collectionCiphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[2]; + keys[0] = collectionIdColumn; + keys[1] = cipherIdColumn; + collectionCiphersTable.PrimaryKey = keys; + + foreach (var collectionCipher in collectionCiphers) + { + var row = collectionCiphersTable.NewRow(); + + row[collectionIdColumn] = collectionCipher.CollectionId; + row[cipherIdColumn] = collectionCipher.CipherId; + + collectionCiphersTable.Rows.Add(row); + } + + return collectionCiphersTable; } - var keys = new DataColumn[1]; - keys[0] = idColumn; - collectionsTable.PrimaryKey = keys; - - foreach (var collection in collections) + private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable sends) { - var row = collectionsTable.NewRow(); + var s = sends.FirstOrDefault(); + if (s == null) + { + throw new ApplicationException("Must have some Sends to bulk import."); + } - row[idColumn] = collection.Id; - row[organizationIdColumn] = collection.OrganizationId; - row[nameColumn] = collection.Name; - row[creationDateColumn] = collection.CreationDate; - row[revisionDateColumn] = collection.RevisionDate; + var sendsTable = new DataTable("SendsDataTable"); - collectionsTable.Rows.Add(row); + var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType()); + sendsTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid)); + sendsTable.Columns.Add(userIdColumn); + var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid)); + sendsTable.Columns.Add(organizationIdColumn); + var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType()); + sendsTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType()); + sendsTable.Columns.Add(dataColumn); + var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType()); + sendsTable.Columns.Add(keyColumn); + var passwordColumn = new DataColumn(nameof(s.Password), typeof(string)); + sendsTable.Columns.Add(passwordColumn); + var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int)); + sendsTable.Columns.Add(maxAccessCountColumn); + var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType()); + sendsTable.Columns.Add(accessCountColumn); + var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType()); + sendsTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType()); + sendsTable.Columns.Add(revisionDateColumn); + var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime)); + sendsTable.Columns.Add(expirationDateColumn); + var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType()); + sendsTable.Columns.Add(deletionDateColumn); + var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType()); + sendsTable.Columns.Add(disabledColumn); + var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool)); + sendsTable.Columns.Add(hideEmailColumn); + + foreach (DataColumn col in sendsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + sendsTable.PrimaryKey = keys; + + foreach (var send in sends) + { + var row = sendsTable.NewRow(); + + row[idColumn] = send.Id; + row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value; + row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)send.Type; + row[dataColumn] = send.Data; + row[keyColumn] = send.Key; + row[passwordColumn] = send.Password; + row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value; + row[accessCountColumn] = send.AccessCount; + row[creationDateColumn] = send.CreationDate; + row[revisionDateColumn] = send.RevisionDate; + row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value; + row[deletionDateColumn] = send.DeletionDate; + row[disabledColumn] = send.Disabled; + row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value; + + sendsTable.Rows.Add(row); + } + + return sendsTable; } - return collectionsTable; - } - - private DataTable BuildCollectionCiphersTable(SqlBulkCopy bulkCopy, IEnumerable collectionCiphers) - { - var cc = collectionCiphers.FirstOrDefault(); - if (cc == null) + public class CipherDetailsWithCollections : CipherDetails { - throw new ApplicationException("Must have some collectionCiphers to bulk import."); + public DataTable CollectionIds { get; set; } } - var collectionCiphersTable = new DataTable("CollectionCipherDataTable"); - - var collectionIdColumn = new DataColumn(nameof(cc.CollectionId), cc.CollectionId.GetType()); - collectionCiphersTable.Columns.Add(collectionIdColumn); - var cipherIdColumn = new DataColumn(nameof(cc.CipherId), cc.CipherId.GetType()); - collectionCiphersTable.Columns.Add(cipherIdColumn); - - foreach (DataColumn col in collectionCiphersTable.Columns) + public class CipherWithCollections : Cipher { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + public DataTable CollectionIds { get; set; } } - - var keys = new DataColumn[2]; - keys[0] = collectionIdColumn; - keys[1] = cipherIdColumn; - collectionCiphersTable.PrimaryKey = keys; - - foreach (var collectionCipher in collectionCiphers) - { - var row = collectionCiphersTable.NewRow(); - - row[collectionIdColumn] = collectionCipher.CollectionId; - row[cipherIdColumn] = collectionCipher.CipherId; - - collectionCiphersTable.Rows.Add(row); - } - - return collectionCiphersTable; - } - - private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable sends) - { - var s = sends.FirstOrDefault(); - if (s == null) - { - throw new ApplicationException("Must have some Sends to bulk import."); - } - - var sendsTable = new DataTable("SendsDataTable"); - - var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType()); - sendsTable.Columns.Add(idColumn); - var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid)); - sendsTable.Columns.Add(userIdColumn); - var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid)); - sendsTable.Columns.Add(organizationIdColumn); - var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType()); - sendsTable.Columns.Add(typeColumn); - var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType()); - sendsTable.Columns.Add(dataColumn); - var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType()); - sendsTable.Columns.Add(keyColumn); - var passwordColumn = new DataColumn(nameof(s.Password), typeof(string)); - sendsTable.Columns.Add(passwordColumn); - var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int)); - sendsTable.Columns.Add(maxAccessCountColumn); - var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType()); - sendsTable.Columns.Add(accessCountColumn); - var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType()); - sendsTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType()); - sendsTable.Columns.Add(revisionDateColumn); - var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime)); - sendsTable.Columns.Add(expirationDateColumn); - var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType()); - sendsTable.Columns.Add(deletionDateColumn); - var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType()); - sendsTable.Columns.Add(disabledColumn); - var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool)); - sendsTable.Columns.Add(hideEmailColumn); - - foreach (DataColumn col in sendsTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - sendsTable.PrimaryKey = keys; - - foreach (var send in sends) - { - var row = sendsTable.NewRow(); - - row[idColumn] = send.Id; - row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value; - row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value; - row[typeColumn] = (short)send.Type; - row[dataColumn] = send.Data; - row[keyColumn] = send.Key; - row[passwordColumn] = send.Password; - row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value; - row[accessCountColumn] = send.AccessCount; - row[creationDateColumn] = send.CreationDate; - row[revisionDateColumn] = send.RevisionDate; - row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value; - row[deletionDateColumn] = send.DeletionDate; - row[disabledColumn] = send.Disabled; - row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value; - - sendsTable.Rows.Add(row); - } - - return sendsTable; - } - - public class CipherDetailsWithCollections : CipherDetails - { - public DataTable CollectionIds { get; set; } - } - - public class CipherWithCollections : Cipher - { - public DataTable CollectionIds { get; set; } } } diff --git a/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs index 1368be21e8..2876979482 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs @@ -5,94 +5,95 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class CollectionCipherRepository : BaseRepository, ICollectionCipherRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public CollectionCipherRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public CollectionCipherRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserIdAsync(Guid userId) + public class CollectionCipherRepository : BaseRepository, ICollectionCipherRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[CollectionCipher_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + public CollectionCipherRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.ToList(); + public CollectionCipherRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[CollectionCipher_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var results = await connection.QueryAsync( - "[dbo].[CollectionCipher_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[CollectionCipher_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) { - var results = await connection.QueryAsync( - "[dbo].[CollectionCipher_ReadByUserIdCipherId]", - new { UserId = userId, CipherId = cipherId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[CollectionCipher_ReadByUserIdCipherId]", + new { UserId = userId, CipherId = cipherId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) { - var results = await connection.ExecuteAsync( - "[dbo].[CollectionCipher_UpdateCollections]", - new { CipherId = cipherId, UserId = userId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[CollectionCipher_UpdateCollections]", + new { CipherId = cipherId, UserId = userId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) { - var results = await connection.ExecuteAsync( - "[dbo].[CollectionCipher_UpdateCollectionsAdmin]", - new { CipherId = cipherId, OrganizationId = organizationId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[CollectionCipher_UpdateCollectionsAdmin]", + new { CipherId = cipherId, OrganizationId = organizationId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, - Guid organizationId, IEnumerable collectionIds) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, + Guid organizationId, IEnumerable collectionIds) { - var results = await connection.ExecuteAsync( - "[dbo].[CollectionCipher_UpdateCollectionsForCiphers]", - new - { - CipherIds = cipherIds.ToGuidIdArrayTVP(), - UserId = userId, - OrganizationId = organizationId, - CollectionIds = collectionIds.ToGuidIdArrayTVP() - }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[CollectionCipher_UpdateCollectionsForCiphers]", + new + { + CipherIds = cipherIds.ToGuidIdArrayTVP(), + UserId = userId, + OrganizationId = organizationId, + CollectionIds = collectionIds.ToGuidIdArrayTVP() + }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index 3fd0a24300..ee8bc1e2e1 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -7,180 +7,181 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class CollectionRepository : Repository, ICollectionRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public CollectionRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public CollectionRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByOrganizationIdAsync(Guid organizationId) + public class CollectionRepository : Repository, ICollectionRepository { - using (var connection = new SqlConnection(ConnectionString)) + public CollectionRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public CollectionRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) { - var results = await connection.ExecuteScalarAsync( - "[dbo].[Collection_ReadCountByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[Collection_ReadCountByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results; + return results; + } } - } - public async Task>> GetByIdWithGroupsAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task>> GetByIdWithGroupsAsync(Guid id) { - var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Collection_ReadWithGroupsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + $"[{Schema}].[Collection_ReadWithGroupsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); - var collection = await results.ReadFirstOrDefaultAsync(); - var groups = (await results.ReadAsync()).ToList(); + var collection = await results.ReadFirstOrDefaultAsync(); + var groups = (await results.ReadAsync()).ToList(); - return new Tuple>(collection, groups); + return new Tuple>(collection, groups); + } } - } - public async Task>> GetByIdWithGroupsAsync( - Guid id, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task>> GetByIdWithGroupsAsync( + Guid id, Guid userId) { - var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Collection_ReadWithGroupsByIdUserId]", - new { Id = id, UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + $"[{Schema}].[Collection_ReadWithGroupsByIdUserId]", + new { Id = id, UserId = userId }, + commandType: CommandType.StoredProcedure); - var collection = await results.ReadFirstOrDefaultAsync(); - var groups = (await results.ReadAsync()).ToList(); + var collection = await results.ReadFirstOrDefaultAsync(); + var groups = (await results.ReadAsync()).ToList(); - return new Tuple>(collection, groups); + return new Tuple>(collection, groups); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetByIdAsync(Guid id, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByIdAsync(Guid id, Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[Collection_ReadByIdUserId]", - new { Id = id, UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Collection_ReadByIdUserId]", + new { Id = id, UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); + return results.FirstOrDefault(); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserIdAsync(Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[Collection_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Collection_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task CreateAsync(Collection obj, IEnumerable groups) - { - obj.SetNewId(); - var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithGroups.Groups = groups.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) + public async Task CreateAsync(Collection obj, IEnumerable groups) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Collection_CreateWithGroups]", - objWithGroups, - commandType: CommandType.StoredProcedure); + obj.SetNewId(); + var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithGroups.Groups = groups.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Collection_CreateWithGroups]", + objWithGroups, + commandType: CommandType.StoredProcedure); + } } - } - public async Task ReplaceAsync(Collection obj, IEnumerable groups) - { - var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithGroups.Groups = groups.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) + public async Task ReplaceAsync(Collection obj, IEnumerable groups) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Collection_UpdateWithGroups]", - objWithGroups, - commandType: CommandType.StoredProcedure); - } - } + var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithGroups.Groups = groups.ToArrayTVP(); - public async Task CreateUserAsync(Guid collectionId, Guid organizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Collection_UpdateWithGroups]", + objWithGroups, + commandType: CommandType.StoredProcedure); + } + } + + public async Task CreateUserAsync(Guid collectionId, Guid organizationUserId) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CollectionUser_Create]", - new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CollectionUser_Create]", + new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CollectionUser_Delete]", - new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CollectionUser_Delete]", + new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task UpdateUsersAsync(Guid id, IEnumerable users) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpdateUsersAsync(Guid id, IEnumerable users) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CollectionUser_UpdateUsers]", - new { CollectionId = id, Users = users.ToArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CollectionUser_UpdateUsers]", + new { CollectionId = id, Users = users.ToArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task> GetManyUsersByIdAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyUsersByIdAsync(Guid id) { - var results = await connection.QueryAsync( - $"[{Schema}].[CollectionUser_ReadByCollectionId]", - new { CollectionId = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[CollectionUser_ReadByCollectionId]", + new { CollectionId = id }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public class CollectionWithGroups : Collection - { - public DataTable Groups { get; set; } + public class CollectionWithGroups : Collection + { + public DataTable Groups { get; set; } + } } } diff --git a/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs b/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs index ac48653ec6..8aedf23211 100644 --- a/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs +++ b/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs @@ -1,17 +1,18 @@ using System.Data; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class DateTimeHandler : SqlMapper.TypeHandler +namespace Bit.Infrastructure.Dapper.Repositories { - public override void SetValue(IDbDataParameter parameter, DateTime value) + public class DateTimeHandler : SqlMapper.TypeHandler { - parameter.Value = value; - } + public override void SetValue(IDbDataParameter parameter, DateTime value) + { + parameter.Value = value; + } - public override DateTime Parse(object value) - { - return DateTime.SpecifyKind((DateTime)value, DateTimeKind.Utc); + public override DateTime Parse(object value) + { + return DateTime.SpecifyKind((DateTime)value, DateTimeKind.Utc); + } } } diff --git a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs index 325cee3070..039ff90ae7 100644 --- a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs @@ -5,83 +5,84 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class DeviceRepository : Repository, IDeviceRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public DeviceRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public DeviceRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + public class DeviceRepository : Repository, IDeviceRepository { - var device = await GetByIdAsync(id); - if (device == null || device.UserId != userId) + public DeviceRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public DeviceRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - return null; + var device = await GetByIdAsync(id); + if (device == null || device.UserId != userId) + { + return null; + } + + return device; } - return device; - } - - public async Task GetByIdentifierAsync(string identifier) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByIdentifierAsync(string identifier) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIdentifier]", - new - { - Identifier = identifier - }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIdentifier]", + new + { + Identifier = identifier + }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); + return results.FirstOrDefault(); + } } - } - public async Task GetByIdentifierAsync(string identifier, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByIdentifierAsync(string identifier, Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIdentifierUserId]", - new - { - UserId = userId, - Identifier = identifier - }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIdentifierUserId]", + new + { + UserId = userId, + Identifier = identifier + }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); + return results.FirstOrDefault(); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserIdAsync(Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task ClearPushTokenAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task ClearPushTokenAsync(Guid id) { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_ClearPushTokenById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_ClearPushTokenById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs b/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs index 9f1f9a9715..c88664c7a6 100644 --- a/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs @@ -6,91 +6,92 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public EmergencyAccessRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public EmergencyAccessRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) + public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[EmergencyAccess_ReadCountByGrantorIdEmail]", - new { GrantorId = grantorId, Email = email, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); + public EmergencyAccessRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results; + public EmergencyAccessRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[EmergencyAccess_ReadCountByGrantorIdEmail]", + new { GrantorId = grantorId, Email = email, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); + + return results; + } } - } - public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadByGrantorId]", - new { GrantorId = grantorId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadByGrantorId]", + new { GrantorId = grantorId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadByGranteeId]", - new { GranteeId = granteeId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadByGranteeId]", + new { GranteeId = granteeId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadByIdGrantorId]", - new { Id = id, GrantorId = grantorId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadByIdGrantorId]", + new { Id = id, GrantorId = grantorId }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); + return results.FirstOrDefault(); + } } - } - public async Task> GetManyToNotifyAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyToNotifyAsync() { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccess_ReadToNotify]", - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccess_ReadToNotify]", + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetExpiredRecoveriesAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetExpiredRecoveriesAsync() { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadExpiredRecoveries]", - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadExpiredRecoveries]", + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/EventRepository.cs b/src/Infrastructure.Dapper/Repositories/EventRepository.cs index ba4c68b352..82491cb035 100644 --- a/src/Infrastructure.Dapper/Repositories/EventRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/EventRepository.cs @@ -6,220 +6,221 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class EventRepository : Repository, IEventRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public EventRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public EventRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, - PageOptions pageOptions) + public class EventRepository : Repository, IEventRepository { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByUserId]", - new Dictionary - { - ["@UserId"] = userId - }, startDate, endDate, pageOptions); - } + public EventRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - public async Task> GetManyByOrganizationAsync(Guid organizationId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationId]", - new Dictionary - { - ["@OrganizationId"] = organizationId - }, startDate, endDate, pageOptions); - } + public EventRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } - public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationIdActingUserId]", - new Dictionary - { - ["@OrganizationId"] = organizationId, - ["@ActingUserId"] = actingUserId - }, startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderAsync(Guid providerId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderId]", - new Dictionary - { - ["@ProviderId"] = providerId - }, startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderIdActingUserId]", - new Dictionary - { - ["@ProviderId"] = providerId, - ["@ActingUserId"] = actingUserId - }, startDate, endDate, pageOptions); - } - - public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, - PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByCipherId]", - new Dictionary - { - ["@OrganizationId"] = cipher.OrganizationId, - ["@UserId"] = cipher.UserId, - ["@CipherId"] = cipher.Id - }, startDate, endDate, pageOptions); - } - - public async Task CreateAsync(IEvent e) - { - if (!(e is Event ev)) + public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, + PageOptions pageOptions) { - ev = new Event(e); + return await GetManyAsync($"[{Schema}].[Event_ReadPageByUserId]", + new Dictionary + { + ["@UserId"] = userId + }, startDate, endDate, pageOptions); } - await base.CreateAsync(ev); - } - - public async Task CreateManyAsync(IEnumerable entities) - { - if (!entities?.Any() ?? true) + public async Task> GetManyByOrganizationAsync(Guid organizationId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) { - return; + return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationId]", + new Dictionary + { + ["@OrganizationId"] = organizationId + }, startDate, endDate, pageOptions); } - if (!entities.Skip(1).Any()) + public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) { - await CreateAsync(entities.First()); - return; + return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationIdActingUserId]", + new Dictionary + { + ["@OrganizationId"] = organizationId, + ["@ActingUserId"] = actingUserId + }, startDate, endDate, pageOptions); } - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByProviderAsync(Guid providerId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) { - connection.Open(); - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, null)) + return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderId]", + new Dictionary + { + ["@ProviderId"] = providerId + }, startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderIdActingUserId]", + new Dictionary + { + ["@ProviderId"] = providerId, + ["@ActingUserId"] = actingUserId + }, startDate, endDate, pageOptions); + } + + public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, + PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByCipherId]", + new Dictionary + { + ["@OrganizationId"] = cipher.OrganizationId, + ["@UserId"] = cipher.UserId, + ["@CipherId"] = cipher.Id + }, startDate, endDate, pageOptions); + } + + public async Task CreateAsync(IEvent e) + { + if (!(e is Event ev)) { - bulkCopy.DestinationTableName = "[dbo].[Event]"; - var dataTable = BuildEventsTable(bulkCopy, entities.Select(e => e is Event ? e as Event : new Event(e))); - await bulkCopy.WriteToServerAsync(dataTable); + ev = new Event(e); + } + + await base.CreateAsync(ev); + } + + public async Task CreateManyAsync(IEnumerable entities) + { + if (!entities?.Any() ?? true) + { + return; + } + + if (!entities.Skip(1).Any()) + { + await CreateAsync(entities.First()); + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, null)) + { + bulkCopy.DestinationTableName = "[dbo].[Event]"; + var dataTable = BuildEventsTable(bulkCopy, entities.Select(e => e is Event ? e as Event : new Event(e))); + await bulkCopy.WriteToServerAsync(dataTable); + } } } - } - private async Task> GetManyAsync(string sprocName, - IDictionary sprocParams, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + private async Task> GetManyAsync(string sprocName, + IDictionary sprocParams, DateTime startDate, DateTime endDate, PageOptions pageOptions) { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - - var parameters = new DynamicParameters(sprocParams); - parameters.Add("@PageSize", pageOptions.PageSize, DbType.Int32); - // Explicitly use DbType.DateTime2 for proper precision. - // ref: https://github.com/StackExchange/Dapper/issues/229 - parameters.Add("@StartDate", startDate.ToUniversalTime(), DbType.DateTime2, null, 7); - parameters.Add("@EndDate", endDate.ToUniversalTime(), DbType.DateTime2, null, 7); - parameters.Add("@BeforeDate", beforeDate, DbType.DateTime2, null, 7); - - using (var connection = new SqlConnection(ConnectionString)) - { - var events = (await connection.QueryAsync(sprocName, parameters, - commandType: CommandType.StoredProcedure)).ToList(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); } - result.Data.AddRange(events); - return result; - } - } - private DataTable BuildEventsTable(SqlBulkCopy bulkCopy, IEnumerable events) - { - var e = events.FirstOrDefault(); - if (e == null) + var parameters = new DynamicParameters(sprocParams); + parameters.Add("@PageSize", pageOptions.PageSize, DbType.Int32); + // Explicitly use DbType.DateTime2 for proper precision. + // ref: https://github.com/StackExchange/Dapper/issues/229 + parameters.Add("@StartDate", startDate.ToUniversalTime(), DbType.DateTime2, null, 7); + parameters.Add("@EndDate", endDate.ToUniversalTime(), DbType.DateTime2, null, 7); + parameters.Add("@BeforeDate", beforeDate, DbType.DateTime2, null, 7); + + using (var connection = new SqlConnection(ConnectionString)) + { + var events = (await connection.QueryAsync(sprocName, parameters, + commandType: CommandType.StoredProcedure)).ToList(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + private DataTable BuildEventsTable(SqlBulkCopy bulkCopy, IEnumerable events) { - throw new ApplicationException("Must have some events to bulk import."); + var e = events.FirstOrDefault(); + if (e == null) + { + throw new ApplicationException("Must have some events to bulk import."); + } + + var eventsTable = new DataTable("EventDataTable"); + + var idColumn = new DataColumn(nameof(e.Id), e.Id.GetType()); + eventsTable.Columns.Add(idColumn); + var typeColumn = new DataColumn(nameof(e.Type), typeof(int)); + eventsTable.Columns.Add(typeColumn); + var userIdColumn = new DataColumn(nameof(e.UserId), typeof(Guid)); + eventsTable.Columns.Add(userIdColumn); + var organizationIdColumn = new DataColumn(nameof(e.OrganizationId), typeof(Guid)); + eventsTable.Columns.Add(organizationIdColumn); + var cipherIdColumn = new DataColumn(nameof(e.CipherId), typeof(Guid)); + eventsTable.Columns.Add(cipherIdColumn); + var collectionIdColumn = new DataColumn(nameof(e.CollectionId), typeof(Guid)); + eventsTable.Columns.Add(collectionIdColumn); + var policyIdColumn = new DataColumn(nameof(e.PolicyId), typeof(Guid)); + eventsTable.Columns.Add(policyIdColumn); + var groupIdColumn = new DataColumn(nameof(e.GroupId), typeof(Guid)); + eventsTable.Columns.Add(groupIdColumn); + var organizationUserIdColumn = new DataColumn(nameof(e.OrganizationUserId), typeof(Guid)); + eventsTable.Columns.Add(organizationUserIdColumn); + var actingUserIdColumn = new DataColumn(nameof(e.ActingUserId), typeof(Guid)); + eventsTable.Columns.Add(actingUserIdColumn); + var deviceTypeColumn = new DataColumn(nameof(e.DeviceType), typeof(int)); + eventsTable.Columns.Add(deviceTypeColumn); + var ipAddressColumn = new DataColumn(nameof(e.IpAddress), typeof(string)); + eventsTable.Columns.Add(ipAddressColumn); + var dateColumn = new DataColumn(nameof(e.Date), typeof(DateTime)); + eventsTable.Columns.Add(dateColumn); + + foreach (DataColumn col in eventsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + eventsTable.PrimaryKey = keys; + + foreach (var ev in events) + { + ev.SetNewId(); + + var row = eventsTable.NewRow(); + + row[idColumn] = ev.Id; + row[typeColumn] = (int)ev.Type; + row[userIdColumn] = ev.UserId.HasValue ? (object)ev.UserId.Value : DBNull.Value; + row[organizationIdColumn] = ev.OrganizationId.HasValue ? (object)ev.OrganizationId.Value : DBNull.Value; + row[cipherIdColumn] = ev.CipherId.HasValue ? (object)ev.CipherId.Value : DBNull.Value; + row[collectionIdColumn] = ev.CollectionId.HasValue ? (object)ev.CollectionId.Value : DBNull.Value; + row[policyIdColumn] = ev.PolicyId.HasValue ? (object)ev.PolicyId.Value : DBNull.Value; + row[groupIdColumn] = ev.GroupId.HasValue ? (object)ev.GroupId.Value : DBNull.Value; + row[organizationUserIdColumn] = ev.OrganizationUserId.HasValue ? + (object)ev.OrganizationUserId.Value : DBNull.Value; + row[actingUserIdColumn] = ev.ActingUserId.HasValue ? (object)ev.ActingUserId.Value : DBNull.Value; + row[deviceTypeColumn] = ev.DeviceType.HasValue ? (object)ev.DeviceType.Value : DBNull.Value; + row[ipAddressColumn] = ev.IpAddress != null ? (object)ev.IpAddress : DBNull.Value; + row[dateColumn] = ev.Date; + + eventsTable.Rows.Add(row); + } + + return eventsTable; } - - var eventsTable = new DataTable("EventDataTable"); - - var idColumn = new DataColumn(nameof(e.Id), e.Id.GetType()); - eventsTable.Columns.Add(idColumn); - var typeColumn = new DataColumn(nameof(e.Type), typeof(int)); - eventsTable.Columns.Add(typeColumn); - var userIdColumn = new DataColumn(nameof(e.UserId), typeof(Guid)); - eventsTable.Columns.Add(userIdColumn); - var organizationIdColumn = new DataColumn(nameof(e.OrganizationId), typeof(Guid)); - eventsTable.Columns.Add(organizationIdColumn); - var cipherIdColumn = new DataColumn(nameof(e.CipherId), typeof(Guid)); - eventsTable.Columns.Add(cipherIdColumn); - var collectionIdColumn = new DataColumn(nameof(e.CollectionId), typeof(Guid)); - eventsTable.Columns.Add(collectionIdColumn); - var policyIdColumn = new DataColumn(nameof(e.PolicyId), typeof(Guid)); - eventsTable.Columns.Add(policyIdColumn); - var groupIdColumn = new DataColumn(nameof(e.GroupId), typeof(Guid)); - eventsTable.Columns.Add(groupIdColumn); - var organizationUserIdColumn = new DataColumn(nameof(e.OrganizationUserId), typeof(Guid)); - eventsTable.Columns.Add(organizationUserIdColumn); - var actingUserIdColumn = new DataColumn(nameof(e.ActingUserId), typeof(Guid)); - eventsTable.Columns.Add(actingUserIdColumn); - var deviceTypeColumn = new DataColumn(nameof(e.DeviceType), typeof(int)); - eventsTable.Columns.Add(deviceTypeColumn); - var ipAddressColumn = new DataColumn(nameof(e.IpAddress), typeof(string)); - eventsTable.Columns.Add(ipAddressColumn); - var dateColumn = new DataColumn(nameof(e.Date), typeof(DateTime)); - eventsTable.Columns.Add(dateColumn); - - foreach (DataColumn col in eventsTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - eventsTable.PrimaryKey = keys; - - foreach (var ev in events) - { - ev.SetNewId(); - - var row = eventsTable.NewRow(); - - row[idColumn] = ev.Id; - row[typeColumn] = (int)ev.Type; - row[userIdColumn] = ev.UserId.HasValue ? (object)ev.UserId.Value : DBNull.Value; - row[organizationIdColumn] = ev.OrganizationId.HasValue ? (object)ev.OrganizationId.Value : DBNull.Value; - row[cipherIdColumn] = ev.CipherId.HasValue ? (object)ev.CipherId.Value : DBNull.Value; - row[collectionIdColumn] = ev.CollectionId.HasValue ? (object)ev.CollectionId.Value : DBNull.Value; - row[policyIdColumn] = ev.PolicyId.HasValue ? (object)ev.PolicyId.Value : DBNull.Value; - row[groupIdColumn] = ev.GroupId.HasValue ? (object)ev.GroupId.Value : DBNull.Value; - row[organizationUserIdColumn] = ev.OrganizationUserId.HasValue ? - (object)ev.OrganizationUserId.Value : DBNull.Value; - row[actingUserIdColumn] = ev.ActingUserId.HasValue ? (object)ev.ActingUserId.Value : DBNull.Value; - row[deviceTypeColumn] = ev.DeviceType.HasValue ? (object)ev.DeviceType.Value : DBNull.Value; - row[ipAddressColumn] = ev.IpAddress != null ? (object)ev.IpAddress : DBNull.Value; - row[dateColumn] = ev.Date; - - eventsTable.Rows.Add(row); - } - - return eventsTable; } } diff --git a/src/Infrastructure.Dapper/Repositories/FolderRepository.cs b/src/Infrastructure.Dapper/Repositories/FolderRepository.cs index 6500d35dd8..a0bd11c9d8 100644 --- a/src/Infrastructure.Dapper/Repositories/FolderRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/FolderRepository.cs @@ -5,39 +5,40 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class FolderRepository : Repository, IFolderRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public FolderRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public FolderRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + public class FolderRepository : Repository, IFolderRepository { - var folder = await GetByIdAsync(id); - if (folder == null || folder.UserId != userId) + public FolderRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public FolderRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - return null; + var folder = await GetByIdAsync(id); + if (folder == null || folder.UserId != userId) + { + return null; + } + + return folder; } - return folder; - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserIdAsync(Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[Folder_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Folder_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/GrantRepository.cs b/src/Infrastructure.Dapper/Repositories/GrantRepository.cs index 168576fa9b..6596fa5104 100644 --- a/src/Infrastructure.Dapper/Repositories/GrantRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/GrantRepository.cs @@ -5,75 +5,76 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class GrantRepository : BaseRepository, IGrantRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public GrantRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public GrantRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByKeyAsync(string key) + public class GrantRepository : BaseRepository, IGrantRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Grant_ReadByKey]", - new { Key = key }, - commandType: CommandType.StoredProcedure); + public GrantRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.SingleOrDefault(); + public GrantRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByKeyAsync(string key) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Grant_ReadByKey]", + new { Key = key }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - public async Task> GetManyAsync(string subjectId, string sessionId, - string clientId, string type) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyAsync(string subjectId, string sessionId, + string clientId, string type) { - var results = await connection.QueryAsync( - "[dbo].[Grant_Read]", - new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Grant_Read]", + new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task SaveAsync(Grant obj) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task SaveAsync(Grant obj) { - var results = await connection.ExecuteAsync( - "[dbo].[Grant_Save]", - obj, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[Grant_Save]", + obj, + commandType: CommandType.StoredProcedure); + } } - } - public async Task DeleteByKeyAsync(string key) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteByKeyAsync(string key) { - await connection.ExecuteAsync( - "[dbo].[Grant_DeleteByKey]", - new { Key = key }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[Grant_DeleteByKey]", + new { Key = key }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) { - await connection.ExecuteAsync( - "[dbo].[Grant_Delete]", - new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[Grant_Delete]", + new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/GroupRepository.cs b/src/Infrastructure.Dapper/Repositories/GroupRepository.cs index eb0482bf3b..31f6a29bb4 100644 --- a/src/Infrastructure.Dapper/Repositories/GroupRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/GroupRepository.cs @@ -7,134 +7,135 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class GroupRepository : Repository, IGroupRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public GroupRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public GroupRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task>> GetByIdWithCollectionsAsync(Guid id) + public class GroupRepository : Repository, IGroupRepository { - using (var connection = new SqlConnection(ConnectionString)) + public GroupRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public GroupRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task>> GetByIdWithCollectionsAsync(Guid id) { - var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Group_ReadWithCollectionsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + $"[{Schema}].[Group_ReadWithCollectionsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); - var group = await results.ReadFirstOrDefaultAsync(); - var colletions = (await results.ReadAsync()).ToList(); + var group = await results.ReadFirstOrDefaultAsync(); + var colletions = (await results.ReadAsync()).ToList(); - return new Tuple>(group, colletions); + return new Tuple>(group, colletions); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var results = await connection.QueryAsync( - $"[{Schema}].[Group_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Group_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) { - var results = await connection.QueryAsync( - $"[{Schema}].[GroupUser_ReadGroupIdsByOrganizationUserId]", - new { OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[GroupUser_ReadGroupIdsByOrganizationUserId]", + new { OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyUserIdsByIdAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyUserIdsByIdAsync(Guid id) { - var results = await connection.QueryAsync( - $"[{Schema}].[GroupUser_ReadOrganizationUserIdsByGroupId]", - new { GroupId = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[GroupUser_ReadOrganizationUserIdsByGroupId]", + new { GroupId = id }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) { - var results = await connection.QueryAsync( - $"[{Schema}].[GroupUser_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[GroupUser_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task CreateAsync(Group obj, IEnumerable collections) - { - obj.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) + public async Task CreateAsync(Group obj, IEnumerable collections) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Group_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); + obj.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Group_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } } - } - public async Task ReplaceAsync(Group obj, IEnumerable collections) - { - var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) + public async Task ReplaceAsync(Group obj, IEnumerable collections) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Group_UpdateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); + var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Group_UpdateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } } - } - public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[GroupUser_Delete]", - new { GroupId = groupId, OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[GroupUser_Delete]", + new { GroupId = groupId, OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) { - var results = await connection.ExecuteAsync( - "[dbo].[GroupUser_UpdateUsers]", - new { GroupId = groupId, OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[GroupUser_UpdateUsers]", + new { GroupId = groupId, OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs b/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs index 0bb38761cf..b82b13c497 100644 --- a/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs @@ -2,15 +2,16 @@ using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class InstallationRepository : Repository, IInstallationRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public InstallationRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } + public class InstallationRepository : Repository, IInstallationRepository + { + public InstallationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - public InstallationRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } + public InstallationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + } } diff --git a/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs b/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs index fb5bf30918..05c6e9d634 100644 --- a/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs @@ -4,73 +4,74 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class MaintenanceRepository : BaseRepository, IMaintenanceRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public MaintenanceRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public MaintenanceRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task UpdateStatisticsAsync() + public class MaintenanceRepository : BaseRepository, IMaintenanceRepository { - using (var connection = new SqlConnection(ConnectionString)) + public MaintenanceRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public MaintenanceRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task UpdateStatisticsAsync() { - await connection.ExecuteAsync( - "[dbo].[AzureSQLMaintenance]", - new { operation = "statistics", mode = "smart", LogToTable = true }, - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[AzureSQLMaintenance]", + new { operation = "statistics", mode = "smart", LogToTable = true }, + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); + } } - } - public async Task DisableCipherAutoStatsAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DisableCipherAutoStatsAsync() { - await connection.ExecuteAsync( - "sp_autostats", - new { tblname = "[dbo].[Cipher]", flagc = "OFF" }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "sp_autostats", + new { tblname = "[dbo].[Cipher]", flagc = "OFF" }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task RebuildIndexesAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task RebuildIndexesAsync() { - await connection.ExecuteAsync( - "[dbo].[AzureSQLMaintenance]", - new { operation = "index", mode = "smart", LogToTable = true }, - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[AzureSQLMaintenance]", + new { operation = "index", mode = "smart", LogToTable = true }, + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); + } } - } - public async Task DeleteExpiredGrantsAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteExpiredGrantsAsync() { - await connection.ExecuteAsync( - "[dbo].[Grant_DeleteExpired]", - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[Grant_DeleteExpired]", + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); + } } - } - public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) { - await connection.ExecuteAsync( - "[dbo].[OrganizationSponsorship_DeleteExpired]", - new { ValidUntilBeforeDate = validUntilBeforeDate }, - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[OrganizationSponsorship_DeleteExpired]", + new { ValidUntilBeforeDate = validUntilBeforeDate }, + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs index 05eaac68fb..b0694862ff 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs @@ -6,32 +6,33 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public OrganizationApiKeyRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository { - - } - - public OrganizationApiKeyRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) - { - using (var connection = new SqlConnection(ConnectionString)) + public OrganizationApiKeyRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) { - return await connection.QueryAsync( - "[dbo].[OrganizationApikey_ReadManyByOrganizationIdType]", - new - { - OrganizationId = organizationId, - Type = type, - }, - commandType: CommandType.StoredProcedure); + + } + + public OrganizationApiKeyRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) + { + using (var connection = new SqlConnection(ConnectionString)) + { + return await connection.QueryAsync( + "[dbo].[OrganizationApikey_ReadManyByOrganizationIdType]", + new + { + OrganizationId = organizationId, + Type = type, + }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs index 1cc9975889..6de4559fad 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs @@ -6,31 +6,32 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public OrganizationConnectionRepository(GlobalSettings globalSettings) - : base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository { - using (var connection = new SqlConnection(ConnectionString)) + public OrganizationConnectionRepository(GlobalSettings globalSettings) + : base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) { - var results = await connection.QueryAsync( - $"[{Schema}].[OrganizationConnection_ReadByOrganizationIdType]", - new - { - OrganizationId = organizationId, - Type = type - }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[OrganizationConnection_ReadByOrganizationIdType]", + new + { + OrganizationId = organizationId, + Type = type + }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) => - (await GetByOrganizationIdTypeAsync(organizationId, type)).Where(c => c.Enabled).ToList(); + public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) => + (await GetByOrganizationIdTypeAsync(organizationId, type)).Where(c => c.Enabled).ToList(); + } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs index a4e294b292..05cc3c92ad 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs @@ -6,92 +6,93 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class OrganizationRepository : Repository, IOrganizationRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public OrganizationRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public OrganizationRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdentifierAsync(string identifier) + public class OrganizationRepository : Repository, IOrganizationRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadByIdentifier]", - new { Identifier = identifier }, - commandType: CommandType.StoredProcedure); + public OrganizationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.SingleOrDefault(); + public OrganizationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdentifierAsync(string identifier) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadByIdentifier]", + new { Identifier = identifier }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - public async Task> GetManyByEnabledAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByEnabledAsync() { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadByEnabled]", - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadByEnabled]", + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserIdAsync(Guid userId) { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> SearchAsync(string name, string userEmail, bool? paid, - int skip, int take) - { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) + public async Task> SearchAsync(string name, string userEmail, bool? paid, + int skip, int take) { - var results = await connection.QueryAsync( - "[dbo].[Organization_Search]", - new { Name = name, UserEmail = userEmail, Paid = paid, Skip = skip, Take = take }, - commandType: CommandType.StoredProcedure, - commandTimeout: 120); + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Organization_Search]", + new { Name = name, UserEmail = userEmail, Paid = paid, Skip = skip, Take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); - return results.ToList(); + return results.ToList(); + } } - } - public async Task UpdateStorageAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpdateStorageAsync(Guid id) { - await connection.ExecuteAsync( - "[dbo].[Organization_UpdateStorage]", - new { Id = id }, - commandType: CommandType.StoredProcedure, - commandTimeout: 180); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "[dbo].[Organization_UpdateStorage]", + new { Id = id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); + } } - } - public async Task> GetManyAbilitiesAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyAbilitiesAsync() { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadAbilities]", - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadAbilities]", + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs index 11e453cacc..6e4ca9904f 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs @@ -5,142 +5,143 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public OrganizationSponsorshipRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public OrganizationSponsorshipRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> CreateManyAsync(IEnumerable organizationSponsorships) + public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository { - if (!organizationSponsorships.Any()) - { - return default; - } + public OrganizationSponsorshipRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - foreach (var organizationSponsorship in organizationSponsorships) - { - organizationSponsorship.SetNewId(); - } + public OrganizationSponsorshipRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } - var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); - using (var connection = new SqlConnection(ConnectionString)) + public async Task> CreateManyAsync(IEnumerable organizationSponsorships) { - var results = await connection.ExecuteAsync( - $"[dbo].[OrganizationSponsorship_CreateMany]", - new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, - commandType: CommandType.StoredProcedure); - } - - return organizationSponsorships.Select(u => u.Id).ToList(); - } - - public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) - { - if (!organizationSponsorships.Any()) - { - return; - } - - var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[dbo].[OrganizationSponsorship_UpdateMany]", - new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task UpsertManyAsync(IEnumerable organizationSponsorships) - { - var createSponsorships = new List(); - var replaceSponsorships = new List(); - foreach (var organizationSponsorship in organizationSponsorships) - { - if (organizationSponsorship.Id.Equals(default)) + if (!organizationSponsorships.Any()) { - createSponsorships.Add(organizationSponsorship); + return default; } - else + + foreach (var organizationSponsorship in organizationSponsorships) { - replaceSponsorships.Add(organizationSponsorship); + organizationSponsorship.SetNewId(); + } + + var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[dbo].[OrganizationSponsorship_CreateMany]", + new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, + commandType: CommandType.StoredProcedure); + } + + return organizationSponsorships.Select(u => u.Id).ToList(); + } + + public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) + { + if (!organizationSponsorships.Any()) + { + return; + } + + var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[dbo].[OrganizationSponsorship_UpdateMany]", + new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, + commandType: CommandType.StoredProcedure); } } - await CreateManyAsync(createSponsorships); - await ReplaceManyAsync(replaceSponsorships); - } - - public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task UpsertManyAsync(IEnumerable organizationSponsorships) { - await connection.ExecuteAsync("[dbo].[OrganizationSponsorship_DeleteByIds]", - new { Ids = organizationSponsorshipIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); - } - } - - public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationUserId]", - new + var createSponsorships = new List(); + var replaceSponsorships = new List(); + foreach (var organizationSponsorship in organizationSponsorships) + { + if (organizationSponsorship.Id.Equals(default)) { - SponsoringOrganizationUserId = sponsoringOrganizationUserId - }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationSponsorship_ReadBySponsoredOrganizationId]", - new { SponsoredOrganizationId = sponsoredOrganizationId }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - return await connection.QuerySingleOrDefaultAsync( - "[dbo].[OrganizationSponsorship_ReadLatestBySponsoringOrganizationId]", - new { SponsoringOrganizationId = sponsoringOrganizationId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationId]", - new + createSponsorships.Add(organizationSponsorship); + } + else { - SponsoringOrganizationId = sponsoringOrganizationId - }, - commandType: CommandType.StoredProcedure); + replaceSponsorships.Add(organizationSponsorship); + } + } - return results.ToList(); + await CreateManyAsync(createSponsorships); + await ReplaceManyAsync(replaceSponsorships); } - } + public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync("[dbo].[OrganizationSponsorship_DeleteByIds]", + new { Ids = organizationSponsorshipIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + } + } + + public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationUserId]", + new + { + SponsoringOrganizationUserId = sponsoringOrganizationUserId + }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationSponsorship_ReadBySponsoredOrganizationId]", + new { SponsoredOrganizationId = sponsoredOrganizationId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + return await connection.QuerySingleOrDefaultAsync( + "[dbo].[OrganizationSponsorship_ReadLatestBySponsoringOrganizationId]", + new { SponsoringOrganizationId = sponsoringOrganizationId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationId]", + new + { + SponsoringOrganizationId = sponsoringOrganizationId + }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs index 06aede3da8..856fcb7a4f 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs @@ -9,422 +9,423 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class OrganizationUserRepository : Repository, IOrganizationUserRepository +namespace Bit.Infrastructure.Dapper.Repositories { - /// - /// For use with methods with TDS stream issues. - /// This has been observed in Linux-hosted SqlServers with large table-valued-parameters - /// https://github.com/dotnet/SqlClient/issues/54 - /// - private string _marsConnectionString; - - public OrganizationUserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + public class OrganizationUserRepository : Repository, IOrganizationUserRepository { - var builder = new SqlConnectionStringBuilder(ConnectionString) + /// + /// For use with methods with TDS stream issues. + /// This has been observed in Linux-hosted SqlServers with large table-valued-parameters + /// https://github.com/dotnet/SqlClient/issues/54 + /// + private string _marsConnectionString; + + public OrganizationUserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) { - MultipleActiveResultSets = true, - }; - _marsConnectionString = builder.ToString(); - } - - public OrganizationUserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results; - } - } - - public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByFreeOrganizationAdminUser]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results; - } - } - - public async Task GetCountByOnlyOwnerAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByOnlyOwner]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results; - } - } - - public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByOrganizationIdEmail]", - new { OrganizationId = organizationId, Email = email, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); - - return result; - } - } - - public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, - bool onlyRegisteredUsers) - { - var emailsTvp = emails.ToArrayTVP("Email"); - using (var connection = new SqlConnection(_marsConnectionString)) - { - var result = await connection.QueryAsync( - "[dbo].[OrganizationUser_SelectKnownEmails]", - new { OrganizationId = organizationId, Emails = emailsTvp, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); - - // Return as a list to avoid timing out the sql connection - return result.ToList(); - } - } - - public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByOrganizationIdUserId]", - new { OrganizationId = organizationId, UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task> GetManyByUserAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyByOrganizationAsync(Guid organizationId, - OrganizationUserType? type) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByOrganizationId]", - new { OrganizationId = organizationId, Type = type }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task>> GetByIdWithCollectionsAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - "[dbo].[OrganizationUser_ReadWithCollectionsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - var user = (await results.ReadAsync()).SingleOrDefault(); - var collections = (await results.ReadAsync()).ToList(); - return new Tuple>(user, collections); - } - } - - public async Task GetDetailsByIdAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserUserDetails_ReadById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - public async Task>> - GetDetailsByIdWithCollectionsAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - "[dbo].[OrganizationUserUserDetails_ReadWithCollectionsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - var user = (await results.ReadAsync()).SingleOrDefault(); - var collections = (await results.ReadAsync()).ToList(); - return new Tuple>(user, collections); - } - } - - public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserUserDetails_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyDetailsByUserAsync(Guid userId, - OrganizationUserStatusType? status = null) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatus]", - new { UserId = userId, Status = status }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task GetDetailsByUserAsync(Guid userId, - Guid organizationId, OrganizationUserStatusType? status = null) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatusOrganizationId]", - new { UserId = userId, Status = status, OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[GroupUser_UpdateGroups]", - new { OrganizationUserId = orgUserId, GroupIds = groupIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task CreateAsync(OrganizationUser obj, IEnumerable collections) - { - obj.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[OrganizationUser_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } - - return obj.Id; - } - - public async Task ReplaceAsync(OrganizationUser obj, IEnumerable collections) - { - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[OrganizationUser_UpdateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyByManyUsersAsync(IEnumerable userIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByUserIds]", - new { UserIds = userIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyAsync(IEnumerable Ids) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByIds]", - new { Ids = Ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByOrganizationIdEmail]", - new { OrganizationId = organizationId, Email = email }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task DeleteManyAsync(IEnumerable organizationUserIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync("[dbo].[OrganizationUser_DeleteByIds]", - new { Ids = organizationUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); - } - } - - public async Task UpsertManyAsync(IEnumerable organizationUsers) - { - var createUsers = new List(); - var replaceUsers = new List(); - foreach (var organizationUser in organizationUsers) - { - if (organizationUser.Id.Equals(default)) + var builder = new SqlConnectionStringBuilder(ConnectionString) { - createUsers.Add(organizationUser); - } - else + MultipleActiveResultSets = true, + }; + _marsConnectionString = builder.ToString(); + } + + public OrganizationUserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - replaceUsers.Add(organizationUser); + var results = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results; } } - await CreateManyAsync(createUsers); - await ReplaceManyAsync(replaceUsers); - } - - public async Task> CreateManyAsync(IEnumerable organizationUsers) - { - if (!organizationUsers.Any()) + public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) { - return default; + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByFreeOrganizationAdminUser]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results; + } } - foreach (var organizationUser in organizationUsers) + public async Task GetCountByOnlyOwnerAsync(Guid userId) { - organizationUser.SetNewId(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByOnlyOwner]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results; + } } - var orgUsersTVP = organizationUsers.ToTvp(); - using (var connection = new SqlConnection(_marsConnectionString)) + public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_CreateMany]", - new { OrganizationUsersInput = orgUsersTVP }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByOrganizationIdEmail]", + new { OrganizationId = organizationId, Email = email, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); + + return result; + } } - return organizationUsers.Select(u => u.Id).ToList(); - } - - public async Task ReplaceManyAsync(IEnumerable organizationUsers) - { - if (!organizationUsers.Any()) + public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, + bool onlyRegisteredUsers) { - return; + var emailsTvp = emails.ToArrayTVP("Email"); + using (var connection = new SqlConnection(_marsConnectionString)) + { + var result = await connection.QueryAsync( + "[dbo].[OrganizationUser_SelectKnownEmails]", + new { OrganizationId = organizationId, Emails = emailsTvp, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); + + // Return as a list to avoid timing out the sql connection + return result.ToList(); + } } - var orgUsersTVP = organizationUsers.ToTvp(); - using (var connection = new SqlConnection(_marsConnectionString)) + public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_UpdateMany]", - new { OrganizationUsersInput = orgUsersTVP }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByOrganizationIdUserId]", + new { OrganizationId = organizationId, UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - public async Task> GetManyPublicKeysByOrganizationUserAsync( - Guid organizationId, IEnumerable Ids) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserAsync(Guid userId) { - var results = await connection.QueryAsync( - "[dbo].[User_ReadPublicKeysByOrganizationUserIds]", - new { OrganizationId = organizationId, OrganizationUserIds = Ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByOrganizationAsync(Guid organizationId, + OrganizationUserType? type) { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByMinimumRole]", - new { OrganizationId = organizationId, MinRole = minRole }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByOrganizationId]", + new { OrganizationId = organizationId, Type = type }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task RevokeAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task>> GetByIdWithCollectionsAsync(Guid id) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Deactivate]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + "[dbo].[OrganizationUser_ReadWithCollectionsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + var user = (await results.ReadAsync()).SingleOrDefault(); + var collections = (await results.ReadAsync()).ToList(); + return new Tuple>(user, collections); + } } - } - public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetDetailsByIdAsync(Guid id) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Activate]", - new { Id = id, Status = status }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserUserDetails_ReadById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + public async Task>> + GetDetailsByIdWithCollectionsAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + "[dbo].[OrganizationUserUserDetails_ReadWithCollectionsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + var user = (await results.ReadAsync()).SingleOrDefault(); + var collections = (await results.ReadAsync()).ToList(); + return new Tuple>(user, collections); + } + } + + public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserUserDetails_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyDetailsByUserAsync(Guid userId, + OrganizationUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatus]", + new { UserId = userId, Status = status }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetDetailsByUserAsync(Guid userId, + Guid organizationId, OrganizationUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatusOrganizationId]", + new { UserId = userId, Status = status, OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[GroupUser_UpdateGroups]", + new { OrganizationUserId = orgUserId, GroupIds = groupIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task CreateAsync(OrganizationUser obj, IEnumerable collections) + { + obj.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[OrganizationUser_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } + + return obj.Id; + } + + public async Task ReplaceAsync(OrganizationUser obj, IEnumerable collections) + { + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[OrganizationUser_UpdateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } + } + + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByUserIds]", + new { UserIds = userIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyAsync(IEnumerable Ids) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByIds]", + new { Ids = Ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByOrganizationIdEmail]", + new { OrganizationId = organizationId, Email = email }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task DeleteManyAsync(IEnumerable organizationUserIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync("[dbo].[OrganizationUser_DeleteByIds]", + new { Ids = organizationUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + } + } + + public async Task UpsertManyAsync(IEnumerable organizationUsers) + { + var createUsers = new List(); + var replaceUsers = new List(); + foreach (var organizationUser in organizationUsers) + { + if (organizationUser.Id.Equals(default)) + { + createUsers.Add(organizationUser); + } + else + { + replaceUsers.Add(organizationUser); + } + } + + await CreateManyAsync(createUsers); + await ReplaceManyAsync(replaceUsers); + } + + public async Task> CreateManyAsync(IEnumerable organizationUsers) + { + if (!organizationUsers.Any()) + { + return default; + } + + foreach (var organizationUser in organizationUsers) + { + organizationUser.SetNewId(); + } + + var orgUsersTVP = organizationUsers.ToTvp(); + using (var connection = new SqlConnection(_marsConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_CreateMany]", + new { OrganizationUsersInput = orgUsersTVP }, + commandType: CommandType.StoredProcedure); + } + + return organizationUsers.Select(u => u.Id).ToList(); + } + + public async Task ReplaceManyAsync(IEnumerable organizationUsers) + { + if (!organizationUsers.Any()) + { + return; + } + + var orgUsersTVP = organizationUsers.ToTvp(); + using (var connection = new SqlConnection(_marsConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_UpdateMany]", + new { OrganizationUsersInput = orgUsersTVP }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task> GetManyPublicKeysByOrganizationUserAsync( + Guid organizationId, IEnumerable Ids) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[User_ReadPublicKeysByOrganizationUserIds]", + new { OrganizationId = organizationId, OrganizationUserIds = Ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByMinimumRole]", + new { OrganizationId = organizationId, MinRole = minRole }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task RevokeAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Deactivate]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Activate]", + new { Id = id, Status = status }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs b/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs index 59552e51e7..46cd6e29e4 100644 --- a/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs @@ -6,82 +6,83 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class PolicyRepository : Repository, IPolicyRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public PolicyRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public PolicyRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) + public class PolicyRepository : Repository, IPolicyRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationIdType]", - new { OrganizationId = organizationId, Type = (byte)type }, - commandType: CommandType.StoredProcedure); + public PolicyRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.SingleOrDefault(); + public PolicyRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationIdType]", + new { OrganizationId = organizationId, Type = (byte)type }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserIdAsync(Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByTypeApplicableToUser]", - new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByTypeApplicableToUser]", + new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) { - var result = await connection.ExecuteScalarAsync( - $"[{Schema}].[{Table}_CountByTypeApplicableToUser]", - new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.ExecuteScalarAsync( + $"[{Schema}].[{Table}_CountByTypeApplicableToUser]", + new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, + commandType: CommandType.StoredProcedure); - return result; + return result; + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs index 18ce678669..282f1d2734 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs @@ -6,41 +6,42 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class ProviderOrganizationRepository : Repository, IProviderOrganizationRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public ProviderOrganizationRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public ProviderOrganizationRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + public class ProviderOrganizationRepository : Repository, IProviderOrganizationRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderOrganizationOrganizationDetails_ReadByProviderId]", - new { ProviderId = providerId }, - commandType: CommandType.StoredProcedure); + public ProviderOrganizationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.ToList(); + public ProviderOrganizationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyDetailsByProviderAsync(Guid providerId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderOrganizationOrganizationDetails_ReadByProviderId]", + new { ProviderId = providerId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } - } - public async Task GetByOrganizationId(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByOrganizationId(Guid organizationId) { - var results = await connection.QueryAsync( - "[dbo].[ProviderOrganization_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderOrganization_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs index 3bc38727c7..4619771a5f 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs @@ -6,41 +6,42 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class ProviderRepository : Repository, IProviderRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public ProviderRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public ProviderRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> SearchAsync(string name, string userEmail, int skip, int take) + public class ProviderRepository : Repository, IProviderRepository { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Provider_Search]", - new { Name = name, UserEmail = userEmail, Skip = skip, Take = take }, - commandType: CommandType.StoredProcedure, - commandTimeout: 120); + public ProviderRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.ToList(); + public ProviderRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> SearchAsync(string name, string userEmail, int skip, int take) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Provider_Search]", + new { Name = name, UserEmail = userEmail, Skip = skip, Take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); + + return results.ToList(); + } } - } - public async Task> GetManyAbilitiesAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyAbilitiesAsync() { - var results = await connection.QueryAsync( - "[dbo].[Provider_ReadAbilities]", - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[Provider_ReadAbilities]", + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs index 22a475321b..98375ab6a9 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs @@ -7,157 +7,158 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class ProviderUserRepository : Repository, IProviderUserRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public ProviderUserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public ProviderUserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) + public class ProviderUserRepository : Repository, IProviderUserRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - "[dbo].[ProviderUser_ReadCountByProviderIdEmail]", - new { ProviderId = providerId, Email = email, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); + public ProviderUserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return result; + public ProviderUserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.ExecuteScalarAsync( + "[dbo].[ProviderUser_ReadCountByProviderIdEmail]", + new { ProviderId = providerId, Email = email, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); + + return result; + } } - } - public async Task> GetManyAsync(IEnumerable ids) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyAsync(IEnumerable ids) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByIds]", - new { Ids = ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByIds]", + new { Ids = ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyByUserAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByUserAsync(Guid userId) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetByProviderUserAsync(Guid providerId, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByProviderIdUserId]", - new { ProviderId = providerId, UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByProviderIdUserId]", + new { ProviderId = providerId, UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } - } - public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByProviderId]", - new { ProviderId = providerId, Type = type }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByProviderId]", + new { ProviderId = providerId, Type = type }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyDetailsByProviderAsync(Guid providerId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyDetailsByProviderAsync(Guid providerId) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUserUserDetails_ReadByProviderId]", - new { ProviderId = providerId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUserUserDetails_ReadByProviderId]", + new { ProviderId = providerId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyDetailsByUserAsync(Guid userId, - ProviderUserStatusType? status = null) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyDetailsByUserAsync(Guid userId, + ProviderUserStatusType? status = null) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUserProviderDetails_ReadByUserIdStatus]", - new { UserId = userId, Status = status }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUserProviderDetails_ReadByUserIdStatus]", + new { UserId = userId, Status = status }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, - ProviderUserStatusType? status = null) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, + ProviderUserStatusType? status = null) { - var results = await connection.QueryAsync( - "[dbo].[ProviderUserProviderOrganizationDetails_ReadByUserIdStatus]", - new { UserId = userId, Status = status }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[ProviderUserProviderOrganizationDetails_ReadByUserIdStatus]", + new { UserId = userId, Status = status }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task DeleteManyAsync(IEnumerable providerUserIds) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task DeleteManyAsync(IEnumerable providerUserIds) { - await connection.ExecuteAsync("[dbo].[ProviderUser_DeleteByIds]", - new { Ids = providerUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync("[dbo].[ProviderUser_DeleteByIds]", + new { Ids = providerUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + } } - } - public async Task> GetManyPublicKeysByProviderUserAsync( - Guid providerId, IEnumerable Ids) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyPublicKeysByProviderUserAsync( + Guid providerId, IEnumerable Ids) { - var results = await connection.QueryAsync( - "[dbo].[User_ReadPublicKeysByProviderUserIds]", - new { ProviderId = providerId, ProviderUserIds = Ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[User_ReadPublicKeysByProviderUserIds]", + new { ProviderId = providerId, ProviderUserIds = Ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetCountByOnlyOwnerAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetCountByOnlyOwnerAsync(Guid userId) { - var results = await connection.ExecuteScalarAsync( - "[dbo].[ProviderUser_ReadCountByOnlyOwner]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[ProviderUser_ReadCountByOnlyOwner]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results; + return results; + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/Repository.cs b/src/Infrastructure.Dapper/Repositories/Repository.cs index 0c46a6d0a9..4bc0b91b1e 100644 --- a/src/Infrastructure.Dapper/Repositories/Repository.cs +++ b/src/Infrastructure.Dapper/Repositories/Repository.cs @@ -4,91 +4,92 @@ using Bit.Core.Entities; using Bit.Core.Repositories; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public abstract class Repository : BaseRepository, IRepository - where TId : IEquatable - where T : class, ITableObject +namespace Bit.Infrastructure.Dapper.Repositories { - public Repository(string connectionString, string readOnlyConnectionString, - string schema = null, string table = null) - : base(connectionString, readOnlyConnectionString) + public abstract class Repository : BaseRepository, IRepository + where TId : IEquatable + where T : class, ITableObject { - if (!string.IsNullOrWhiteSpace(table)) + public Repository(string connectionString, string readOnlyConnectionString, + string schema = null, string table = null) + : base(connectionString, readOnlyConnectionString) { - Table = table; + if (!string.IsNullOrWhiteSpace(table)) + { + Table = table; + } + + if (!string.IsNullOrWhiteSpace(schema)) + { + Schema = schema; + } } - if (!string.IsNullOrWhiteSpace(schema)) + protected string Schema { get; private set; } = "dbo"; + protected string Table { get; private set; } = typeof(T).Name; + + public virtual async Task GetByIdAsync(TId id) { - Schema = schema; + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - protected string Schema { get; private set; } = "dbo"; - protected string Table { get; private set; } = typeof(T).Name; - - public virtual async Task GetByIdAsync(TId id) - { - using (var connection = new SqlConnection(ConnectionString)) + public virtual async Task CreateAsync(T obj) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); + obj.SetNewId(); + using (var connection = new SqlConnection(ConnectionString)) + { + var parameters = new DynamicParameters(); + parameters.AddDynamicParams(obj); + parameters.Add("Id", obj.Id, direction: ParameterDirection.InputOutput); + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Create]", + parameters, + commandType: CommandType.StoredProcedure); + obj.Id = parameters.Get(nameof(obj.Id)); + } + return obj; } - } - public virtual async Task CreateAsync(T obj) - { - obj.SetNewId(); - using (var connection = new SqlConnection(ConnectionString)) + public virtual async Task ReplaceAsync(T obj) { - var parameters = new DynamicParameters(); - parameters.AddDynamicParams(obj); - parameters.Add("Id", obj.Id, direction: ParameterDirection.InputOutput); - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Create]", - parameters, - commandType: CommandType.StoredProcedure); - obj.Id = parameters.Get(nameof(obj.Id)); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Update]", + obj, + commandType: CommandType.StoredProcedure); + } } - return obj; - } - public virtual async Task ReplaceAsync(T obj) - { - using (var connection = new SqlConnection(ConnectionString)) + public virtual async Task UpsertAsync(T obj) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Update]", - obj, - commandType: CommandType.StoredProcedure); + if (obj.Id.Equals(default(TId))) + { + await CreateAsync(obj); + } + else + { + await ReplaceAsync(obj); + } } - } - public virtual async Task UpsertAsync(T obj) - { - if (obj.Id.Equals(default(TId))) + public virtual async Task DeleteAsync(T obj) { - await CreateAsync(obj); - } - else - { - await ReplaceAsync(obj); - } - } - - public virtual async Task DeleteAsync(T obj) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_DeleteById]", - new { Id = obj.Id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_DeleteById]", + new { Id = obj.Id }, + commandType: CommandType.StoredProcedure); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/SendRepository.cs b/src/Infrastructure.Dapper/Repositories/SendRepository.cs index b64af45cda..6d8ba6c190 100644 --- a/src/Infrastructure.Dapper/Repositories/SendRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/SendRepository.cs @@ -5,41 +5,42 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class SendRepository : Repository, ISendRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public SendRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public SendRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserIdAsync(Guid userId) + public class SendRepository : Repository, ISendRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Send_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + public SendRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.ToList(); + public SendRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Send_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } - } - public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) { - var results = await connection.QueryAsync( - $"[{Schema}].[Send_ReadByDeletionDateBefore]", - new { DeletionDate = deletionDateBefore }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Send_ReadByDeletionDateBefore]", + new { DeletionDate = deletionDateBefore }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs b/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs index 3b8a5a904e..70d527c1cf 100644 --- a/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs @@ -5,54 +5,55 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class SsoConfigRepository : Repository, ISsoConfigRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public SsoConfigRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public SsoConfigRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByOrganizationIdAsync(Guid organizationId) + public class SsoConfigRepository : Repository, ISsoConfigRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + public SsoConfigRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.SingleOrDefault(); + public SsoConfigRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - public async Task GetByIdentifierAsync(string identifier) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByIdentifierAsync(string identifier) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIdentifier]", - new { Identifier = identifier }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIdentifier]", + new { Identifier = identifier }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } - } - public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadManyByNotBeforeRevisionDate]", - new { NotBefore = notBefore }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadManyByNotBeforeRevisionDate]", + new { NotBefore = notBefore }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs b/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs index e393762fa5..fd32c708d1 100644 --- a/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs @@ -5,39 +5,40 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class SsoUserRepository : Repository, ISsoUserRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public SsoUserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public SsoUserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task DeleteAsync(Guid userId, Guid? organizationId) + public class SsoUserRepository : Repository, ISsoUserRepository { - using (var connection = new SqlConnection(ConnectionString)) + public SsoUserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public SsoUserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task DeleteAsync(Guid userId, Guid? organizationId) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[SsoUser_Delete]", - new { UserId = userId, OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[SsoUser_Delete]", + new { UserId = userId, OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) { - var results = await connection.QueryAsync( - $"[{Schema}].[SsoUser_ReadByUserIdOrganizationId]", - new { UserId = userId, OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[SsoUser_ReadByUserIdOrganizationId]", + new { UserId = userId, OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs b/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs index 7a9ad7d09b..1c9982e018 100644 --- a/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs @@ -5,64 +5,65 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class TaxRateRepository : Repository, ITaxRateRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public TaxRateRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public TaxRateRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> SearchAsync(int skip, int count) + public class TaxRateRepository : Repository, ITaxRateRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[TaxRate_Search]", - new { Skip = skip, Count = count }, - commandType: CommandType.StoredProcedure); + public TaxRateRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.ToList(); + public TaxRateRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> SearchAsync(int skip, int count) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[TaxRate_Search]", + new { Skip = skip, Count = count }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } - } - public async Task> GetAllActiveAsync() - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetAllActiveAsync() { - var results = await connection.QueryAsync( - $"[{Schema}].[TaxRate_ReadAllActive]", - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[TaxRate_ReadAllActive]", + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task ArchiveAsync(TaxRate model) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task ArchiveAsync(TaxRate model) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[TaxRate_Archive]", - new { Id = model.Id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[TaxRate_Archive]", + new { Id = model.Id }, + commandType: CommandType.StoredProcedure); + } } - } - public async Task> GetByLocationAsync(TaxRate model) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetByLocationAsync(TaxRate model) { - var results = await connection.QueryAsync( - $"[{Schema}].[TaxRate_ReadByLocation]", - new { Country = model.Country, PostalCode = model.PostalCode }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[TaxRate_ReadByLocation]", + new { Country = model.Country, PostalCode = model.PostalCode }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs b/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs index ff9c900bfa..ed8b16f91c 100644 --- a/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs @@ -6,54 +6,55 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class TransactionRepository : Repository, ITransactionRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public TransactionRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public TransactionRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserIdAsync(Guid userId) + public class TransactionRepository : Repository, ITransactionRepository { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + public TransactionRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - return results.ToList(); + public TransactionRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByGatewayId]", - new { Gateway = gatewayType, GatewayId = gatewayId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByGatewayId]", + new { Gateway = gatewayType, GatewayId = gatewayId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } } } diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 19c7a83bea..077fdd59ae 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -6,165 +6,166 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories; - -public class UserRepository : Repository, IUserRepository +namespace Bit.Infrastructure.Dapper.Repositories { - public UserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public UserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public override async Task GetByIdAsync(Guid id) + public class UserRepository : Repository, IUserRepository { - return await base.GetByIdAsync(id); - } + public UserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } - public async Task GetByEmailAsync(string email) - { - using (var connection = new SqlConnection(ConnectionString)) + public UserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public override async Task GetByIdAsync(Guid id) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByEmail]", - new { Email = email }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); + return await base.GetByIdAsync(id); } - } - public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetByEmailAsync(string email) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadBySsoUserOrganizationIdExternalId]", - new { OrganizationId = organizationId, ExternalId = externalId }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByEmail]", + new { Email = email }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } - } - public async Task GetKdfInformationByEmailAsync(string email) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadKdfByEmail]", - new { Email = email }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadBySsoUserOrganizationIdExternalId]", + new { OrganizationId = organizationId, ExternalId = externalId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } - } - public async Task> SearchAsync(string email, int skip, int take) - { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) + public async Task GetKdfInformationByEmailAsync(string email) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_Search]", - new { Email = email, Skip = skip, Take = take }, - commandType: CommandType.StoredProcedure, - commandTimeout: 120); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadKdfByEmail]", + new { Email = email }, + commandType: CommandType.StoredProcedure); - return results.ToList(); + return results.SingleOrDefault(); + } } - } - public async Task> GetManyByPremiumAsync(bool premium) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> SearchAsync(string email, int skip, int take) { - var results = await connection.QueryAsync( - "[dbo].[User_ReadByPremium]", - new { Premium = premium }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_Search]", + new { Email = email, Skip = skip, Take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); - return results.ToList(); + return results.ToList(); + } } - } - public async Task GetPublicKeyAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task> GetManyByPremiumAsync(bool premium) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadPublicKeyById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[User_ReadByPremium]", + new { Premium = premium }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.ToList(); + } } - } - public async Task GetAccountRevisionDateAsync(Guid id) - { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) + public async Task GetPublicKeyAsync(Guid id) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadAccountRevisionDateById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadPublicKeyById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); + return results.SingleOrDefault(); + } } - } - public override async Task ReplaceAsync(User user) - { - await base.ReplaceAsync(user); - } - - public override async Task DeleteAsync(User user) - { - using (var connection = new SqlConnection(ConnectionString)) + public async Task GetAccountRevisionDateAsync(Guid id) { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_DeleteById]", - new { Id = user.Id }, - commandType: CommandType.StoredProcedure, - commandTimeout: 180); + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadAccountRevisionDateById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } } - } - public async Task UpdateStorageAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) + public override async Task ReplaceAsync(User user) { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_UpdateStorage]", - new { Id = id }, - commandType: CommandType.StoredProcedure, - commandTimeout: 180); + await base.ReplaceAsync(user); } - } - public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) - { - using (var connection = new SqlConnection(ConnectionString)) + public override async Task DeleteAsync(User user) { - await connection.ExecuteAsync( - $"[{Schema}].[User_UpdateRenewalReminderDate]", - new { Id = id, RenewalReminderDate = renewalReminderDate }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_DeleteById]", + new { Id = user.Id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); + } } - } - public async Task> GetManyAsync(IEnumerable ids) - { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) + public async Task UpdateStorageAsync(Guid id) { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIds]", - new { Ids = ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_UpdateStorage]", + new { Id = id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); + } + } - return results.ToList(); + public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[User_UpdateRenewalReminderDate]", + new { Id = id, RenewalReminderDate = renewalReminderDate }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task> GetManyAsync(IEnumerable ids) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIds]", + new { Ids = ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } } } } diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index c8a99b2740..259deb2e14 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -5,61 +5,62 @@ using LinqToDB.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework; - -public static class EntityFrameworkServiceCollectionExtensions +namespace Bit.Infrastructure.EntityFramework { - public static void AddEFRepositories(this IServiceCollection services, bool selfHosted, string connectionString, - SupportedDatabaseProviders provider) + public static class EntityFrameworkServiceCollectionExtensions { - if (string.IsNullOrWhiteSpace(connectionString)) + public static void AddEFRepositories(this IServiceCollection services, bool selfHosted, string connectionString, + SupportedDatabaseProviders provider) { - throw new Exception($"Database provider type {provider} was selected but no connection string was found."); - } - LinqToDBForEFTools.Initialize(); - services.AddAutoMapper(typeof(UserRepository)); - services.AddDbContext(options => - { - if (provider == SupportedDatabaseProviders.Postgres) + if (string.IsNullOrWhiteSpace(connectionString)) { - options.UseNpgsql(connectionString); - // Handle NpgSql Legacy Support for `timestamp without timezone` issue - AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); + throw new Exception($"Database provider type {provider} was selected but no connection string was found."); } - else if (provider == SupportedDatabaseProviders.MySql) + LinqToDBForEFTools.Initialize(); + services.AddAutoMapper(typeof(UserRepository)); + services.AddDbContext(options => { - options.UseMySql(connectionString, ServerVersion.AutoDetect(connectionString)); - } - }); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); + if (provider == SupportedDatabaseProviders.Postgres) + { + options.UseNpgsql(connectionString); + // Handle NpgSql Legacy Support for `timestamp without timezone` issue + AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); + } + else if (provider == SupportedDatabaseProviders.MySql) + { + options.UseMySql(connectionString, ServerVersion.AutoDetect(connectionString)); + } + }); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); - if (selfHosted) - { - services.AddSingleton(); + if (selfHosted) + { + services.AddSingleton(); + } } } } diff --git a/src/Infrastructure.EntityFramework/Models/Cipher.cs b/src/Infrastructure.EntityFramework/Models/Cipher.cs index ec5ddc53d0..4cf008d523 100644 --- a/src/Infrastructure.EntityFramework/Models/Cipher.cs +++ b/src/Infrastructure.EntityFramework/Models/Cipher.cs @@ -1,18 +1,19 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Cipher : Core.Entities.Cipher +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual User User { get; set; } - public virtual Organization Organization { get; set; } - public virtual ICollection CollectionCiphers { get; set; } -} - -public class CipherMapperProfile : Profile -{ - public CipherMapperProfile() + public class Cipher : Core.Entities.Cipher { - CreateMap().ReverseMap(); + public virtual User User { get; set; } + public virtual Organization Organization { get; set; } + public virtual ICollection CollectionCiphers { get; set; } + } + + public class CipherMapperProfile : Profile + { + public CipherMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Collection.cs b/src/Infrastructure.EntityFramework/Models/Collection.cs index 29495081d4..2e4337238e 100644 --- a/src/Infrastructure.EntityFramework/Models/Collection.cs +++ b/src/Infrastructure.EntityFramework/Models/Collection.cs @@ -1,19 +1,20 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Collection : Core.Entities.Collection +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } - public virtual ICollection CollectionUsers { get; set; } - public virtual ICollection CollectionCiphers { get; set; } - public virtual ICollection CollectionGroups { get; set; } -} - -public class CollectionMapperProfile : Profile -{ - public CollectionMapperProfile() + public class Collection : Core.Entities.Collection { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + public virtual ICollection CollectionUsers { get; set; } + public virtual ICollection CollectionCiphers { get; set; } + public virtual ICollection CollectionGroups { get; set; } + } + + public class CollectionMapperProfile : Profile + { + public CollectionMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs b/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs index 93d1deae1f..8a7de5a780 100644 --- a/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs +++ b/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class CollectionCipher : Core.Entities.CollectionCipher +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Cipher Cipher { get; set; } - public virtual Collection Collection { get; set; } -} - -public class CollectionCipherMapperProfile : Profile -{ - public CollectionCipherMapperProfile() + public class CollectionCipher : Core.Entities.CollectionCipher { - CreateMap().ReverseMap(); + public virtual Cipher Cipher { get; set; } + public virtual Collection Collection { get; set; } + } + + public class CollectionCipherMapperProfile : Profile + { + public CollectionCipherMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs b/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs index 623a5d8084..fdded35216 100644 --- a/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs +++ b/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class CollectionGroup : Core.Entities.CollectionGroup +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Collection Collection { get; set; } - public virtual Group Group { get; set; } -} - -public class CollectionGroupMapperProfile : Profile -{ - public CollectionGroupMapperProfile() + public class CollectionGroup : Core.Entities.CollectionGroup { - CreateMap().ReverseMap(); + public virtual Collection Collection { get; set; } + public virtual Group Group { get; set; } + } + + public class CollectionGroupMapperProfile : Profile + { + public CollectionGroupMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/CollectionUser.cs b/src/Infrastructure.EntityFramework/Models/CollectionUser.cs index 308673492b..24d10c2a76 100644 --- a/src/Infrastructure.EntityFramework/Models/CollectionUser.cs +++ b/src/Infrastructure.EntityFramework/Models/CollectionUser.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class CollectionUser : Core.Entities.CollectionUser +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Collection Collection { get; set; } - public virtual OrganizationUser OrganizationUser { get; set; } -} - -public class CollectionUserMapperProfile : Profile -{ - public CollectionUserMapperProfile() + public class CollectionUser : Core.Entities.CollectionUser { - CreateMap().ReverseMap(); + public virtual Collection Collection { get; set; } + public virtual OrganizationUser OrganizationUser { get; set; } + } + + public class CollectionUserMapperProfile : Profile + { + public CollectionUserMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Device.cs b/src/Infrastructure.EntityFramework/Models/Device.cs index 1eace238d5..675ed917a8 100644 --- a/src/Infrastructure.EntityFramework/Models/Device.cs +++ b/src/Infrastructure.EntityFramework/Models/Device.cs @@ -1,16 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Device : Core.Entities.Device +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual User User { get; set; } -} - -public class DeviceMapperProfile : Profile -{ - public DeviceMapperProfile() + public class Device : Core.Entities.Device { - CreateMap().ReverseMap(); + public virtual User User { get; set; } + } + + public class DeviceMapperProfile : Profile + { + public DeviceMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs b/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs index 867912c5e3..e92eba8eef 100644 --- a/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs +++ b/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class EmergencyAccess : Core.Entities.EmergencyAccess +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual User Grantee { get; set; } - public virtual User Grantor { get; set; } -} - -public class EmergencyAccessMapperProfile : Profile -{ - public EmergencyAccessMapperProfile() + public class EmergencyAccess : Core.Entities.EmergencyAccess { - CreateMap().ReverseMap(); + public virtual User Grantee { get; set; } + public virtual User Grantor { get; set; } + } + + public class EmergencyAccessMapperProfile : Profile + { + public EmergencyAccessMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Event.cs b/src/Infrastructure.EntityFramework/Models/Event.cs index b7bad9c789..558f2a2856 100644 --- a/src/Infrastructure.EntityFramework/Models/Event.cs +++ b/src/Infrastructure.EntityFramework/Models/Event.cs @@ -1,15 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Event : Core.Entities.Event +namespace Bit.Infrastructure.EntityFramework.Models { -} - -public class EventMapperProfile : Profile -{ - public EventMapperProfile() + public class Event : Core.Entities.Event { - CreateMap().ReverseMap(); + } + + public class EventMapperProfile : Profile + { + public EventMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Folder.cs b/src/Infrastructure.EntityFramework/Models/Folder.cs index 4668337858..1918dfe733 100644 --- a/src/Infrastructure.EntityFramework/Models/Folder.cs +++ b/src/Infrastructure.EntityFramework/Models/Folder.cs @@ -1,16 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Folder : Core.Entities.Folder +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual User User { get; set; } -} - -public class FolderMapperProfile : Profile -{ - public FolderMapperProfile() + public class Folder : Core.Entities.Folder { - CreateMap().ReverseMap(); + public virtual User User { get; set; } + } + + public class FolderMapperProfile : Profile + { + public FolderMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Grant.cs b/src/Infrastructure.EntityFramework/Models/Grant.cs index 78b4b4582f..251d16437e 100644 --- a/src/Infrastructure.EntityFramework/Models/Grant.cs +++ b/src/Infrastructure.EntityFramework/Models/Grant.cs @@ -1,15 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Grant : Core.Entities.Grant +namespace Bit.Infrastructure.EntityFramework.Models { -} - -public class GrantMapperProfile : Profile -{ - public GrantMapperProfile() + public class Grant : Core.Entities.Grant { - CreateMap().ReverseMap(); + } + + public class GrantMapperProfile : Profile + { + public GrantMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Group.cs b/src/Infrastructure.EntityFramework/Models/Group.cs index eaa41bed82..98f4820121 100644 --- a/src/Infrastructure.EntityFramework/Models/Group.cs +++ b/src/Infrastructure.EntityFramework/Models/Group.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Group : Core.Entities.Group +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } - public virtual ICollection GroupUsers { get; set; } -} - -public class GroupMapperProfile : Profile -{ - public GroupMapperProfile() + public class Group : Core.Entities.Group { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + public virtual ICollection GroupUsers { get; set; } + } + + public class GroupMapperProfile : Profile + { + public GroupMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/GroupUser.cs b/src/Infrastructure.EntityFramework/Models/GroupUser.cs index 3f25e7d876..5a81ed884e 100644 --- a/src/Infrastructure.EntityFramework/Models/GroupUser.cs +++ b/src/Infrastructure.EntityFramework/Models/GroupUser.cs @@ -1,18 +1,19 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class GroupUser : Core.Entities.GroupUser +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Group Group { get; set; } - public virtual OrganizationUser OrganizationUser { get; set; } -} - -public class GroupUserMapperProfile : Profile -{ - public GroupUserMapperProfile() + public class GroupUser : Core.Entities.GroupUser { - CreateMap().ReverseMap(); + public virtual Group Group { get; set; } + public virtual OrganizationUser OrganizationUser { get; set; } + } + + public class GroupUserMapperProfile : Profile + { + public GroupUserMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Installation.cs b/src/Infrastructure.EntityFramework/Models/Installation.cs index 35223a33d7..92bbd2abbd 100644 --- a/src/Infrastructure.EntityFramework/Models/Installation.cs +++ b/src/Infrastructure.EntityFramework/Models/Installation.cs @@ -1,15 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Installation : Core.Entities.Installation +namespace Bit.Infrastructure.EntityFramework.Models { -} - -public class InstallationMapperProfile : Profile -{ - public InstallationMapperProfile() + public class Installation : Core.Entities.Installation { - CreateMap().ReverseMap(); + } + + public class InstallationMapperProfile : Profile + { + public InstallationMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Organization.cs b/src/Infrastructure.EntityFramework/Models/Organization.cs index c1969cab0c..3d46027ef2 100644 --- a/src/Infrastructure.EntityFramework/Models/Organization.cs +++ b/src/Infrastructure.EntityFramework/Models/Organization.cs @@ -1,24 +1,25 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Organization : Core.Entities.Organization +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual ICollection Ciphers { get; set; } - public virtual ICollection OrganizationUsers { get; set; } - public virtual ICollection Groups { get; set; } - public virtual ICollection Policies { get; set; } - public virtual ICollection SsoConfigs { get; set; } - public virtual ICollection SsoUsers { get; set; } - public virtual ICollection Transactions { get; set; } - public virtual ICollection ApiKeys { get; set; } - public virtual ICollection Connections { get; set; } -} - -public class OrganizationMapperProfile : Profile -{ - public OrganizationMapperProfile() + public class Organization : Core.Entities.Organization { - CreateMap().ReverseMap(); + public virtual ICollection Ciphers { get; set; } + public virtual ICollection OrganizationUsers { get; set; } + public virtual ICollection Groups { get; set; } + public virtual ICollection Policies { get; set; } + public virtual ICollection SsoConfigs { get; set; } + public virtual ICollection SsoUsers { get; set; } + public virtual ICollection Transactions { get; set; } + public virtual ICollection ApiKeys { get; set; } + public virtual ICollection Connections { get; set; } + } + + public class OrganizationMapperProfile : Profile + { + public OrganizationMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs b/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs index b8a4f4e746..c0e6c33e05 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs @@ -1,16 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class OrganizationApiKey : Core.Entities.OrganizationApiKey +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } -} - -public class OrganizationApiKeyMapperProfile : Profile -{ - public OrganizationApiKeyMapperProfile() + public class OrganizationApiKey : Core.Entities.OrganizationApiKey { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + } + + public class OrganizationApiKeyMapperProfile : Profile + { + public OrganizationApiKeyMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs b/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs index 5c41d5f6c1..f53ee711c0 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs @@ -1,16 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class OrganizationConnection : Core.Entities.OrganizationConnection +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } -} - -public class OrganizationConnectionMapperProfile : Profile -{ - public OrganizationConnectionMapperProfile() + public class OrganizationConnection : Core.Entities.OrganizationConnection { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + } + + public class OrganizationConnectionMapperProfile : Profile + { + public OrganizationConnectionMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs b/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs index 3d8b8acf77..c9eee03e50 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class OrganizationSponsorship : Core.Entities.OrganizationSponsorship +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization SponsoringOrganization { get; set; } - public virtual Organization SponsoredOrganization { get; set; } -} - -public class OrganizationSponsorshipMapperProfile : Profile -{ - public OrganizationSponsorshipMapperProfile() + public class OrganizationSponsorship : Core.Entities.OrganizationSponsorship { - CreateMap().ReverseMap(); + public virtual Organization SponsoringOrganization { get; set; } + public virtual Organization SponsoredOrganization { get; set; } + } + + public class OrganizationSponsorshipMapperProfile : Profile + { + public OrganizationSponsorshipMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs b/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs index abab1a4d5a..f1489bd462 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs @@ -1,18 +1,19 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class OrganizationUser : Core.Entities.OrganizationUser +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } - public virtual ICollection CollectionUsers { get; set; } -} - -public class OrganizationUserMapperProfile : Profile -{ - public OrganizationUserMapperProfile() + public class OrganizationUser : Core.Entities.OrganizationUser { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } + public virtual ICollection CollectionUsers { get; set; } + } + + public class OrganizationUserMapperProfile : Profile + { + public OrganizationUserMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Policy.cs b/src/Infrastructure.EntityFramework/Models/Policy.cs index 22b17c6f6d..953556cddc 100644 --- a/src/Infrastructure.EntityFramework/Models/Policy.cs +++ b/src/Infrastructure.EntityFramework/Models/Policy.cs @@ -1,16 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Policy : Core.Entities.Policy +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } -} - -public class PolicyMapperProfile : Profile -{ - public PolicyMapperProfile() + public class Policy : Core.Entities.Policy { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + } + + public class PolicyMapperProfile : Profile + { + public PolicyMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs b/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs index d639d6d01d..8efa1558dd 100644 --- a/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs +++ b/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs @@ -1,15 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Provider : Core.Entities.Provider.Provider +namespace Bit.Infrastructure.EntityFramework.Models { -} - -public class ProviderMapperProfile : Profile -{ - public ProviderMapperProfile() + public class Provider : Core.Entities.Provider.Provider { - CreateMap().ReverseMap(); + } + + public class ProviderMapperProfile : Profile + { + public ProviderMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs b/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs index af23ba978c..13aa521104 100644 --- a/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs +++ b/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class ProviderOrganization : Core.Entities.Provider.ProviderOrganization +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Provider Provider { get; set; } - public virtual Organization Organization { get; set; } -} - -public class ProviderOrganizationMapperProfile : Profile -{ - public ProviderOrganizationMapperProfile() + public class ProviderOrganization : Core.Entities.Provider.ProviderOrganization { - CreateMap().ReverseMap(); + public virtual Provider Provider { get; set; } + public virtual Organization Organization { get; set; } + } + + public class ProviderOrganizationMapperProfile : Profile + { + public ProviderOrganizationMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs b/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs index 5c53c4d979..9aac138be0 100644 --- a/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs +++ b/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class ProviderUser : Core.Entities.Provider.ProviderUser +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual User User { get; set; } - public virtual Provider Provider { get; set; } -} - -public class ProviderUserMapperProfile : Profile -{ - public ProviderUserMapperProfile() + public class ProviderUser : Core.Entities.Provider.ProviderUser { - CreateMap().ReverseMap(); + public virtual User User { get; set; } + public virtual Provider Provider { get; set; } + } + + public class ProviderUserMapperProfile : Profile + { + public ProviderUserMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Role.cs b/src/Infrastructure.EntityFramework/Models/Role.cs index 4cc2e099c3..a92682e2e9 100644 --- a/src/Infrastructure.EntityFramework/Models/Role.cs +++ b/src/Infrastructure.EntityFramework/Models/Role.cs @@ -1,15 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Role : Core.Entities.Role +namespace Bit.Infrastructure.EntityFramework.Models { -} - -public class RoleMapperProfile : Profile -{ - public RoleMapperProfile() + public class Role : Core.Entities.Role { - CreateMap().ReverseMap(); + } + + public class RoleMapperProfile : Profile + { + public RoleMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Send.cs b/src/Infrastructure.EntityFramework/Models/Send.cs index 13bfbb61b5..5732ac2a1c 100644 --- a/src/Infrastructure.EntityFramework/Models/Send.cs +++ b/src/Infrastructure.EntityFramework/Models/Send.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Send : Core.Entities.Send +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } -} - -public class SendMapperProfile : Profile -{ - public SendMapperProfile() + public class Send : Core.Entities.Send { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } + } + + public class SendMapperProfile : Profile + { + public SendMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/SsoConfig.cs b/src/Infrastructure.EntityFramework/Models/SsoConfig.cs index 70e007b992..d748934f2d 100644 --- a/src/Infrastructure.EntityFramework/Models/SsoConfig.cs +++ b/src/Infrastructure.EntityFramework/Models/SsoConfig.cs @@ -1,16 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class SsoConfig : Core.Entities.SsoConfig +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } -} - -public class SsoConfigMapperProfile : Profile -{ - public SsoConfigMapperProfile() + public class SsoConfig : Core.Entities.SsoConfig { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + } + + public class SsoConfigMapperProfile : Profile + { + public SsoConfigMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/SsoUser.cs b/src/Infrastructure.EntityFramework/Models/SsoUser.cs index 01333dbcae..eb02984422 100644 --- a/src/Infrastructure.EntityFramework/Models/SsoUser.cs +++ b/src/Infrastructure.EntityFramework/Models/SsoUser.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class SsoUser : Core.Entities.SsoUser +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } -} - -public class SsoUserMapperProfile : Profile -{ - public SsoUserMapperProfile() + public class SsoUser : Core.Entities.SsoUser { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } + } + + public class SsoUserMapperProfile : Profile + { + public SsoUserMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/TaxRate.cs b/src/Infrastructure.EntityFramework/Models/TaxRate.cs index d47a92237a..f464724aec 100644 --- a/src/Infrastructure.EntityFramework/Models/TaxRate.cs +++ b/src/Infrastructure.EntityFramework/Models/TaxRate.cs @@ -1,15 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class TaxRate : Core.Entities.TaxRate +namespace Bit.Infrastructure.EntityFramework.Models { -} - -public class TaxRateMapperProfile : Profile -{ - public TaxRateMapperProfile() + public class TaxRate : Core.Entities.TaxRate { - CreateMap().ReverseMap(); + } + + public class TaxRateMapperProfile : Profile + { + public TaxRateMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/Transaction.cs b/src/Infrastructure.EntityFramework/Models/Transaction.cs index 4eb63646c3..b9d4bc954d 100644 --- a/src/Infrastructure.EntityFramework/Models/Transaction.cs +++ b/src/Infrastructure.EntityFramework/Models/Transaction.cs @@ -1,17 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class Transaction : Core.Entities.Transaction +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } -} - -public class TransactionMapperProfile : Profile -{ - public TransactionMapperProfile() + public class Transaction : Core.Entities.Transaction { - CreateMap().ReverseMap(); + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } + } + + public class TransactionMapperProfile : Profile + { + public TransactionMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/User.cs b/src/Infrastructure.EntityFramework/Models/User.cs index 1316acccfb..9ff81e90b4 100644 --- a/src/Infrastructure.EntityFramework/Models/User.cs +++ b/src/Infrastructure.EntityFramework/Models/User.cs @@ -1,22 +1,23 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models; - -public class User : Core.Entities.User +namespace Bit.Infrastructure.EntityFramework.Models { - public virtual ICollection Ciphers { get; set; } - public virtual ICollection Folders { get; set; } - public virtual ICollection CollectionUsers { get; set; } - public virtual ICollection GroupUsers { get; set; } - public virtual ICollection OrganizationUsers { get; set; } - public virtual ICollection SsoUsers { get; set; } - public virtual ICollection Transactions { get; set; } -} - -public class UserMapperProfile : Profile -{ - public UserMapperProfile() + public class User : Core.Entities.User { - CreateMap().ReverseMap(); + public virtual ICollection Ciphers { get; set; } + public virtual ICollection Folders { get; set; } + public virtual ICollection CollectionUsers { get; set; } + public virtual ICollection GroupUsers { get; set; } + public virtual ICollection OrganizationUsers { get; set; } + public virtual ICollection SsoUsers { get; set; } + public virtual ICollection Transactions { get; set; } + } + + public class UserMapperProfile : Profile + { + public UserMapperProfile() + { + CreateMap().ReverseMap(); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs index 9dc7818d73..833994fc59 100644 --- a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs @@ -10,63 +10,80 @@ using Microsoft.Extensions.DependencyInjection; using Cipher = Bit.Core.Entities.Cipher; using User = Bit.Core.Entities.User; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public abstract class BaseEntityFrameworkRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - protected BulkCopyOptions DefaultBulkCopyOptions { get; set; } = new BulkCopyOptions + public abstract class BaseEntityFrameworkRepository { - KeepIdentity = true, - BulkCopyType = BulkCopyType.MultipleRows, - }; - - public BaseEntityFrameworkRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - { - ServiceScopeFactory = serviceScopeFactory; - Mapper = mapper; - } - - protected IServiceScopeFactory ServiceScopeFactory { get; private set; } - protected IMapper Mapper { get; private set; } - - public DatabaseContext GetDatabaseContext(IServiceScope serviceScope) - { - return serviceScope.ServiceProvider.GetRequiredService(); - } - - public void ClearChangeTracking() - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected BulkCopyOptions DefaultBulkCopyOptions { get; set; } = new BulkCopyOptions { - var dbContext = GetDatabaseContext(scope); - dbContext.ChangeTracker.Clear(); + KeepIdentity = true, + BulkCopyType = BulkCopyType.MultipleRows, + }; + + public BaseEntityFrameworkRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + { + ServiceScopeFactory = serviceScopeFactory; + Mapper = mapper; } - } - public async Task GetCountFromQuery(IQuery query) - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected IServiceScopeFactory ServiceScopeFactory { get; private set; } + protected IMapper Mapper { get; private set; } + + public DatabaseContext GetDatabaseContext(IServiceScope serviceScope) { - return await query.Run(GetDatabaseContext(scope)).CountAsync(); + return serviceScope.ServiceProvider.GetRequiredService(); } - } - protected async Task UserBumpAccountRevisionDateByCipherId(Cipher cipher) - { - var list = new List { cipher }; - await UserBumpAccountRevisionDateByCipherId(list); - } - - protected async Task UserBumpAccountRevisionDateByCipherId(IEnumerable ciphers) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public void ClearChangeTracking() { - foreach (var cipher in ciphers) + using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); - var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); - var users = query.Run(dbContext); + dbContext.ChangeTracker.Clear(); + } + } + public async Task GetCountFromQuery(IQuery query) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + return await query.Run(GetDatabaseContext(scope)).CountAsync(); + } + } + + protected async Task UserBumpAccountRevisionDateByCipherId(Cipher cipher) + { + var list = new List { cipher }; + await UserBumpAccountRevisionDateByCipherId(list); + } + + protected async Task UserBumpAccountRevisionDateByCipherId(IEnumerable ciphers) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + foreach (var cipher in ciphers) + { + var dbContext = GetDatabaseContext(scope); + var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); + var users = query.Run(dbContext); + + await users.ForEachAsync(e => + { + dbContext.Attach(e); + e.RevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); + } + } + } + + protected async Task UserBumpAccountRevisionDateByOrganizationId(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new UserBumpAccountRevisionDateByOrganizationIdQuery(organizationId); + var users = query.Run(dbContext); await users.ForEachAsync(e => { dbContext.Attach(e); @@ -75,191 +92,175 @@ public abstract class BaseEntityFrameworkRepository await dbContext.SaveChangesAsync(); } } - } - protected async Task UserBumpAccountRevisionDateByOrganizationId(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected async Task UserBumpAccountRevisionDate(Guid userId) { - var dbContext = GetDatabaseContext(scope); - var query = new UserBumpAccountRevisionDateByOrganizationIdQuery(organizationId); - var users = query.Run(dbContext); - await users.ForEachAsync(e => - { - dbContext.Attach(e); - e.RevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); + await UserBumpManyAccountRevisionDates(new[] { userId }); } - } - protected async Task UserBumpAccountRevisionDate(Guid userId) - { - await UserBumpManyAccountRevisionDates(new[] { userId }); - } - - protected async Task UserBumpManyAccountRevisionDates(ICollection userIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected async Task UserBumpManyAccountRevisionDates(ICollection userIds) { - var dbContext = GetDatabaseContext(scope); - var users = dbContext.Users.Where(u => userIds.Contains(u.Id)); - await users.ForEachAsync(u => + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.Attach(u); - u.RevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); - } - } - - protected async Task OrganizationUpdateStorage(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var attachments = await dbContext.Ciphers - .Where(e => e.UserId == null && - e.OrganizationId == organizationId && - !string.IsNullOrWhiteSpace(e.Attachments)) - .Select(e => e.Attachments) - .ToListAsync(); - var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() - .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); - var organization = new Organization - { - Id = organizationId, - RevisionDate = DateTime.UtcNow, - Storage = storage, - }; - dbContext.Organizations.Attach(organization); - var entry = dbContext.Entry(organization); - entry.Property(e => e.RevisionDate).IsModified = true; - entry.Property(e => e.Storage).IsModified = true; - await dbContext.SaveChangesAsync(); - } - } - - protected async Task UserUpdateStorage(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var attachments = await dbContext.Ciphers - .Where(e => e.UserId.HasValue && - e.UserId.Value == userId && - e.OrganizationId == null && - !string.IsNullOrWhiteSpace(e.Attachments)) - .Select(e => e.Attachments) - .ToListAsync(); - var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() - .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); - var user = new Models.User - { - Id = userId, - RevisionDate = DateTime.UtcNow, - Storage = storage, - }; - dbContext.Users.Attach(user); - var entry = dbContext.Entry(user); - entry.Property(e => e.RevisionDate).IsModified = true; - entry.Property(e => e.Storage).IsModified = true; - await dbContext.SaveChangesAsync(); - } - } - - protected async Task UserUpdateKeys(User user) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await dbContext.Users.FindAsync(user.Id); - if (entity == null) - { - return; + var dbContext = GetDatabaseContext(scope); + var users = dbContext.Users.Where(u => userIds.Contains(u.Id)); + await users.ForEachAsync(u => + { + dbContext.Attach(u); + u.RevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); } - entity.SecurityStamp = user.SecurityStamp; - entity.Key = user.Key; - entity.PrivateKey = user.PrivateKey; - entity.RevisionDate = DateTime.UtcNow; - await dbContext.SaveChangesAsync(); } - } - protected async Task UserBumpAccountRevisionDateByCollectionId(Guid collectionId, Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected async Task OrganizationUpdateStorage(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.CollectionId.Equals(collectionId) - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == default(Guid) && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == collectionId && - (ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed && - (cu.CollectionId != default(Guid) || cg.CollectionId != default(Guid) || ou.AccessAll || g.AccessAll)) - select new { u, ou, cu, gu, g, cg }; - var users = query.Select(x => x.u); - await users.ForEachAsync(u => + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.Attach(u); - u.RevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); + var dbContext = GetDatabaseContext(scope); + var attachments = await dbContext.Ciphers + .Where(e => e.UserId == null && + e.OrganizationId == organizationId && + !string.IsNullOrWhiteSpace(e.Attachments)) + .Select(e => e.Attachments) + .ToListAsync(); + var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() + .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); + var organization = new Organization + { + Id = organizationId, + RevisionDate = DateTime.UtcNow, + Storage = storage, + }; + dbContext.Organizations.Attach(organization); + var entry = dbContext.Entry(organization); + entry.Property(e => e.RevisionDate).IsModified = true; + entry.Property(e => e.Storage).IsModified = true; + await dbContext.SaveChangesAsync(); + } } - } - protected async Task UserBumpAccountRevisionDateByOrganizationUserId(Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected async Task UserUpdateStorage(Guid userId) { - var dbContext = GetDatabaseContext(scope); - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - where ou.Id.Equals(organizationUserId) && ou.Status.Equals(OrganizationUserStatusType.Confirmed) - select new { u, ou }; - var users = query.Select(x => x.u); - await users.ForEachAsync(u => + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.Attach(u); - u.AccountRevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); + var dbContext = GetDatabaseContext(scope); + var attachments = await dbContext.Ciphers + .Where(e => e.UserId.HasValue && + e.UserId.Value == userId && + e.OrganizationId == null && + !string.IsNullOrWhiteSpace(e.Attachments)) + .Select(e => e.Attachments) + .ToListAsync(); + var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() + .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); + var user = new Models.User + { + Id = userId, + RevisionDate = DateTime.UtcNow, + Storage = storage, + }; + dbContext.Users.Attach(user); + var entry = dbContext.Entry(user); + entry.Property(e => e.RevisionDate).IsModified = true; + entry.Property(e => e.Storage).IsModified = true; + await dbContext.SaveChangesAsync(); + } } - } - protected async Task UserBumpAccountRevisionDateByProviderUserIds(ICollection providerUserIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + protected async Task UserUpdateKeys(User user) { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id - where pu.Status.Equals(ProviderUserStatusType.Confirmed) && - providerUserIds.Contains(pu.Id) - select new { pu, u }; - var users = query.Select(x => x.u); - await users.ForEachAsync(u => + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.Attach(u); - u.AccountRevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.Users.FindAsync(user.Id); + if (entity == null) + { + return; + } + entity.SecurityStamp = user.SecurityStamp; + entity.Key = user.Key; + entity.PrivateKey = user.PrivateKey; + entity.RevisionDate = DateTime.UtcNow; + await dbContext.SaveChangesAsync(); + } + } + + protected async Task UserBumpAccountRevisionDateByCollectionId(Guid collectionId, Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.CollectionId.Equals(collectionId) + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == default(Guid) && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == collectionId && + (ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed && + (cu.CollectionId != default(Guid) || cg.CollectionId != default(Guid) || ou.AccessAll || g.AccessAll)) + select new { u, ou, cu, gu, g, cg }; + var users = query.Select(x => x.u); + await users.ForEachAsync(u => + { + dbContext.Attach(u); + u.RevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); + } + } + + protected async Task UserBumpAccountRevisionDateByOrganizationUserId(Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + where ou.Id.Equals(organizationUserId) && ou.Status.Equals(OrganizationUserStatusType.Confirmed) + select new { u, ou }; + var users = query.Select(x => x.u); + await users.ForEachAsync(u => + { + dbContext.Attach(u); + u.AccountRevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); + } + } + + protected async Task UserBumpAccountRevisionDateByProviderUserIds(ICollection providerUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id + where pu.Status.Equals(ProviderUserStatusType.Confirmed) && + providerUserIds.Contains(pu.Id) + select new { pu, u }; + var users = query.Select(x => x.u); + await users.ForEachAsync(u => + { + dbContext.Attach(u); + u.AccountRevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs index 17aaedfac9..fdf528393f 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs @@ -13,637 +13,638 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using User = Bit.Core.Entities.User; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class CipherRepository : Repository, ICipherRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public CipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Ciphers) - { } - - public override async Task CreateAsync(Core.Entities.Cipher cipher) + public class CipherRepository : Repository, ICipherRepository { - cipher = await base.CreateAsync(cipher); - using (var scope = ServiceScopeFactory.CreateScope()) + public CipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Ciphers) + { } + + public override async Task CreateAsync(Core.Entities.Cipher cipher) { - var dbContext = GetDatabaseContext(scope); - if (cipher.OrganizationId.HasValue) + cipher = await base.CreateAsync(cipher); + using (var scope = ServiceScopeFactory.CreateScope()) { - await UserBumpAccountRevisionDateByCipherId(cipher); - } - else if (cipher.UserId.HasValue) - { - await UserBumpAccountRevisionDate(cipher.UserId.Value); - } - } - return cipher; - } - - public IQueryable GetBumpedAccountsByCipherId(Core.Entities.Cipher cipher) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); - return query.Run(dbContext); - } - } - - public async Task CreateAsync(Core.Entities.Cipher cipher, IEnumerable collectionIds) - { - cipher = await base.CreateAsync(cipher); - await UpdateCollections(cipher, collectionIds); - } - - private async Task UpdateCollections(Core.Entities.Cipher cipher, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipherEntity = await dbContext.Ciphers.FindAsync(cipher.Id); - var query = new CipherUpdateCollectionsQuery(cipherEntity, collectionIds).Run(dbContext); - await dbContext.AddRangeAsync(query); - await dbContext.SaveChangesAsync(); - } - } - - public async Task CreateAsync(CipherDetails cipher) - { - await CreateAsyncReturnCipher(cipher); - } - - private async Task CreateAsyncReturnCipher(CipherDetails cipher) - { - cipher.SetNewId(); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var userIdKey = $"\"{cipher.UserId}\""; - cipher.UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId; - cipher.Favorites = cipher.Favorite ? - $"{{{userIdKey}:true}}" : - null; - cipher.Folders = cipher.FolderId.HasValue ? - $"{{{userIdKey}:\"{cipher.FolderId}\"}}" : - null; - var entity = Mapper.Map((Core.Entities.Cipher)cipher); - await dbContext.AddAsync(entity); - await dbContext.SaveChangesAsync(); - } - await UserBumpAccountRevisionDateByCipherId(cipher); - return cipher; - } - - public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) - { - cipher = await CreateAsyncReturnCipher(cipher); - await UpdateCollections(cipher, collectionIds); - } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) - { - if (!ciphers.Any()) - { - return; - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var folderEntities = Mapper.Map>(folders); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); - var cipherEntities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); - await UserBumpAccountRevisionDateByCipherId(ciphers); - } - } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers) - { - if (!ciphers.Any()) - { - return; - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipherEntities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); - if (collections.Any()) - { - var collectionEntities = Mapper.Map>(collections); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionEntities); - - if (collectionCiphers.Any()) + var dbContext = GetDatabaseContext(scope); + if (cipher.OrganizationId.HasValue) { - var collectionCipherEntities = Mapper.Map>(collectionCiphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities); + await UserBumpAccountRevisionDateByCipherId(cipher); + } + else if (cipher.UserId.HasValue) + { + await UserBumpAccountRevisionDate(cipher.UserId.Value); } } - await UserBumpAccountRevisionDateByOrganizationId(ciphers.First().OrganizationId.Value); + return cipher; } - } - public async Task DeleteAsync(IEnumerable ids, Guid userId) - { - await ToggleCipherStates(ids, userId, CipherStateAction.HardDelete); - } - - public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public IQueryable GetBumpedAccountsByCipherId(Core.Entities.Cipher cipher) { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(cipherId); - var attachmentsJson = JObject.Parse(cipher.Attachments); - attachmentsJson.Remove(attachmentId); - cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); - await dbContext.SaveChangesAsync(); - - if (cipher.OrganizationId.HasValue) + using (var scope = ServiceScopeFactory.CreateScope()) { - await OrganizationUpdateStorage(cipher.OrganizationId.Value); - await UserBumpAccountRevisionDateByCipherId(cipher); - } - else if (cipher.UserId.HasValue) - { - await UserUpdateStorage(cipher.UserId.Value); - await UserBumpAccountRevisionDate(cipher.UserId.Value); + var dbContext = GetDatabaseContext(scope); + var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); + return query.Run(dbContext); } } - } - public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task CreateAsync(Core.Entities.Cipher cipher, IEnumerable collectionIds) { - var dbContext = GetDatabaseContext(scope); - var ciphers = from c in dbContext.Ciphers - where c.OrganizationId == organizationId && - ids.Contains(c.Id) - select c; - dbContext.RemoveRange(ciphers); - await dbContext.SaveChangesAsync(); - } - await OrganizationUpdateStorage(organizationId); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - - public async Task DeleteByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - - var collectionCiphers = from cc in dbContext.CollectionCiphers - join c in dbContext.Collections - on cc.CollectionId equals c.Id - where c.OrganizationId == organizationId - select cc; - dbContext.RemoveRange(collectionCiphers); - - var ciphers = from c in dbContext.Ciphers - where c.OrganizationId == organizationId - select c; - dbContext.RemoveRange(ciphers); - - await dbContext.SaveChangesAsync(); - } - await OrganizationUpdateStorage(organizationId); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - - public async Task DeleteByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ciphers = from c in dbContext.Ciphers - where c.UserId == userId - select c; - dbContext.RemoveRange(ciphers); - var folders = from f in dbContext.Folders - where f.UserId == userId - select f; - dbContext.RemoveRange(folders); - await dbContext.SaveChangesAsync(); - await UserUpdateStorage(userId); - await UserBumpAccountRevisionDate(userId); + cipher = await base.CreateAsync(cipher); + await UpdateCollections(cipher, collectionIds); } - } - - public async Task DeleteDeletedAsync(DateTime deletedDateBefore) - { - using (var scope = ServiceScopeFactory.CreateScope()) + private async Task UpdateCollections(Core.Entities.Cipher cipher, IEnumerable collectionIds) { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Ciphers.Where(c => c.DeletedDate < deletedDateBefore); - dbContext.RemoveRange(query); - await dbContext.SaveChangesAsync(); - } - } - - public async Task GetByIdAsync(Guid id, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var userCipherDetails = new UserCipherDetailsQuery(userId); - var data = await userCipherDetails.Run(dbContext).FirstOrDefaultAsync(c => c.Id == id); - return data; - } - } - - public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( - Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new CipherOrganizationDetailsReadByIdQuery(organizationId); - var data = await query.Run(dbContext).ToListAsync(); - return data; - } - } - - public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new CipherReadCanEditByIdUserIdQuery(userId, cipherId); - var canEdit = await query.Run(dbContext).AnyAsync(); - return canEdit; - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Ciphers.Where(x => !x.UserId.HasValue && x.OrganizationId == organizationId); - var data = await query.ToListAsync(); - return Mapper.Map>(data); - } - } - - public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - IQueryable cipherDetailsView = withOrganizations ? - new UserCipherDetailsQuery(userId).Run(dbContext) : - new CipherDetailsQuery(userId).Run(dbContext); - if (!withOrganizations) + using (var scope = ServiceScopeFactory.CreateScope()) { - cipherDetailsView = from c in cipherDetailsView - where c.UserId == userId - select new CipherDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - Favorite = c.Favorite, - FolderId = c.FolderId, - Edit = true, - ViewPassword = true, - OrganizationUseTotp = false, - }; - } - var ciphers = await cipherDetailsView.ToListAsync(); - return ciphers; - } - } - - public async Task GetOrganizationDetailsByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new CipherOrganizationDetailsReadByIdQuery(id); - var data = await query.Run(dbContext).FirstOrDefaultAsync(); - return data; - } - } - - public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipherEntities = dbContext.Ciphers.Where(c => ids.Contains(c.Id)); - var userCipherDetails = new UserCipherDetailsQuery(userId).Run(dbContext); - var idsToMove = from ucd in userCipherDetails - join c in cipherEntities - on ucd.Id equals c.Id - where ucd.Edit - select c; - await idsToMove.ForEachAsync(cipher => - { - var foldersJson = string.IsNullOrWhiteSpace(cipher.Folders) ? - new JObject() : - JObject.Parse(cipher.Folders); - - if (folderId.HasValue) - { - foldersJson.Remove(userId.ToString()); - foldersJson.Add(userId.ToString(), folderId.Value.ToString()); - } - else if (!string.IsNullOrWhiteSpace(cipher.Folders)) - { - foldersJson.Remove(userId.ToString()); - } - dbContext.Attach(cipher); - cipher.Folders = JsonConvert.SerializeObject(foldersJson); - }); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDate(userId); - } - } - - public async Task ReplaceAsync(CipherDetails cipher) - { - cipher.UserId = cipher.OrganizationId.HasValue ? - null : - cipher.UserId; - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await dbContext.Ciphers.FindAsync(cipher.Id); - if (entity != null) - { - var userIdKey = $"\"{cipher.UserId}\""; - if (cipher.Favorite) - { - if (cipher.Favorites == null) - { - cipher.Favorites = $"{{{userIdKey}:true}}"; - } - else - { - var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); - favorites.Add(cipher.UserId.Value, true); - cipher.Favorites = JsonConvert.SerializeObject(favorites); - } - } - else - { - if (cipher.Favorites != null && cipher.Favorites.Contains(cipher.UserId.Value.ToString())) - { - var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); - favorites.Remove(cipher.UserId.Value); - cipher.Favorites = JsonConvert.SerializeObject(favorites); - } - } - if (cipher.FolderId.HasValue) - { - if (cipher.Folders == null) - { - cipher.Folders = $"{{{userIdKey}:\"{cipher.FolderId}\"}}"; - } - else - { - var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Folders); - folders.Add(cipher.UserId.Value, cipher.FolderId.Value); - cipher.Folders = JsonConvert.SerializeObject(folders); - } - } - else - { - if (cipher.Folders != null && cipher.Folders.Contains(cipher.UserId.Value.ToString())) - { - var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); - folders.Remove(cipher.UserId.Value); - cipher.Favorites = JsonConvert.SerializeObject(folders); - } - } - var mappedEntity = Mapper.Map((Core.Entities.Cipher)cipher); - dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); - await UserBumpAccountRevisionDateByCipherId(cipher); + var dbContext = GetDatabaseContext(scope); + var cipherEntity = await dbContext.Ciphers.FindAsync(cipher.Id); + var query = new CipherUpdateCollectionsQuery(cipherEntity, collectionIds).Run(dbContext); + await dbContext.AddRangeAsync(query); await dbContext.SaveChangesAsync(); } } - } - public async Task ReplaceAsync(Core.Entities.Cipher obj, IEnumerable collectionIds) - { - await UpdateCollections(obj, collectionIds); - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task CreateAsync(CipherDetails cipher) { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(obj.Id); - cipher.UserId = null; - cipher.OrganizationId = obj.OrganizationId; - cipher.Data = obj.Data; - cipher.Attachments = obj.Attachments; - cipher.RevisionDate = obj.RevisionDate; - cipher.DeletedDate = obj.DeletedDate; - await dbContext.SaveChangesAsync(); + await CreateAsyncReturnCipher(cipher); + } - if (!string.IsNullOrWhiteSpace(cipher.Attachments)) + private async Task CreateAsyncReturnCipher(CipherDetails cipher) + { + cipher.SetNewId(); + using (var scope = ServiceScopeFactory.CreateScope()) { + var dbContext = GetDatabaseContext(scope); + var userIdKey = $"\"{cipher.UserId}\""; + cipher.UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId; + cipher.Favorites = cipher.Favorite ? + $"{{{userIdKey}:true}}" : + null; + cipher.Folders = cipher.FolderId.HasValue ? + $"{{{userIdKey}:\"{cipher.FolderId}\"}}" : + null; + var entity = Mapper.Map((Core.Entities.Cipher)cipher); + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + } + await UserBumpAccountRevisionDateByCipherId(cipher); + return cipher; + } + + public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) + { + cipher = await CreateAsyncReturnCipher(cipher); + await UpdateCollections(cipher, collectionIds); + } + + public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) + { + if (!ciphers.Any()) + { + return; + } + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var folderEntities = Mapper.Map>(folders); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); + var cipherEntities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); + await UserBumpAccountRevisionDateByCipherId(ciphers); + } + } + + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers) + { + if (!ciphers.Any()) + { + return; + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipherEntities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); + if (collections.Any()) + { + var collectionEntities = Mapper.Map>(collections); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionEntities); + + if (collectionCiphers.Any()) + { + var collectionCipherEntities = Mapper.Map>(collectionCiphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities); + } + } + await UserBumpAccountRevisionDateByOrganizationId(ciphers.First().OrganizationId.Value); + } + } + + public async Task DeleteAsync(IEnumerable ids, Guid userId) + { + await ToggleCipherStates(ids, userId, CipherStateAction.HardDelete); + } + + public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(cipherId); + var attachmentsJson = JObject.Parse(cipher.Attachments); + attachmentsJson.Remove(attachmentId); + cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); + await dbContext.SaveChangesAsync(); + if (cipher.OrganizationId.HasValue) { await OrganizationUpdateStorage(cipher.OrganizationId.Value); + await UserBumpAccountRevisionDateByCipherId(cipher); } else if (cipher.UserId.HasValue) { await UserUpdateStorage(cipher.UserId.Value); + await UserBumpAccountRevisionDate(cipher.UserId.Value); } } - - await UserBumpAccountRevisionDateByCipherId(cipher); - return true; } - } - public async Task RestoreAsync(IEnumerable ids, Guid userId) - { - return await ToggleCipherStates(ids, userId, CipherStateAction.Restore); - } - - public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) - { - await ToggleCipherStates(ids, userId, CipherStateAction.SoftDelete); - } - - private async Task ToggleCipherStates(IEnumerable ids, Guid userId, CipherStateAction action) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var userCipherDetailsQuery = new UserCipherDetailsQuery(userId); - var cipherEntitiesToCheck = await (dbContext.Ciphers.Where(c => ids.Contains(c.Id))).ToListAsync(); - var query = from ucd in await (userCipherDetailsQuery.Run(dbContext)).ToListAsync() - join c in cipherEntitiesToCheck - on ucd.Id equals c.Id - where ucd.Edit && ucd.DeletedDate == null - select c; - - var utcNow = DateTime.UtcNow; - var cipherIdsToModify = query.Select(c => c.Id); - var cipherEntitiesToModify = dbContext.Ciphers.Where(x => cipherIdsToModify.Contains(x.Id)); - if (action == CipherStateAction.HardDelete) + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.RemoveRange(cipherEntitiesToModify); + var dbContext = GetDatabaseContext(scope); + var ciphers = from c in dbContext.Ciphers + where c.OrganizationId == organizationId && + ids.Contains(c.Id) + select c; + dbContext.RemoveRange(ciphers); + await dbContext.SaveChangesAsync(); } - else - { - await cipherEntitiesToModify.ForEachAsync(cipher => - { - dbContext.Attach(cipher); - cipher.DeletedDate = action == CipherStateAction.Restore ? null : utcNow; - cipher.RevisionDate = utcNow; - }); - } - - var orgIds = query - .Where(c => c.OrganizationId.HasValue) - .GroupBy(c => c.OrganizationId).Select(x => x.Key); - - foreach (var orgId in orgIds) - { - await OrganizationUpdateStorage(orgId.Value); - await UserBumpAccountRevisionDateByOrganizationId(orgId.Value); - } - if (query.Any(c => c.UserId.HasValue && !string.IsNullOrWhiteSpace(c.Attachments))) - { - await UserUpdateStorage(userId); - } - await UserBumpAccountRevisionDate(userId); - await dbContext.SaveChangesAsync(); - return utcNow; - } - } - - public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var utcNow = DateTime.UtcNow; - var ciphers = dbContext.Ciphers.Where(c => ids.Contains(c.Id) && c.OrganizationId == organizationId); - await ciphers.ForEachAsync(cipher => - { - dbContext.Attach(cipher); - cipher.DeletedDate = utcNow; - cipher.RevisionDate = utcNow; - }); - await dbContext.SaveChangesAsync(); await OrganizationUpdateStorage(organizationId); await UserBumpAccountRevisionDateByOrganizationId(organizationId); } - } - public async Task UpdateAttachmentAsync(CipherAttachment attachment) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(attachment.Id); - var attachmentsJson = string.IsNullOrWhiteSpace(cipher.Attachments) ? new JObject() : JObject.Parse(cipher.Attachments); - attachmentsJson.Add(attachment.AttachmentId, attachment.AttachmentData); - cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); - await dbContext.SaveChangesAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); - if (attachment.OrganizationId.HasValue) - { - await OrganizationUpdateStorage(cipher.OrganizationId.Value); - await UserBumpAccountRevisionDateByCipherId(new List { cipher }); + var collectionCiphers = from cc in dbContext.CollectionCiphers + join c in dbContext.Collections + on cc.CollectionId equals c.Id + where c.OrganizationId == organizationId + select cc; + dbContext.RemoveRange(collectionCiphers); + + var ciphers = from c in dbContext.Ciphers + where c.OrganizationId == organizationId + select c; + dbContext.RemoveRange(ciphers); + + await dbContext.SaveChangesAsync(); } - else if (attachment.UserId.HasValue) + await OrganizationUpdateStorage(organizationId); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + + public async Task DeleteByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - await UserUpdateStorage(attachment.UserId.Value); - await UserBumpAccountRevisionDate(attachment.UserId.Value); + var dbContext = GetDatabaseContext(scope); + var ciphers = from c in dbContext.Ciphers + where c.UserId == userId + select c; + dbContext.RemoveRange(ciphers); + var folders = from f in dbContext.Folders + where f.UserId == userId + select f; + dbContext.RemoveRange(folders); + await dbContext.SaveChangesAsync(); + await UserUpdateStorage(userId); + await UserBumpAccountRevisionDate(userId); + } + + } + + public async Task DeleteDeletedAsync(DateTime deletedDateBefore) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Ciphers.Where(c => c.DeletedDate < deletedDateBefore); + dbContext.RemoveRange(query); + await dbContext.SaveChangesAsync(); } } - } - public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) - { - if (!ciphers.Any()) + public async Task GetByIdAsync(Guid id, Guid userId) { - return; - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities); - await UserBumpAccountRevisionDate(userId); - } - } - - public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(id); - - var foldersJson = JObject.Parse(cipher.Folders); - if (foldersJson == null && folderId.HasValue) + using (var scope = ServiceScopeFactory.CreateScope()) { - foldersJson.Add(userId.ToString(), folderId.Value); + var dbContext = GetDatabaseContext(scope); + var userCipherDetails = new UserCipherDetailsQuery(userId); + var data = await userCipherDetails.Run(dbContext).FirstOrDefaultAsync(c => c.Id == id); + return data; } - else if (foldersJson != null && folderId.HasValue) + } + + public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( + Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - foldersJson[userId] = folderId.Value; + var dbContext = GetDatabaseContext(scope); + var query = new CipherOrganizationDetailsReadByIdQuery(organizationId); + var data = await query.Run(dbContext).ToListAsync(); + return data; + } + } + + public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new CipherReadCanEditByIdUserIdQuery(userId, cipherId); + var canEdit = await query.Run(dbContext).AnyAsync(); + return canEdit; + } + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Ciphers.Where(x => !x.UserId.HasValue && x.OrganizationId == organizationId); + var data = await query.ToListAsync(); + return Mapper.Map>(data); + } + } + + public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + IQueryable cipherDetailsView = withOrganizations ? + new UserCipherDetailsQuery(userId).Run(dbContext) : + new CipherDetailsQuery(userId).Run(dbContext); + if (!withOrganizations) + { + cipherDetailsView = from c in cipherDetailsView + where c.UserId == userId + select new CipherDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + Favorite = c.Favorite, + FolderId = c.FolderId, + Edit = true, + ViewPassword = true, + OrganizationUseTotp = false, + }; + } + var ciphers = await cipherDetailsView.ToListAsync(); + return ciphers; + } + } + + public async Task GetOrganizationDetailsByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new CipherOrganizationDetailsReadByIdQuery(id); + var data = await query.Run(dbContext).FirstOrDefaultAsync(); + return data; + } + } + + public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipherEntities = dbContext.Ciphers.Where(c => ids.Contains(c.Id)); + var userCipherDetails = new UserCipherDetailsQuery(userId).Run(dbContext); + var idsToMove = from ucd in userCipherDetails + join c in cipherEntities + on ucd.Id equals c.Id + where ucd.Edit + select c; + await idsToMove.ForEachAsync(cipher => + { + var foldersJson = string.IsNullOrWhiteSpace(cipher.Folders) ? + new JObject() : + JObject.Parse(cipher.Folders); + + if (folderId.HasValue) + { + foldersJson.Remove(userId.ToString()); + foldersJson.Add(userId.ToString(), folderId.Value.ToString()); + } + else if (!string.IsNullOrWhiteSpace(cipher.Folders)) + { + foldersJson.Remove(userId.ToString()); + } + dbContext.Attach(cipher); + cipher.Folders = JsonConvert.SerializeObject(foldersJson); + }); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDate(userId); + } + } + + public async Task ReplaceAsync(CipherDetails cipher) + { + cipher.UserId = cipher.OrganizationId.HasValue ? + null : + cipher.UserId; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.Ciphers.FindAsync(cipher.Id); + if (entity != null) + { + var userIdKey = $"\"{cipher.UserId}\""; + if (cipher.Favorite) + { + if (cipher.Favorites == null) + { + cipher.Favorites = $"{{{userIdKey}:true}}"; + } + else + { + var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); + favorites.Add(cipher.UserId.Value, true); + cipher.Favorites = JsonConvert.SerializeObject(favorites); + } + } + else + { + if (cipher.Favorites != null && cipher.Favorites.Contains(cipher.UserId.Value.ToString())) + { + var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); + favorites.Remove(cipher.UserId.Value); + cipher.Favorites = JsonConvert.SerializeObject(favorites); + } + } + if (cipher.FolderId.HasValue) + { + if (cipher.Folders == null) + { + cipher.Folders = $"{{{userIdKey}:\"{cipher.FolderId}\"}}"; + } + else + { + var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Folders); + folders.Add(cipher.UserId.Value, cipher.FolderId.Value); + cipher.Folders = JsonConvert.SerializeObject(folders); + } + } + else + { + if (cipher.Folders != null && cipher.Folders.Contains(cipher.UserId.Value.ToString())) + { + var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); + folders.Remove(cipher.UserId.Value); + cipher.Favorites = JsonConvert.SerializeObject(folders); + } + } + var mappedEntity = Mapper.Map((Core.Entities.Cipher)cipher); + dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); + await UserBumpAccountRevisionDateByCipherId(cipher); + await dbContext.SaveChangesAsync(); + } + } + } + + public async Task ReplaceAsync(Core.Entities.Cipher obj, IEnumerable collectionIds) + { + await UpdateCollections(obj, collectionIds); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(obj.Id); + cipher.UserId = null; + cipher.OrganizationId = obj.OrganizationId; + cipher.Data = obj.Data; + cipher.Attachments = obj.Attachments; + cipher.RevisionDate = obj.RevisionDate; + cipher.DeletedDate = obj.DeletedDate; + await dbContext.SaveChangesAsync(); + + if (!string.IsNullOrWhiteSpace(cipher.Attachments)) + { + if (cipher.OrganizationId.HasValue) + { + await OrganizationUpdateStorage(cipher.OrganizationId.Value); + } + else if (cipher.UserId.HasValue) + { + await UserUpdateStorage(cipher.UserId.Value); + } + } + + await UserBumpAccountRevisionDateByCipherId(cipher); + return true; + } + } + + public async Task RestoreAsync(IEnumerable ids, Guid userId) + { + return await ToggleCipherStates(ids, userId, CipherStateAction.Restore); + } + + public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) + { + await ToggleCipherStates(ids, userId, CipherStateAction.SoftDelete); + } + + private async Task ToggleCipherStates(IEnumerable ids, Guid userId, CipherStateAction action) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var userCipherDetailsQuery = new UserCipherDetailsQuery(userId); + var cipherEntitiesToCheck = await (dbContext.Ciphers.Where(c => ids.Contains(c.Id))).ToListAsync(); + var query = from ucd in await (userCipherDetailsQuery.Run(dbContext)).ToListAsync() + join c in cipherEntitiesToCheck + on ucd.Id equals c.Id + where ucd.Edit && ucd.DeletedDate == null + select c; + + var utcNow = DateTime.UtcNow; + var cipherIdsToModify = query.Select(c => c.Id); + var cipherEntitiesToModify = dbContext.Ciphers.Where(x => cipherIdsToModify.Contains(x.Id)); + if (action == CipherStateAction.HardDelete) + { + dbContext.RemoveRange(cipherEntitiesToModify); + } + else + { + await cipherEntitiesToModify.ForEachAsync(cipher => + { + dbContext.Attach(cipher); + cipher.DeletedDate = action == CipherStateAction.Restore ? null : utcNow; + cipher.RevisionDate = utcNow; + }); + } + + var orgIds = query + .Where(c => c.OrganizationId.HasValue) + .GroupBy(c => c.OrganizationId).Select(x => x.Key); + + foreach (var orgId in orgIds) + { + await OrganizationUpdateStorage(orgId.Value); + await UserBumpAccountRevisionDateByOrganizationId(orgId.Value); + } + if (query.Any(c => c.UserId.HasValue && !string.IsNullOrWhiteSpace(c.Attachments))) + { + await UserUpdateStorage(userId); + } + await UserBumpAccountRevisionDate(userId); + await dbContext.SaveChangesAsync(); + return utcNow; + } + } + + public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var utcNow = DateTime.UtcNow; + var ciphers = dbContext.Ciphers.Where(c => ids.Contains(c.Id) && c.OrganizationId == organizationId); + await ciphers.ForEachAsync(cipher => + { + dbContext.Attach(cipher); + cipher.DeletedDate = utcNow; + cipher.RevisionDate = utcNow; + }); + await dbContext.SaveChangesAsync(); + await OrganizationUpdateStorage(organizationId); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + } + + public async Task UpdateAttachmentAsync(CipherAttachment attachment) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(attachment.Id); + var attachmentsJson = string.IsNullOrWhiteSpace(cipher.Attachments) ? new JObject() : JObject.Parse(cipher.Attachments); + attachmentsJson.Add(attachment.AttachmentId, attachment.AttachmentData); + cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); + await dbContext.SaveChangesAsync(); + + if (attachment.OrganizationId.HasValue) + { + await OrganizationUpdateStorage(cipher.OrganizationId.Value); + await UserBumpAccountRevisionDateByCipherId(new List { cipher }); + } + else if (attachment.UserId.HasValue) + { + await UserUpdateStorage(attachment.UserId.Value); + await UserBumpAccountRevisionDate(attachment.UserId.Value); + } + } + } + + public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) + { + if (!ciphers.Any()) + { + return; + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities); + await UserBumpAccountRevisionDate(userId); + } + } + + public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(id); + + var foldersJson = JObject.Parse(cipher.Folders); + if (foldersJson == null && folderId.HasValue) + { + foldersJson.Add(userId.ToString(), folderId.Value); + } + else if (foldersJson != null && folderId.HasValue) + { + foldersJson[userId] = folderId.Value; + } + else + { + foldersJson.Remove(userId.ToString()); + } + + var favoritesJson = JObject.Parse(cipher.Favorites); + if (favorite) + { + favoritesJson.Add(userId.ToString(), favorite); + } + else + { + favoritesJson.Remove(userId.ToString()); + } + + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDate(userId); + } + } + + public async Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + await UserUpdateKeys(user); + var cipherEntities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); + var folderEntities = Mapper.Map>(folders); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); + var sendEntities = Mapper.Map>(sends); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, sendEntities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task UpsertAsync(CipherDetails cipher) + { + if (cipher.Id.Equals(default)) + { + await CreateAsync(cipher); } else { - foldersJson.Remove(userId.ToString()); + await ReplaceAsync(cipher); } - - var favoritesJson = JObject.Parse(cipher.Favorites); - if (favorite) - { - favoritesJson.Add(userId.ToString(), favorite); - } - else - { - favoritesJson.Remove(userId.ToString()); - } - - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDate(userId); - } - } - - public async Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - await UserUpdateKeys(user); - var cipherEntities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); - var folderEntities = Mapper.Map>(folders); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); - var sendEntities = Mapper.Map>(sends); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, sendEntities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task UpsertAsync(CipherDetails cipher) - { - if (cipher.Id.Equals(default)) - { - await CreateAsync(cipher); - } - else - { - await ReplaceAsync(cipher); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs index fd23237a9c..1d717ce2ef 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs @@ -6,232 +6,233 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using CollectionCipher = Bit.Core.Entities.CollectionCipher; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollectionCipherRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public CollectionCipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper) - { } - - public async Task CreateAsync(CollectionCipher obj) + public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollectionCipherRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public CollectionCipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper) + { } + + public async Task CreateAsync(CollectionCipher obj) { - var dbContext = GetDatabaseContext(scope); - var entity = Mapper.Map(obj); - dbContext.Add(entity); - await dbContext.SaveChangesAsync(); - var organizationId = (await dbContext.Ciphers.FirstOrDefaultAsync(c => c.Id.Equals(obj.CipherId))).OrganizationId; - if (organizationId.HasValue) + using (var scope = ServiceScopeFactory.CreateScope()) { - await UserBumpAccountRevisionDateByCollectionId(obj.CollectionId, organizationId.Value); - } - return obj; - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await (from cc in dbContext.CollectionCiphers - join c in dbContext.Collections - on cc.CollectionId equals c.Id - where c.OrganizationId == organizationId - select cc).ToArrayAsync(); - return data; - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await new CollectionCipherReadByUserIdQuery(userId) - .Run(dbContext) - .ToArrayAsync(); - return data; - } - } - - public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await new CollectionCipherReadByUserIdCipherIdQuery(userId, cipherId) - .Run(dbContext) - .ToArrayAsync(); - return data; - } - } - - public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizationId = (await dbContext.Ciphers.FindAsync(cipherId)).OrganizationId; - var availableCollectionsCte = from c in dbContext.Collections - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId - where ou.UserId == userId - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.CollectionId == c.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == c.Id && - (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && ( - ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) - select new { c, o, cu, gu, g, cg }; - var target = from cc in dbContext.CollectionCiphers - where cc.CipherId == cipherId - select new { cc.CollectionId, cc.CipherId }; - var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); - var merge1 = from t in target - join s in source - on t.CollectionId equals s.CollectionId into s_g - from s in s_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var merge2 = from s in source - join t in target - on s.CollectionId equals t.CollectionId into t_g - from t in t_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var union = merge1.Union(merge2).Distinct(); - var insert = union - .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) - .Select(x => new Models.CollectionCipher + var dbContext = GetDatabaseContext(scope); + var entity = Mapper.Map(obj); + dbContext.Add(entity); + await dbContext.SaveChangesAsync(); + var organizationId = (await dbContext.Ciphers.FirstOrDefaultAsync(c => c.Id.Equals(obj.CipherId))).OrganizationId; + if (organizationId.HasValue) { - CollectionId = x.s.CollectionId, - CipherId = x.s.CipherId, - }); - var delete = union - .Where(x => x.s == null && x.t.CipherId == cipherId && collectionIds.Contains(x.t.CollectionId)) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.t.CollectionId, - CipherId = x.t.CipherId, - }); - await dbContext.AddRangeAsync(insert); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - - if (organizationId.HasValue) - { - await UserBumpAccountRevisionDateByOrganizationId(organizationId.Value); + await UserBumpAccountRevisionDateByCollectionId(obj.CollectionId, organizationId.Value); + } + return obj; } } - } - public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var availableCollectionsCte = from c in dbContext.Collections - where c.OrganizationId == organizationId - select c; - var target = from cc in dbContext.CollectionCiphers - where cc.CipherId == cipherId - select new { cc.CollectionId, cc.CipherId }; - var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); - var merge1 = from t in target - join s in source - on t.CollectionId equals s.CollectionId into s_g - from s in s_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var merge2 = from s in source - join t in target - on s.CollectionId equals t.CollectionId into t_g - from t in t_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var union = merge1.Union(merge2).Distinct(); - var insert = union - .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.s.CollectionId, - CipherId = x.s.CipherId, - }); - var delete = union - .Where(x => x.s == null && x.t.CipherId == cipherId) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.t.CollectionId, - CipherId = x.t.CipherId, - }); - await dbContext.AddRangeAsync(insert); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - } - - public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var availibleCollections = from c in dbContext.Collections - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId - where ou.UserId == userId - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.CollectionId == c.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == c.Id && - (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && - (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) - select new { c, o, ou, cu, gu, g, cg }; - var count = await availibleCollections.CountAsync(); - if (await availibleCollections.CountAsync() < 1) + using (var scope = ServiceScopeFactory.CreateScope()) { - return; + var dbContext = GetDatabaseContext(scope); + var data = await (from cc in dbContext.CollectionCiphers + join c in dbContext.Collections + on cc.CollectionId equals c.Id + where c.OrganizationId == organizationId + select cc).ToArrayAsync(); + return data; } + } - var insertData = from collectionId in collectionIds - from cipherId in cipherIds - where availibleCollections.Select(x => x.c.Id).Contains(collectionId) - select new Models.CollectionCipher - { - CollectionId = collectionId, - CipherId = cipherId, - }; - await dbContext.AddRangeAsync(insertData); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var data = await new CollectionCipherReadByUserIdQuery(userId) + .Run(dbContext) + .ToArrayAsync(); + return data; + } + } + + public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var data = await new CollectionCipherReadByUserIdCipherIdQuery(userId, cipherId) + .Run(dbContext) + .ToArrayAsync(); + return data; + } + } + + public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var organizationId = (await dbContext.Ciphers.FindAsync(cipherId)).OrganizationId; + var availableCollectionsCte = from c in dbContext.Collections + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId + where ou.UserId == userId + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.CollectionId == c.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == c.Id && + (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && ( + ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) + select new { c, o, cu, gu, g, cg }; + var target = from cc in dbContext.CollectionCiphers + where cc.CipherId == cipherId + select new { cc.CollectionId, cc.CipherId }; + var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); + var merge1 = from t in target + join s in source + on t.CollectionId equals s.CollectionId into s_g + from s in s_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var merge2 = from s in source + join t in target + on s.CollectionId equals t.CollectionId into t_g + from t in t_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var union = merge1.Union(merge2).Distinct(); + var insert = union + .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) + .Select(x => new Models.CollectionCipher + { + CollectionId = x.s.CollectionId, + CipherId = x.s.CipherId, + }); + var delete = union + .Where(x => x.s == null && x.t.CipherId == cipherId && collectionIds.Contains(x.t.CollectionId)) + .Select(x => new Models.CollectionCipher + { + CollectionId = x.t.CollectionId, + CipherId = x.t.CipherId, + }); + await dbContext.AddRangeAsync(insert); + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + + if (organizationId.HasValue) + { + await UserBumpAccountRevisionDateByOrganizationId(organizationId.Value); + } + } + } + + public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var availableCollectionsCte = from c in dbContext.Collections + where c.OrganizationId == organizationId + select c; + var target = from cc in dbContext.CollectionCiphers + where cc.CipherId == cipherId + select new { cc.CollectionId, cc.CipherId }; + var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); + var merge1 = from t in target + join s in source + on t.CollectionId equals s.CollectionId into s_g + from s in s_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var merge2 = from s in source + join t in target + on s.CollectionId equals t.CollectionId into t_g + from t in t_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var union = merge1.Union(merge2).Distinct(); + var insert = union + .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) + .Select(x => new Models.CollectionCipher + { + CollectionId = x.s.CollectionId, + CipherId = x.s.CipherId, + }); + var delete = union + .Where(x => x.s == null && x.t.CipherId == cipherId) + .Select(x => new Models.CollectionCipher + { + CollectionId = x.t.CollectionId, + CipherId = x.t.CipherId, + }); + await dbContext.AddRangeAsync(insert); + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + } + + public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var availibleCollections = from c in dbContext.Collections + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId + where ou.UserId == userId + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.CollectionId == c.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == c.Id && + (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && + (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) + select new { c, o, ou, cu, gu, g, cg }; + var count = await availibleCollections.CountAsync(); + if (await availibleCollections.CountAsync() < 1) + { + return; + } + + var insertData = from collectionId in collectionIds + from cipherId in cipherIds + where availibleCollections.Select(x => x.c.Id).Contains(collectionId) + select new Models.CollectionCipher + { + CollectionId = collectionId, + CipherId = cipherId, + }; + await dbContext.AddRangeAsync(insertData); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index d8338b470b..74d714bb11 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -6,244 +6,245 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class CollectionRepository : Repository, ICollectionRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public CollectionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Collections) - { } - - public override async Task CreateAsync(Core.Entities.Collection obj) + public class CollectionRepository : Repository, ICollectionRepository { - await base.CreateAsync(obj); - await UserBumpAccountRevisionDateByCollectionId(obj.Id, obj.OrganizationId); - return obj; - } + public CollectionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Collections) + { } - public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable groups) - { - await base.CreateAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) + public override async Task CreateAsync(Core.Entities.Collection obj) { - var dbContext = GetDatabaseContext(scope); - var availibleGroups = await (from g in dbContext.Groups - where g.OrganizationId == obj.OrganizationId - select g.Id).ToListAsync(); - var collectionGroups = groups - .Where(g => availibleGroups.Contains(g.Id)) - .Select(g => new CollectionGroup - { - CollectionId = obj.Id, - GroupId = g.Id, - ReadOnly = g.ReadOnly, - HidePasswords = g.HidePasswords, - }); - await dbContext.AddRangeAsync(collectionGroups); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); + await base.CreateAsync(obj); + await UserBumpAccountRevisionDateByCollectionId(obj.Id, obj.OrganizationId); + return obj; } - } - public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable groups) { - var dbContext = GetDatabaseContext(scope); - var query = from cu in dbContext.CollectionUsers - where cu.CollectionId == collectionId && - cu.OrganizationUserId == organizationUserId - select cu; - dbContext.RemoveRange(await query.ToListAsync()); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationUserId(organizationUserId); - } - } - - public async Task GetByIdAsync(Guid id, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return (await GetManyByUserIdAsync(userId)).FirstOrDefault(c => c.Id == id); - } - } - - public async Task>> GetByIdWithGroupsAsync(Guid id) - { - var collection = await base.GetByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var collectionGroups = await (from cg in dbContext.CollectionGroups - where cg.CollectionId == id - select cg).ToListAsync(); - var selectionReadOnlys = collectionGroups.Select(cg => new SelectionReadOnly + await base.CreateAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - Id = cg.GroupId, - ReadOnly = cg.ReadOnly, - HidePasswords = cg.HidePasswords, - }).ToList(); - return new Tuple>(collection, selectionReadOnlys); + var dbContext = GetDatabaseContext(scope); + var availibleGroups = await (from g in dbContext.Groups + where g.OrganizationId == obj.OrganizationId + select g.Id).ToListAsync(); + var collectionGroups = groups + .Where(g => availibleGroups.Contains(g.Id)) + .Select(g => new CollectionGroup + { + CollectionId = obj.Id, + GroupId = g.Id, + ReadOnly = g.ReadOnly, + HidePasswords = g.HidePasswords, + }); + await dbContext.AddRangeAsync(collectionGroups); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); + } } - } - public async Task>> GetByIdWithGroupsAsync(Guid id, Guid userId) - { - var collection = await GetByIdAsync(id, userId); - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) { - var dbContext = GetDatabaseContext(scope); - var query = from cg in dbContext.CollectionGroups - where cg.CollectionId.Equals(id) - select new SelectionReadOnly - { - Id = cg.GroupId, - ReadOnly = cg.ReadOnly, - HidePasswords = cg.HidePasswords, - }; - var configurations = await query.ToArrayAsync(); - return new Tuple>(collection, configurations); - } - } - - public async Task GetCountByOrganizationIdAsync(Guid organizationId) - { - var query = new CollectionReadCountByOrganizationIdQuery(organizationId); - return await GetCountFromQuery(query); - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from c in dbContext.Collections - where c.OrganizationId == organizationId - select c; - var collections = await query.ToArrayAsync(); - return collections; - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return (await new UserCollectionDetailsQuery(userId).Run(dbContext).ToListAsync()) - .GroupBy(c => c.Id) - .Select(g => new CollectionDetails - { - Id = g.Key, - OrganizationId = g.FirstOrDefault().OrganizationId, - Name = g.FirstOrDefault().Name, - ExternalId = g.FirstOrDefault().ExternalId, - CreationDate = g.FirstOrDefault().CreationDate, - RevisionDate = g.FirstOrDefault().RevisionDate, - ReadOnly = g.Min(c => c.ReadOnly), - HidePasswords = g.Min(c => c.HidePasswords) - }).ToList(); - } - } - - public async Task> GetManyUsersByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from cu in dbContext.CollectionUsers - where cu.CollectionId == id - select cu; - var collectionUsers = await query.ToListAsync(); - return collectionUsers.Select(cu => new SelectionReadOnly + using (var scope = ServiceScopeFactory.CreateScope()) { - Id = cu.OrganizationUserId, - ReadOnly = cu.ReadOnly, - HidePasswords = cu.HidePasswords, - }).ToArray(); + var dbContext = GetDatabaseContext(scope); + var query = from cu in dbContext.CollectionUsers + where cu.CollectionId == collectionId && + cu.OrganizationUserId == organizationUserId + select cu; + dbContext.RemoveRange(await query.ToListAsync()); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationUserId(organizationUserId); + } } - } - public async Task ReplaceAsync(Core.Entities.Collection collection, IEnumerable groups) - { - await base.ReplaceAsync(collection); - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByIdAsync(Guid id, Guid userId) { - var dbContext = GetDatabaseContext(scope); - var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId); - var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id)); - var target = (from cg in dbContext.CollectionGroups - join g in modifiedGroupEntities - on cg.CollectionId equals collection.Id into s_g - from g in s_g.DefaultIfEmpty() - where g == null || cg.GroupId == g.Id - select new { cg, g }).AsNoTracking(); - var source = (from g in modifiedGroupEntities - from cg in dbContext.CollectionGroups - .Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty() - select new { cg, g }).AsNoTracking(); - var union = await target - .Union(source) - .Where(x => - x.cg == null || - ((x.g == null || x.g.Id == x.cg.GroupId) && - (x.cg.CollectionId == collection.Id))) - .AsNoTracking() - .ToListAsync(); - var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id)) - .Select(x => new CollectionGroup + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return (await GetManyByUserIdAsync(userId)).FirstOrDefault(c => c.Id == id); + } + } + + public async Task>> GetByIdWithGroupsAsync(Guid id) + { + var collection = await base.GetByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var collectionGroups = await (from cg in dbContext.CollectionGroups + where cg.CollectionId == id + select cg).ToListAsync(); + var selectionReadOnlys = collectionGroups.Select(cg => new SelectionReadOnly { - CollectionId = collection.Id, - GroupId = x.g.Id, - ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, - HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, + Id = cg.GroupId, + ReadOnly = cg.ReadOnly, + HidePasswords = cg.HidePasswords, }).ToList(); - var update = union - .Where( - x => x.g != null && - x.cg != null && - (x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly || - x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords) - ) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.g.Id, - ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, - HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, - }); - var delete = union - .Where( - x => x.g == null && - x.cg.CollectionId == collection.Id - ) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.cg.GroupId, - }) - .ToList(); - - await dbContext.AddRangeAsync(insert); - dbContext.UpdateRange(update); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByCollectionId(collection.Id, collection.OrganizationId); + return new Tuple>(collection, selectionReadOnlys); + } } - } - public async Task UpdateUsersAsync(Guid id, IEnumerable users) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task>> GetByIdWithGroupsAsync(Guid id, Guid userId) { - var dbContext = GetDatabaseContext(scope); - var procedure = new CollectionUserUpdateUsersQuery(id, users); - var updateData = await procedure.Update.BuildInMemory(dbContext); - dbContext.UpdateRange(updateData); - var insertData = await procedure.Insert.BuildInMemory(dbContext); - await dbContext.AddRangeAsync(insertData); - dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); + var collection = await GetByIdAsync(id, userId); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from cg in dbContext.CollectionGroups + where cg.CollectionId.Equals(id) + select new SelectionReadOnly + { + Id = cg.GroupId, + ReadOnly = cg.ReadOnly, + HidePasswords = cg.HidePasswords, + }; + var configurations = await query.ToArrayAsync(); + return new Tuple>(collection, configurations); + } + } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) + { + var query = new CollectionReadCountByOrganizationIdQuery(organizationId); + return await GetCountFromQuery(query); + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from c in dbContext.Collections + where c.OrganizationId == organizationId + select c; + var collections = await query.ToArrayAsync(); + return collections; + } + } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return (await new UserCollectionDetailsQuery(userId).Run(dbContext).ToListAsync()) + .GroupBy(c => c.Id) + .Select(g => new CollectionDetails + { + Id = g.Key, + OrganizationId = g.FirstOrDefault().OrganizationId, + Name = g.FirstOrDefault().Name, + ExternalId = g.FirstOrDefault().ExternalId, + CreationDate = g.FirstOrDefault().CreationDate, + RevisionDate = g.FirstOrDefault().RevisionDate, + ReadOnly = g.Min(c => c.ReadOnly), + HidePasswords = g.Min(c => c.HidePasswords) + }).ToList(); + } + } + + public async Task> GetManyUsersByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from cu in dbContext.CollectionUsers + where cu.CollectionId == id + select cu; + var collectionUsers = await query.ToListAsync(); + return collectionUsers.Select(cu => new SelectionReadOnly + { + Id = cu.OrganizationUserId, + ReadOnly = cu.ReadOnly, + HidePasswords = cu.HidePasswords, + }).ToArray(); + } + } + + public async Task ReplaceAsync(Core.Entities.Collection collection, IEnumerable groups) + { + await base.ReplaceAsync(collection); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId); + var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id)); + var target = (from cg in dbContext.CollectionGroups + join g in modifiedGroupEntities + on cg.CollectionId equals collection.Id into s_g + from g in s_g.DefaultIfEmpty() + where g == null || cg.GroupId == g.Id + select new { cg, g }).AsNoTracking(); + var source = (from g in modifiedGroupEntities + from cg in dbContext.CollectionGroups + .Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty() + select new { cg, g }).AsNoTracking(); + var union = await target + .Union(source) + .Where(x => + x.cg == null || + ((x.g == null || x.g.Id == x.cg.GroupId) && + (x.cg.CollectionId == collection.Id))) + .AsNoTracking() + .ToListAsync(); + var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id)) + .Select(x => new CollectionGroup + { + CollectionId = collection.Id, + GroupId = x.g.Id, + ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, + HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, + }).ToList(); + var update = union + .Where( + x => x.g != null && + x.cg != null && + (x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly || + x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords) + ) + .Select(x => new CollectionGroup + { + CollectionId = collection.Id, + GroupId = x.g.Id, + ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, + HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, + }); + var delete = union + .Where( + x => x.g == null && + x.cg.CollectionId == collection.Id + ) + .Select(x => new CollectionGroup + { + CollectionId = collection.Id, + GroupId = x.cg.GroupId, + }) + .ToList(); + + await dbContext.AddRangeAsync(insert); + dbContext.UpdateRange(update); + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByCollectionId(collection.Id, collection.OrganizationId); + } + } + + public async Task UpdateUsersAsync(Guid id, IEnumerable users) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var procedure = new CollectionUserUpdateUsersQuery(id, users); + var updateData = await procedure.Update.BuildInMemory(dbContext); + dbContext.UpdateRange(updateData); + var insertData = await procedure.Insert.BuildInMemory(dbContext); + await dbContext.AddRangeAsync(insertData); + dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 88c2bb464a..8d3af7a4f3 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -1,139 +1,140 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class DatabaseContext : DbContext +namespace Bit.Infrastructure.EntityFramework.Repositories { - public const string postgresIndetermanisticCollation = "postgresIndetermanisticCollation"; - - public DatabaseContext(DbContextOptions options) - : base(options) - { } - - public DbSet Ciphers { get; set; } - public DbSet Collections { get; set; } - public DbSet CollectionCiphers { get; set; } - public DbSet CollectionGroups { get; set; } - public DbSet CollectionUsers { get; set; } - public DbSet Devices { get; set; } - public DbSet EmergencyAccesses { get; set; } - public DbSet Events { get; set; } - public DbSet Folders { get; set; } - public DbSet Grants { get; set; } - public DbSet Groups { get; set; } - public DbSet GroupUsers { get; set; } - public DbSet Installations { get; set; } - public DbSet Organizations { get; set; } - public DbSet OrganizationApiKeys { get; set; } - public DbSet OrganizationSponsorships { get; set; } - public DbSet OrganizationConnections { get; set; } - public DbSet OrganizationUsers { get; set; } - public DbSet Policies { get; set; } - public DbSet Providers { get; set; } - public DbSet ProviderUsers { get; set; } - public DbSet ProviderOrganizations { get; set; } - public DbSet Sends { get; set; } - public DbSet SsoConfigs { get; set; } - public DbSet SsoUsers { get; set; } - public DbSet TaxRates { get; set; } - public DbSet Transactions { get; set; } - public DbSet Users { get; set; } - - protected override void OnModelCreating(ModelBuilder builder) + public class DatabaseContext : DbContext { - var eCipher = builder.Entity(); - var eCollection = builder.Entity(); - var eCollectionCipher = builder.Entity(); - var eCollectionUser = builder.Entity(); - var eCollectionGroup = builder.Entity(); - var eDevice = builder.Entity(); - var eEmergencyAccess = builder.Entity(); - var eEvent = builder.Entity(); - var eFolder = builder.Entity(); - var eGrant = builder.Entity(); - var eGroup = builder.Entity(); - var eGroupUser = builder.Entity(); - var eInstallation = builder.Entity(); - var eOrganization = builder.Entity(); - var eOrganizationSponsorship = builder.Entity(); - var eOrganizationUser = builder.Entity(); - var ePolicy = builder.Entity(); - var eProvider = builder.Entity(); - var eProviderUser = builder.Entity(); - var eProviderOrganization = builder.Entity(); - var eSend = builder.Entity(); - var eSsoConfig = builder.Entity(); - var eSsoUser = builder.Entity(); - var eTaxRate = builder.Entity(); - var eTransaction = builder.Entity(); - var eUser = builder.Entity(); - var eOrganizationApiKey = builder.Entity(); - var eOrganizationConnection = builder.Entity(); + public const string postgresIndetermanisticCollation = "postgresIndetermanisticCollation"; - eCipher.Property(c => c.Id).ValueGeneratedNever(); - eCollection.Property(c => c.Id).ValueGeneratedNever(); - eEmergencyAccess.Property(c => c.Id).ValueGeneratedNever(); - eEvent.Property(c => c.Id).ValueGeneratedNever(); - eFolder.Property(c => c.Id).ValueGeneratedNever(); - eGroup.Property(c => c.Id).ValueGeneratedNever(); - eInstallation.Property(c => c.Id).ValueGeneratedNever(); - eOrganization.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationSponsorship.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationUser.Property(c => c.Id).ValueGeneratedNever(); - ePolicy.Property(c => c.Id).ValueGeneratedNever(); - eProvider.Property(c => c.Id).ValueGeneratedNever(); - eProviderUser.Property(c => c.Id).ValueGeneratedNever(); - eProviderOrganization.Property(c => c.Id).ValueGeneratedNever(); - eSend.Property(c => c.Id).ValueGeneratedNever(); - eTransaction.Property(c => c.Id).ValueGeneratedNever(); - eUser.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationApiKey.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationConnection.Property(c => c.Id).ValueGeneratedNever(); + public DatabaseContext(DbContextOptions options) + : base(options) + { } - eCollectionCipher.HasKey(cc => new { cc.CollectionId, cc.CipherId }); - eCollectionUser.HasKey(cu => new { cu.CollectionId, cu.OrganizationUserId }); - eCollectionGroup.HasKey(cg => new { cg.CollectionId, cg.GroupId }); - eGrant.HasKey(x => x.Key); - eGroupUser.HasKey(gu => new { gu.GroupId, gu.OrganizationUserId }); + public DbSet Ciphers { get; set; } + public DbSet Collections { get; set; } + public DbSet CollectionCiphers { get; set; } + public DbSet CollectionGroups { get; set; } + public DbSet CollectionUsers { get; set; } + public DbSet Devices { get; set; } + public DbSet EmergencyAccesses { get; set; } + public DbSet Events { get; set; } + public DbSet Folders { get; set; } + public DbSet Grants { get; set; } + public DbSet Groups { get; set; } + public DbSet GroupUsers { get; set; } + public DbSet Installations { get; set; } + public DbSet Organizations { get; set; } + public DbSet OrganizationApiKeys { get; set; } + public DbSet OrganizationSponsorships { get; set; } + public DbSet OrganizationConnections { get; set; } + public DbSet OrganizationUsers { get; set; } + public DbSet Policies { get; set; } + public DbSet Providers { get; set; } + public DbSet ProviderUsers { get; set; } + public DbSet ProviderOrganizations { get; set; } + public DbSet Sends { get; set; } + public DbSet SsoConfigs { get; set; } + public DbSet SsoUsers { get; set; } + public DbSet TaxRates { get; set; } + public DbSet Transactions { get; set; } + public DbSet Users { get; set; } - - if (Database.IsNpgsql()) + protected override void OnModelCreating(ModelBuilder builder) { - // the postgres provider doesn't currently support database level non-deterministic collations. - // see https://www.npgsql.org/efcore/misc/collations-and-case-sensitivity.html#database-collation - builder.HasCollation(postgresIndetermanisticCollation, locale: "en-u-ks-primary", provider: "icu", deterministic: false); - eUser.Property(e => e.Email).UseCollation(postgresIndetermanisticCollation); - eSsoUser.Property(e => e.ExternalId).UseCollation(postgresIndetermanisticCollation); - eOrganization.Property(e => e.Identifier).UseCollation(postgresIndetermanisticCollation); - // - } + var eCipher = builder.Entity(); + var eCollection = builder.Entity(); + var eCollectionCipher = builder.Entity(); + var eCollectionUser = builder.Entity(); + var eCollectionGroup = builder.Entity(); + var eDevice = builder.Entity(); + var eEmergencyAccess = builder.Entity(); + var eEvent = builder.Entity(); + var eFolder = builder.Entity(); + var eGrant = builder.Entity(); + var eGroup = builder.Entity(); + var eGroupUser = builder.Entity(); + var eInstallation = builder.Entity(); + var eOrganization = builder.Entity(); + var eOrganizationSponsorship = builder.Entity(); + var eOrganizationUser = builder.Entity(); + var ePolicy = builder.Entity(); + var eProvider = builder.Entity(); + var eProviderUser = builder.Entity(); + var eProviderOrganization = builder.Entity(); + var eSend = builder.Entity(); + var eSsoConfig = builder.Entity(); + var eSsoUser = builder.Entity(); + var eTaxRate = builder.Entity(); + var eTransaction = builder.Entity(); + var eUser = builder.Entity(); + var eOrganizationApiKey = builder.Entity(); + var eOrganizationConnection = builder.Entity(); - eCipher.ToTable(nameof(Cipher)); - eCollection.ToTable(nameof(Collection)); - eCollectionCipher.ToTable(nameof(CollectionCipher)); - eDevice.ToTable(nameof(Device)); - eEmergencyAccess.ToTable(nameof(EmergencyAccess)); - eEvent.ToTable(nameof(Event)); - eFolder.ToTable(nameof(Folder)); - eGrant.ToTable(nameof(Grant)); - eGroup.ToTable(nameof(Group)); - eGroupUser.ToTable(nameof(GroupUser)); - eInstallation.ToTable(nameof(Installation)); - eOrganization.ToTable(nameof(Organization)); - eOrganizationSponsorship.ToTable(nameof(OrganizationSponsorship)); - eOrganizationUser.ToTable(nameof(OrganizationUser)); - ePolicy.ToTable(nameof(Policy)); - eProvider.ToTable(nameof(Provider)); - eProviderUser.ToTable(nameof(ProviderUser)); - eProviderOrganization.ToTable(nameof(ProviderOrganization)); - eSend.ToTable(nameof(Send)); - eSsoConfig.ToTable(nameof(SsoConfig)); - eSsoUser.ToTable(nameof(SsoUser)); - eTaxRate.ToTable(nameof(TaxRate)); - eTransaction.ToTable(nameof(Transaction)); - eUser.ToTable(nameof(User)); - eOrganizationApiKey.ToTable(nameof(OrganizationApiKey)); - eOrganizationConnection.ToTable(nameof(OrganizationConnection)); + eCipher.Property(c => c.Id).ValueGeneratedNever(); + eCollection.Property(c => c.Id).ValueGeneratedNever(); + eEmergencyAccess.Property(c => c.Id).ValueGeneratedNever(); + eEvent.Property(c => c.Id).ValueGeneratedNever(); + eFolder.Property(c => c.Id).ValueGeneratedNever(); + eGroup.Property(c => c.Id).ValueGeneratedNever(); + eInstallation.Property(c => c.Id).ValueGeneratedNever(); + eOrganization.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationSponsorship.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationUser.Property(c => c.Id).ValueGeneratedNever(); + ePolicy.Property(c => c.Id).ValueGeneratedNever(); + eProvider.Property(c => c.Id).ValueGeneratedNever(); + eProviderUser.Property(c => c.Id).ValueGeneratedNever(); + eProviderOrganization.Property(c => c.Id).ValueGeneratedNever(); + eSend.Property(c => c.Id).ValueGeneratedNever(); + eTransaction.Property(c => c.Id).ValueGeneratedNever(); + eUser.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationApiKey.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationConnection.Property(c => c.Id).ValueGeneratedNever(); + + eCollectionCipher.HasKey(cc => new { cc.CollectionId, cc.CipherId }); + eCollectionUser.HasKey(cu => new { cu.CollectionId, cu.OrganizationUserId }); + eCollectionGroup.HasKey(cg => new { cg.CollectionId, cg.GroupId }); + eGrant.HasKey(x => x.Key); + eGroupUser.HasKey(gu => new { gu.GroupId, gu.OrganizationUserId }); + + + if (Database.IsNpgsql()) + { + // the postgres provider doesn't currently support database level non-deterministic collations. + // see https://www.npgsql.org/efcore/misc/collations-and-case-sensitivity.html#database-collation + builder.HasCollation(postgresIndetermanisticCollation, locale: "en-u-ks-primary", provider: "icu", deterministic: false); + eUser.Property(e => e.Email).UseCollation(postgresIndetermanisticCollation); + eSsoUser.Property(e => e.ExternalId).UseCollation(postgresIndetermanisticCollation); + eOrganization.Property(e => e.Identifier).UseCollation(postgresIndetermanisticCollation); + // + } + + eCipher.ToTable(nameof(Cipher)); + eCollection.ToTable(nameof(Collection)); + eCollectionCipher.ToTable(nameof(CollectionCipher)); + eDevice.ToTable(nameof(Device)); + eEmergencyAccess.ToTable(nameof(EmergencyAccess)); + eEvent.ToTable(nameof(Event)); + eFolder.ToTable(nameof(Folder)); + eGrant.ToTable(nameof(Grant)); + eGroup.ToTable(nameof(Group)); + eGroupUser.ToTable(nameof(GroupUser)); + eInstallation.ToTable(nameof(Installation)); + eOrganization.ToTable(nameof(Organization)); + eOrganizationSponsorship.ToTable(nameof(OrganizationSponsorship)); + eOrganizationUser.ToTable(nameof(OrganizationUser)); + ePolicy.ToTable(nameof(Policy)); + eProvider.ToTable(nameof(Provider)); + eProviderUser.ToTable(nameof(ProviderUser)); + eProviderOrganization.ToTable(nameof(ProviderOrganization)); + eSend.ToTable(nameof(Send)); + eSsoConfig.ToTable(nameof(SsoConfig)); + eSsoUser.ToTable(nameof(SsoUser)); + eTaxRate.ToTable(nameof(TaxRate)); + eTransaction.ToTable(nameof(Transaction)); + eUser.ToTable(nameof(User)); + eOrganizationApiKey.ToTable(nameof(OrganizationApiKey)); + eOrganizationConnection.ToTable(nameof(OrganizationConnection)); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs index cc664aa1b7..79ad608181 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs @@ -4,67 +4,68 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class DeviceRepository : Repository, IDeviceRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public DeviceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Devices) - { } - - public async Task ClearPushTokenAsync(Guid id) + public class DeviceRepository : Repository, IDeviceRepository { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.Id == id); - dbContext.AttachRange(query); - await query.ForEachAsync(x => x.PushToken = null); - await dbContext.SaveChangesAsync(); - } - } + public DeviceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Devices) + { } - public async Task GetByIdAsync(Guid id, Guid userId) - { - var device = await base.GetByIdAsync(id); - if (device == null || device.UserId != userId) + public async Task ClearPushTokenAsync(Guid id) { - return null; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.Id == id); + dbContext.AttachRange(query); + await query.ForEachAsync(x => x.PushToken = null); + await dbContext.SaveChangesAsync(); + } } - return Mapper.Map(device); - } - - public async Task GetByIdentifierAsync(string identifier) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByIdAsync(Guid id, Guid userId) { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.Identifier == identifier); - var device = await query.FirstOrDefaultAsync(); + var device = await base.GetByIdAsync(id); + if (device == null || device.UserId != userId) + { + return null; + } + return Mapper.Map(device); } - } - public async Task GetByIdentifierAsync(string identifier, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByIdentifierAsync(string identifier) { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.Identifier == identifier && d.UserId == userId); - var device = await query.FirstOrDefaultAsync(); - return Mapper.Map(device); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.Identifier == identifier); + var device = await query.FirstOrDefaultAsync(); + return Mapper.Map(device); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByIdentifierAsync(string identifier, Guid userId) { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.UserId == userId); - var devices = await query.ToListAsync(); - return Mapper.Map>(devices); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.Identifier == identifier && d.UserId == userId); + var device = await query.FirstOrDefaultAsync(); + return Mapper.Map(device); + } + } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.UserId == userId); + var devices = await query.ToListAsync(); + return Mapper.Map>(devices); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs b/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs index 028ce222f3..4ace885608 100644 --- a/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs @@ -7,101 +7,102 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public EmergencyAccessRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.EmergencyAccesses) - { } - - public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) + public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository { - var query = new EmergencyAccessReadCountByGrantorIdEmailQuery(grantorId, email, onlyRegisteredUsers); - return await GetCountFromQuery(query); - } + public EmergencyAccessRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.EmergencyAccesses) + { } - public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.Id == id && - ea.GrantorId == grantorId - ); - return await query.FirstOrDefaultAsync(); + var query = new EmergencyAccessReadCountByGrantorIdEmailQuery(grantorId, email, onlyRegisteredUsers); + return await GetCountFromQuery(query); } - } - public async Task> GetExpiredRecoveriesAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.Status == EmergencyAccessStatusType.RecoveryInitiated - ); - return await query.ToListAsync(); - } - } - - public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.GranteeId == granteeId - ); - return await query.ToListAsync(); - } - } - - public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.GrantorId == grantorId - ); - return await query.ToListAsync(); - } - } - - public async Task> GetManyToNotifyAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.Status == EmergencyAccessStatusType.RecoveryInitiated - ); - var notifies = await query.Select(ea => new EmergencyAccessNotify + using (var scope = ServiceScopeFactory.CreateScope()) { - Id = ea.Id, - GrantorId = ea.GrantorId, - GranteeId = ea.GranteeId, - Email = ea.Email, - KeyEncrypted = ea.KeyEncrypted, - Type = ea.Type, - Status = ea.Status, - WaitTimeDays = ea.WaitTimeDays, - RecoveryInitiatedDate = ea.RecoveryInitiatedDate, - LastNotificationDate = ea.LastNotificationDate, - CreationDate = ea.CreationDate, - RevisionDate = ea.RevisionDate, - GranteeName = ea.GranteeName, - GranteeEmail = ea.GranteeEmail, - GrantorEmail = ea.GrantorEmail, - }).ToListAsync(); - return notifies; + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.Id == id && + ea.GrantorId == grantorId + ); + return await query.FirstOrDefaultAsync(); + } + } + + public async Task> GetExpiredRecoveriesAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.Status == EmergencyAccessStatusType.RecoveryInitiated + ); + return await query.ToListAsync(); + } + } + + public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.GranteeId == granteeId + ); + return await query.ToListAsync(); + } + } + + public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.GrantorId == grantorId + ); + return await query.ToListAsync(); + } + } + + public async Task> GetManyToNotifyAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.Status == EmergencyAccessStatusType.RecoveryInitiated + ); + var notifies = await query.Select(ea => new EmergencyAccessNotify + { + Id = ea.Id, + GrantorId = ea.GrantorId, + GranteeId = ea.GranteeId, + Email = ea.Email, + KeyEncrypted = ea.KeyEncrypted, + Type = ea.Type, + Status = ea.Status, + WaitTimeDays = ea.WaitTimeDays, + RecoveryInitiatedDate = ea.RecoveryInitiatedDate, + LastNotificationDate = ea.LastNotificationDate, + CreationDate = ea.CreationDate, + RevisionDate = ea.RevisionDate, + GranteeName = ea.GranteeName, + GranteeEmail = ea.GranteeEmail, + GrantorEmail = ea.GrantorEmail, + }).ToListAsync(); + return notifies; + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs b/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs index cb49f8535e..712885245c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs @@ -8,195 +8,196 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using Cipher = Bit.Core.Entities.Cipher; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class EventRepository : Repository, IEventRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public EventRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Events) - { } - - public async Task CreateAsync(IEvent e) + public class EventRepository : Repository, IEventRepository { - if (e is not Core.Entities.Event ev) - { - ev = new Core.Entities.Event(e); - } + public EventRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Events) + { } - await base.CreateAsync(ev); - } - - public async Task CreateManyAsync(IEnumerable entities) - { - if (!entities?.Any() ?? true) + public async Task CreateAsync(IEvent e) { - return; - } - - if (!entities.Skip(1).Any()) - { - await CreateAsync(entities.First()); - return; - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var tableEvents = entities.Select(e => e as Core.Entities.Event ?? new Core.Entities.Event(e)); - var entityEvents = Mapper.Map>(tableEvents); - entityEvents.ForEach(e => e.SetNewId()); - await dbContext.BulkCopyAsync(entityEvents); - } - } - - public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByCipherIdQuery(cipher, startDate, endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + if (e is not Core.Entities.Event ev) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + ev = new Core.Entities.Event(e); } - result.Data.AddRange(events); - return result; + + await base.CreateAsync(ev); } - } - - public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + public async Task CreateManyAsync(IEnumerable entities) { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByOrganizationIdActingUserIdQuery(organizationId, actingUserId, - startDate, endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + if (!entities?.Any() ?? true) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + return; } - result.Data.AddRange(events); - return result; - } - } - public async Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByProviderIdQuery(providerId, startDate, - endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + if (!entities.Skip(1).Any()) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + await CreateAsync(entities.First()); + return; } - result.Data.AddRange(events); - return result; - } - } - public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByProviderIdActingUserIdQuery(providerId, actingUserId, - startDate, endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + using (var scope = ServiceScopeFactory.CreateScope()) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + var dbContext = GetDatabaseContext(scope); + var tableEvents = entities.Select(e => e as Core.Entities.Event ?? new Core.Entities.Event(e)); + var entityEvents = Mapper.Map>(tableEvents); + entityEvents.ForEach(e => e.SetNewId()); + await dbContext.BulkCopyAsync(entityEvents); } - result.Data.AddRange(events); - return result; } - } - public async Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByOrganizationIdQuery(organizationId, startDate, - endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); } - result.Data.AddRange(events); - return result; - } - } - - public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByUserIdQuery(userId, startDate, - endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) + using (var scope = ServiceScopeFactory.CreateScope()) { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByCipherIdQuery(cipher, startDate, endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + + public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByOrganizationIdActingUserIdQuery(organizationId, actingUserId, + startDate, endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByProviderIdQuery(providerId, startDate, + endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByProviderIdActingUserIdQuery(providerId, actingUserId, + startDate, endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByOrganizationIdQuery(organizationId, startDate, + endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByUserIdQuery(userId, startDate, + endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; } - result.Data.AddRange(events); - return result; } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs b/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs index 9f1f862bf0..dae64f9c23 100644 --- a/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs @@ -4,35 +4,36 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class FolderRepository : Repository, IFolderRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public FolderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Folders) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + public class FolderRepository : Repository, IFolderRepository { - var folder = await base.GetByIdAsync(id); - if (folder == null || folder.UserId != userId) + public FolderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Folders) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - return null; + var folder = await base.GetByIdAsync(id); + if (folder == null || folder.UserId != userId) + { + return null; + } + + return folder; } - return folder; - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByUserIdAsync(Guid userId) { - var dbContext = GetDatabaseContext(scope); - var query = from f in dbContext.Folders - where f.UserId == userId - select f; - var folders = await query.ToListAsync(); - return Mapper.Map>(folders); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from f in dbContext.Folders + where f.UserId == userId + select f; + var folders = await query.ToListAsync(); + return Mapper.Map>(folders); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs b/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs index 2edb62d9cb..0f8f197fe6 100644 --- a/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs @@ -4,91 +4,92 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public GrantRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper) - { } - - public async Task DeleteByKeyAsync(string key) + public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.Key == key - select g; - dbContext.Remove(query); - await dbContext.SaveChangesAsync(); - } - } + public GrantRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper) + { } - public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteByKeyAsync(string key) { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.SubjectId == subjectId && - g.ClientId == clientId && - g.SessionId == sessionId && - g.Type == type - select g; - dbContext.Remove(query); - await dbContext.SaveChangesAsync(); - } - } - - public async Task GetByKeyAsync(string key) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.Key == key - select g; - var grant = await query.FirstOrDefaultAsync(); - return grant; - } - } - - public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.SubjectId == subjectId && - g.ClientId == clientId && - g.SessionId == sessionId && - g.Type == type - select g; - var grants = await query.ToListAsync(); - return (ICollection)grants; - } - } - - public async Task SaveAsync(Core.Entities.Grant obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var existingGrant = await (from g in dbContext.Grants - where g.Key == obj.Key - select g).FirstOrDefaultAsync(); - if (existingGrant != null) + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.Entry(existingGrant).CurrentValues.SetValues(obj); - } - else - { - var entity = Mapper.Map(obj); - await dbContext.AddAsync(entity); + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.Key == key + select g; + dbContext.Remove(query); await dbContext.SaveChangesAsync(); } } + + public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.SubjectId == subjectId && + g.ClientId == clientId && + g.SessionId == sessionId && + g.Type == type + select g; + dbContext.Remove(query); + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetByKeyAsync(string key) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.Key == key + select g; + var grant = await query.FirstOrDefaultAsync(); + return grant; + } + } + + public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.SubjectId == subjectId && + g.ClientId == clientId && + g.SessionId == sessionId && + g.Type == type + select g; + var grants = await query.ToListAsync(); + return (ICollection)grants; + } + } + + public async Task SaveAsync(Core.Entities.Grant obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var existingGrant = await (from g in dbContext.Grants + where g.Key == obj.Key + select g).FirstOrDefaultAsync(); + if (existingGrant != null) + { + dbContext.Entry(existingGrant).CurrentValues.SetValues(obj); + } + else + { + var entity = Mapper.Map(obj); + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + } + } + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs b/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs index d41f078045..b471a5fdbc 100644 --- a/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs @@ -5,163 +5,164 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class GroupRepository : Repository, IGroupRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public GroupRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Groups) - { } - - public async Task CreateAsync(Core.Entities.Group obj, IEnumerable collections) + public class GroupRepository : Repository, IGroupRepository { - var grp = await base.CreateAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) + public GroupRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Groups) + { } + + public async Task CreateAsync(Core.Entities.Group obj, IEnumerable collections) { - var dbContext = GetDatabaseContext(scope); - var availibleCollections = await ( - from c in dbContext.Collections - where c.OrganizationId == grp.OrganizationId - select c).ToListAsync(); - var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); - var collectionGroups = filteredCollections.Select(y => new CollectionGroup + var grp = await base.CreateAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - CollectionId = y.Id, - GroupId = grp.Id, - ReadOnly = y.ReadOnly, - HidePasswords = y.HidePasswords, - }); - await dbContext.CollectionGroups.AddRangeAsync(collectionGroups); - await dbContext.SaveChangesAsync(); + var dbContext = GetDatabaseContext(scope); + var availibleCollections = await ( + from c in dbContext.Collections + where c.OrganizationId == grp.OrganizationId + select c).ToListAsync(); + var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); + var collectionGroups = filteredCollections.Select(y => new CollectionGroup + { + CollectionId = y.Id, + GroupId = grp.Id, + ReadOnly = y.ReadOnly, + HidePasswords = y.HidePasswords, + }); + await dbContext.CollectionGroups.AddRangeAsync(collectionGroups); + await dbContext.SaveChangesAsync(); + } } - } - public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) { - var dbContext = GetDatabaseContext(scope); - var query = from gu in dbContext.GroupUsers - where gu.GroupId == groupId && - gu.OrganizationUserId == organizationUserId - select gu; - dbContext.RemoveRange(await query.ToListAsync()); - await dbContext.SaveChangesAsync(); - } - } - - public async Task>> GetByIdWithCollectionsAsync(Guid id) - { - var grp = await base.GetByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = await ( - from cg in dbContext.CollectionGroups - where cg.GroupId == id - select cg).ToListAsync(); - var collections = query.Select(c => new SelectionReadOnly + using (var scope = ServiceScopeFactory.CreateScope()) { - Id = c.CollectionId, - ReadOnly = c.ReadOnly, - HidePasswords = c.HidePasswords, - }).ToList(); - return new Tuple>( - grp, collections); + var dbContext = GetDatabaseContext(scope); + var query = from gu in dbContext.GroupUsers + where gu.GroupId == groupId && + gu.OrganizationUserId == organizationUserId + select gu; + dbContext.RemoveRange(await query.ToListAsync()); + await dbContext.SaveChangesAsync(); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task>> GetByIdWithCollectionsAsync(Guid id) { - var dbContext = GetDatabaseContext(scope); - var data = await ( - from g in dbContext.Groups - where g.OrganizationId == organizationId - select g).ToListAsync(); - return Mapper.Map>(data); + var grp = await base.GetByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = await ( + from cg in dbContext.CollectionGroups + where cg.GroupId == id + select cg).ToListAsync(); + var collections = query.Select(c => new SelectionReadOnly + { + Id = c.CollectionId, + ReadOnly = c.ReadOnly, + HidePasswords = c.HidePasswords, + }).ToList(); + return new Tuple>( + grp, collections); + } } - } - public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var query = - from gu in dbContext.GroupUsers - join g in dbContext.Groups - on gu.GroupId equals g.Id - where g.OrganizationId == organizationId - select gu; - var groupUsers = await query.ToListAsync(); - return Mapper.Map>(groupUsers); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var data = await ( + from g in dbContext.Groups + where g.OrganizationId == organizationId + select g).ToListAsync(); + return Mapper.Map>(data); + } } - } - public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var query = - from gu in dbContext.GroupUsers - where gu.OrganizationUserId == organizationUserId - select gu; - var groupIds = await query.Select(x => x.GroupId).ToListAsync(); - return groupIds; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = + from gu in dbContext.GroupUsers + join g in dbContext.Groups + on gu.GroupId equals g.Id + where g.OrganizationId == organizationId + select gu; + var groupUsers = await query.ToListAsync(); + return Mapper.Map>(groupUsers); + } } - } - public async Task> GetManyUserIdsByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) { - var dbContext = GetDatabaseContext(scope); - var query = - from gu in dbContext.GroupUsers - where gu.GroupId == id - select gu; - var groupIds = await query.Select(x => x.OrganizationUserId).ToListAsync(); - return groupIds; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = + from gu in dbContext.GroupUsers + where gu.OrganizationUserId == organizationUserId + select gu; + var groupIds = await query.Select(x => x.GroupId).ToListAsync(); + return groupIds; + } } - } - public async Task ReplaceAsync(Core.Entities.Group obj, IEnumerable collections) - { - await base.ReplaceAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyUserIdsByIdAsync(Guid id) { - var dbContext = GetDatabaseContext(scope); - await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = + from gu in dbContext.GroupUsers + where gu.GroupId == id + select gu; + var groupIds = await query.Select(x => x.OrganizationUserId).ToListAsync(); + return groupIds; + } } - } - public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task ReplaceAsync(Core.Entities.Group obj, IEnumerable collections) { - var dbContext = GetDatabaseContext(scope); - var orgId = (await dbContext.Groups.FindAsync(groupId)).OrganizationId; - var insert = from ou in dbContext.OrganizationUsers - where organizationUserIds.Contains(ou.Id) && - ou.OrganizationId == orgId && - !dbContext.GroupUsers.Any(gu => gu.GroupId == groupId && ou.Id == gu.OrganizationUserId) - select new GroupUser - { - GroupId = groupId, - OrganizationUserId = ou.Id, - }; - await dbContext.AddRangeAsync(insert); + await base.ReplaceAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); + } + } - var delete = from gu in dbContext.GroupUsers - where gu.GroupId == groupId && - !organizationUserIds.Contains(gu.OrganizationUserId) - select gu; - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationId(orgId); + public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgId = (await dbContext.Groups.FindAsync(groupId)).OrganizationId; + var insert = from ou in dbContext.OrganizationUsers + where organizationUserIds.Contains(ou.Id) && + ou.OrganizationId == orgId && + !dbContext.GroupUsers.Any(gu => gu.GroupId == groupId && ou.Id == gu.OrganizationUserId) + select new GroupUser + { + GroupId = groupId, + OrganizationUserId = ou.Id, + }; + await dbContext.AddRangeAsync(insert); + + var delete = from gu in dbContext.GroupUsers + where gu.GroupId == groupId && + !organizationUserIds.Contains(gu.OrganizationUserId) + select gu; + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationId(orgId); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs index 292e98f851..1cc4808c32 100644 --- a/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs @@ -3,11 +3,12 @@ using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Models; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class InstallationRepository : Repository, IInstallationRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public InstallationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Installations) - { } + public class InstallationRepository : Repository, IInstallationRepository + { + public InstallationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Installations) + { } + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs b/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs index 340834ca59..e91d775cc4 100644 --- a/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs @@ -2,52 +2,53 @@ using Bit.Core.Repositories; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class MaintenanceRepository : BaseEntityFrameworkRepository, IMaintenanceRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public MaintenanceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper) - { } - - public async Task DeleteExpiredGrantsAsync() + public class MaintenanceRepository : BaseEntityFrameworkRepository, IMaintenanceRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public MaintenanceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper) + { } + + public async Task DeleteExpiredGrantsAsync() { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.ExpirationDate < DateTime.UtcNow - select g; - dbContext.RemoveRange(query); - await dbContext.SaveChangesAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.ExpirationDate < DateTime.UtcNow + select g; + dbContext.RemoveRange(query); + await dbContext.SaveChangesAsync(); + } } - } - public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) { - var dbContext = GetDatabaseContext(scope); - var query = from s in dbContext.OrganizationSponsorships - where s.ValidUntil < validUntilBeforeDate - select s; - dbContext.RemoveRange(query); - await dbContext.SaveChangesAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from s in dbContext.OrganizationSponsorships + where s.ValidUntil < validUntilBeforeDate + select s; + dbContext.RemoveRange(query); + await dbContext.SaveChangesAsync(); + } } - } - public Task DisableCipherAutoStatsAsync() - { - return Task.CompletedTask; - } + public Task DisableCipherAutoStatsAsync() + { + return Task.CompletedTask; + } - public Task RebuildIndexesAsync() - { - return Task.CompletedTask; - } + public Task RebuildIndexesAsync() + { + return Task.CompletedTask; + } - public Task UpdateStatisticsAsync() - { - return Task.CompletedTask; + public Task UpdateStatisticsAsync() + { + return Task.CompletedTask; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs index 52cf3d5e6a..8bc462adf7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs @@ -5,25 +5,26 @@ using Bit.Core.Repositories; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public OrganizationApiKeyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, db => db.OrganizationApiKeys) + public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository { - - } - - public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public OrganizationApiKeyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, db => db.OrganizationApiKeys) { - var dbContext = GetDatabaseContext(scope); - var apiKeys = await dbContext.OrganizationApiKeys - .Where(o => o.OrganizationId == organizationId && (type == null || o.Type == type)) - .ToListAsync(); - return Mapper.Map>(apiKeys); + + } + + public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var apiKeys = await dbContext.OrganizationApiKeys + .Where(o => o.OrganizationId == organizationId && (type == null || o.Type == type)) + .ToListAsync(); + return Mapper.Map>(apiKeys); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs index 298e28e029..5acd8807db 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs @@ -5,37 +5,38 @@ using Bit.Core.Repositories; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public OrganizationConnectionRepository(IServiceScopeFactory serviceScopeFactory, - IMapper mapper) - : base(serviceScopeFactory, mapper, context => context.OrganizationConnections) + public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository { - } - - public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public OrganizationConnectionRepository(IServiceScopeFactory serviceScopeFactory, + IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.OrganizationConnections) { - var dbContext = GetDatabaseContext(scope); - var connections = await dbContext.OrganizationConnections - .Where(oc => oc.OrganizationId == organizationId && oc.Type == type) - .ToListAsync(); - return Mapper.Map>(connections); } - } - public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) { - var dbContext = GetDatabaseContext(scope); - var connections = await dbContext.OrganizationConnections - .Where(oc => oc.OrganizationId == organizationId && oc.Type == type && oc.Enabled) - .ToListAsync(); - return Mapper.Map>(connections); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var connections = await dbContext.OrganizationConnections + .Where(oc => oc.OrganizationId == organizationId && oc.Type == type) + .ToListAsync(); + return Mapper.Map>(connections); + } + } + + public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var connections = await dbContext.OrganizationConnections + .Where(oc => oc.OrganizationId == organizationId && oc.Type == type && oc.Enabled) + .ToListAsync(); + return Mapper.Map>(connections); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs index bd60b53e8e..b12ff65bfd 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs @@ -5,104 +5,105 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class OrganizationRepository : Repository, IOrganizationRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public OrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Organizations) - { } - - public async Task GetByIdentifierAsync(string identifier) + public class OrganizationRepository : Repository, IOrganizationRepository { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organization = await GetDbSet(dbContext).Where(e => e.Identifier == identifier) - .FirstOrDefaultAsync(); - return organization; - } - } + public OrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Organizations) + { } - public async Task> GetManyByEnabledAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByIdentifierAsync(string identifier) { - var dbContext = GetDatabaseContext(scope); - var organizations = await GetDbSet(dbContext).Where(e => e.Enabled).ToListAsync(); - return Mapper.Map>(organizations); - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizations = await GetDbSet(dbContext) - .Select(e => e.OrganizationUsers - .Where(ou => ou.UserId == userId) - .Select(ou => ou.Organization)) - .ToListAsync(); - return Mapper.Map>(organizations); - } - } - - public async Task> SearchAsync(string name, string userEmail, - bool? paid, int skip, int take) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizations = await GetDbSet(dbContext) - .Where(e => name == null || e.Name.Contains(name)) - .Where(e => userEmail == null || e.OrganizationUsers.Any(u => u.Email == userEmail)) - .Where(e => paid == null || - (paid == true && !string.IsNullOrWhiteSpace(e.GatewaySubscriptionId)) || - (paid == false && e.GatewaySubscriptionId == null)) - .OrderBy(e => e.CreationDate) - .Skip(skip).Take(take) - .ToListAsync(); - return Mapper.Map>(organizations); - } - } - - public async Task> GetManyAbilitiesAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext) - .Select(e => new OrganizationAbility + using (var scope = ServiceScopeFactory.CreateScope()) { - Enabled = e.Enabled, - Id = e.Id, - Use2fa = e.Use2fa, - UseEvents = e.UseEvents, - UsersGetPremium = e.UsersGetPremium, - Using2fa = e.Use2fa && e.TwoFactorProviders != null, - UseSso = e.UseSso, - UseKeyConnector = e.UseKeyConnector, - UseResetPassword = e.UseResetPassword, - UseScim = e.UseScim, - }).ToListAsync(); + var dbContext = GetDatabaseContext(scope); + var organization = await GetDbSet(dbContext).Where(e => e.Identifier == identifier) + .FirstOrDefaultAsync(); + return organization; + } } - } - public async Task UpdateStorageAsync(Guid id) - { - await OrganizationUpdateStorage(id); - } - - public override async Task DeleteAsync(Core.Entities.Organization organization) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByEnabledAsync() { - var dbContext = GetDatabaseContext(scope); - var orgEntity = await dbContext.FindAsync(organization.Id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var organizations = await GetDbSet(dbContext).Where(e => e.Enabled).ToListAsync(); + return Mapper.Map>(organizations); + } + } - dbContext.Remove(orgEntity); - await dbContext.SaveChangesAsync(); + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var organizations = await GetDbSet(dbContext) + .Select(e => e.OrganizationUsers + .Where(ou => ou.UserId == userId) + .Select(ou => ou.Organization)) + .ToListAsync(); + return Mapper.Map>(organizations); + } + } + + public async Task> SearchAsync(string name, string userEmail, + bool? paid, int skip, int take) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var organizations = await GetDbSet(dbContext) + .Where(e => name == null || e.Name.Contains(name)) + .Where(e => userEmail == null || e.OrganizationUsers.Any(u => u.Email == userEmail)) + .Where(e => paid == null || + (paid == true && !string.IsNullOrWhiteSpace(e.GatewaySubscriptionId)) || + (paid == false && e.GatewaySubscriptionId == null)) + .OrderBy(e => e.CreationDate) + .Skip(skip).Take(take) + .ToListAsync(); + return Mapper.Map>(organizations); + } + } + + public async Task> GetManyAbilitiesAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext) + .Select(e => new OrganizationAbility + { + Enabled = e.Enabled, + Id = e.Id, + Use2fa = e.Use2fa, + UseEvents = e.UseEvents, + UsersGetPremium = e.UsersGetPremium, + Using2fa = e.Use2fa && e.TwoFactorProviders != null, + UseSso = e.UseSso, + UseKeyConnector = e.UseKeyConnector, + UseResetPassword = e.UseResetPassword, + UseScim = e.UseScim, + }).ToListAsync(); + } + } + + public async Task UpdateStorageAsync(Guid id) + { + await OrganizationUpdateStorage(id); + } + + public override async Task DeleteAsync(Core.Entities.Organization organization) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgEntity = await dbContext.FindAsync(organization.Id); + + dbContext.Remove(orgEntity); + await dbContext.SaveChangesAsync(); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs index de0af89dfb..9e00d924d1 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs @@ -4,137 +4,138 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public OrganizationSponsorshipRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationSponsorships) - { } - - public async Task> CreateManyAsync(IEnumerable organizationSponsorships) + public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository { - if (!organizationSponsorships.Any()) - { - return new List(); - } + public OrganizationSponsorshipRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationSponsorships) + { } - foreach (var organizationSponsorship in organizationSponsorships) + public async Task> CreateManyAsync(IEnumerable organizationSponsorships) { - organizationSponsorship.SetNewId(); - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = Mapper.Map>(organizationSponsorships); - await dbContext.AddRangeAsync(entities); - await dbContext.SaveChangesAsync(); - } - - return organizationSponsorships.Select(u => u.Id).ToList(); - } - - public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - dbContext.UpdateRange(organizationSponsorships); - await dbContext.SaveChangesAsync(); - } - } - - public async Task UpsertManyAsync(IEnumerable organizationSponsorships) - { - var createSponsorships = new List(); - var replaceSponsorships = new List(); - foreach (var organizationSponsorship in organizationSponsorships) - { - if (organizationSponsorship.Id.Equals(default)) + if (!organizationSponsorships.Any()) { - createSponsorships.Add(organizationSponsorship); + return new List(); } - else + + foreach (var organizationSponsorship in organizationSponsorships) { - replaceSponsorships.Add(organizationSponsorship); + organizationSponsorship.SetNewId(); + } + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = Mapper.Map>(organizationSponsorships); + await dbContext.AddRangeAsync(entities); + await dbContext.SaveChangesAsync(); + } + + return organizationSponsorships.Select(u => u.Id).ToList(); + } + + public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + dbContext.UpdateRange(organizationSponsorships); + await dbContext.SaveChangesAsync(); } } - await CreateManyAsync(createSponsorships); - await ReplaceManyAsync(replaceSponsorships); - } - - public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task UpsertManyAsync(IEnumerable organizationSponsorships) { - var dbContext = GetDatabaseContext(scope); - var entities = await dbContext.OrganizationSponsorships - .Where(os => organizationSponsorshipIds.Contains(os.Id)) - .ToListAsync(); + var createSponsorships = new List(); + var replaceSponsorships = new List(); + foreach (var organizationSponsorship in organizationSponsorships) + { + if (organizationSponsorship.Id.Equals(default)) + { + createSponsorships.Add(organizationSponsorship); + } + else + { + replaceSponsorships.Add(organizationSponsorship); + } + } - dbContext.OrganizationSponsorships.RemoveRange(entities); - await dbContext.SaveChangesAsync(); + await CreateManyAsync(createSponsorships); + await ReplaceManyAsync(replaceSponsorships); } - } - public async Task GetByOfferedToEmailAsync(string email) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) { - var dbContext = GetDatabaseContext(scope); - var orgSponsorship = await GetDbSet(dbContext).Where(e => e.OfferedToEmail == email) - .FirstOrDefaultAsync(); - return orgSponsorship; - } - } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = await dbContext.OrganizationSponsorships + .Where(os => organizationSponsorshipIds.Contains(os.Id)) + .ToListAsync(); - public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + dbContext.OrganizationSponsorships.RemoveRange(entities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetByOfferedToEmailAsync(string email) { - var dbContext = GetDatabaseContext(scope); - var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoredOrganizationId == sponsoredOrganizationId) - .FirstOrDefaultAsync(); - return orgSponsorship; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgSponsorship = await GetDbSet(dbContext).Where(e => e.OfferedToEmail == email) + .FirstOrDefaultAsync(); + return orgSponsorship; + } } - } - public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) { - var dbContext = GetDatabaseContext(scope); - var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationUserId == sponsoringOrganizationUserId) - .FirstOrDefaultAsync(); - return orgSponsorship; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoredOrganizationId == sponsoredOrganizationId) + .FirstOrDefaultAsync(); + return orgSponsorship; + } } - } - public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationId == sponsoringOrganizationId && e.LastSyncDate != null) - .OrderByDescending(e => e.LastSyncDate) - .Select(e => e.LastSyncDate) - .FirstOrDefaultAsync(); - + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationUserId == sponsoringOrganizationUserId) + .FirstOrDefaultAsync(); + return orgSponsorship; + } } - } - public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) { - var dbContext = GetDatabaseContext(scope); - var query = from os in dbContext.OrganizationSponsorships - where os.SponsoringOrganizationId == sponsoringOrganizationId - select os; - return Mapper.Map>(await query.ToListAsync()); - } - } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationId == sponsoringOrganizationId && e.LastSyncDate != null) + .OrderByDescending(e => e.LastSyncDate) + .Select(e => e.LastSyncDate) + .FirstOrDefaultAsync(); + } + } + + public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from os in dbContext.OrganizationSponsorships + where os.SponsoringOrganizationId == sponsoringOrganizationId + select os; + return Mapper.Map>(await query.ToListAsync()); + } + } + + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs index 70f4401eed..0c0383182e 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs @@ -8,455 +8,456 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class OrganizationUserRepository : Repository, IOrganizationUserRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public OrganizationUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationUsers) - { } - - public async Task CreateAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) + public class OrganizationUserRepository : Repository, IOrganizationUserRepository { - var organizationUser = await base.CreateAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) + public OrganizationUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationUsers) + { } + + public async Task CreateAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) { - var dbContext = GetDatabaseContext(scope); - var availibleCollections = await ( - from c in dbContext.Collections - where c.OrganizationId == organizationUser.OrganizationId - select c).ToListAsync(); - var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); - var collectionUsers = filteredCollections.Select(y => new CollectionUser + var organizationUser = await base.CreateAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - CollectionId = y.Id, - OrganizationUserId = organizationUser.Id, - ReadOnly = y.ReadOnly, - HidePasswords = y.HidePasswords, - }); - await dbContext.CollectionUsers.AddRangeAsync(collectionUsers); - await dbContext.SaveChangesAsync(); - } - - return organizationUser.Id; - } - - public async Task> CreateManyAsync(IEnumerable organizationUsers) - { - if (!organizationUsers.Any()) - { - return new List(); - } - - foreach (var organizationUser in organizationUsers) - { - organizationUser.SetNewId(); - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = Mapper.Map>(organizationUsers); - await dbContext.AddRangeAsync(entities); - await dbContext.SaveChangesAsync(); - } - - return organizationUsers.Select(u => u.Id).ToList(); - } - - public override async Task DeleteAsync(Core.Entities.OrganizationUser organizationUser) => await DeleteAsync(organizationUser.Id); - public async Task DeleteAsync(Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgUser = await dbContext.FindAsync(organizationUserId); - - dbContext.Remove(orgUser); - await dbContext.SaveChangesAsync(); - } - } - - public async Task DeleteManyAsync(IEnumerable organizationUserIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = await dbContext.OrganizationUsers - .Where(ou => organizationUserIds.Contains(ou.Id)) - .ToListAsync(); - - dbContext.OrganizationUsers.RemoveRange(entities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task>> GetByIdWithCollectionsAsync(Guid id) - { - var organizationUser = await base.GetByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = await ( - from ou in dbContext.OrganizationUsers - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId - where !ou.AccessAll && - ou.Id == id - select cu).ToListAsync(); - var collections = query.Select(cu => new SelectionReadOnly - { - Id = cu.CollectionId, - ReadOnly = cu.ReadOnly, - HidePasswords = cu.HidePasswords, - }); - return new Tuple>( - organizationUser, collections.ToList()); - } - } - - public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext) - .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); - return entity; - } - } - - public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext) - .FirstOrDefaultAsync(ou => ou.OrganizationId == organizationId && - !string.IsNullOrWhiteSpace(ou.Email) && - ou.Email == email); - return entity; - } - } - - public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) - { - var query = new OrganizationUserReadCountByFreeOrganizationAdminUserQuery(userId); - return await GetCountFromQuery(query); - } - - public async Task GetCountByOnlyOwnerAsync(Guid userId) - { - var query = new OrganizationUserReadCountByOnlyOwnerQuery(userId); - return await GetCountFromQuery(query); - } - - public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) - { - var query = new OrganizationUserReadCountByOrganizationIdEmailQuery(organizationId, email, onlyRegisteredUsers); - return await GetCountFromQuery(query); - } - - public async Task GetCountByOrganizationIdAsync(Guid organizationId) - { - var query = new OrganizationUserReadCountByOrganizationIdQuery(organizationId); - return await GetCountFromQuery(query); - } - - public async Task GetDetailsByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserUserDetailsViewQuery(); - var entity = await view.Run(dbContext).FirstOrDefaultAsync(ou => ou.Id == id); - return entity; - } - } - - public async Task>> GetDetailsByIdWithCollectionsAsync(Guid id) - { - var organizationUserUserDetails = await GetDetailsByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - join cu in dbContext.CollectionUsers on ou.Id equals cu.OrganizationUserId - where !ou.AccessAll && ou.Id == id - select cu; - var collections = await query.Select(cu => new SelectionReadOnly - { - Id = cu.CollectionId, - ReadOnly = cu.ReadOnly, - HidePasswords = cu.HidePasswords, - }).ToListAsync(); - return new Tuple>(organizationUserUserDetails, collections); - } - } - - public async Task GetDetailsByUserAsync(Guid userId, Guid organizationId, OrganizationUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserOrganizationDetailsViewQuery(); - var t = await (view.Run(dbContext)).ToArrayAsync(); - var entity = await view.Run(dbContext) - .FirstOrDefaultAsync(o => o.UserId == userId && - o.OrganizationId == organizationId && - (status == null || o.Status == status)); - return entity; - } - } - - public async Task> GetManyAsync(IEnumerable Ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where Ids.Contains(ou.Id) - select ou; - var data = await query.ToArrayAsync(); - return data; - } - } - - public async Task> GetManyByManyUsersAsync(IEnumerable userIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where userIds.Contains(ou.Id) - select ou; - return Mapper.Map>(await query.ToListAsync()); - } - } - - public async Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where ou.OrganizationId == organizationId && - (type == null || ou.Type == type) - select ou; - return Mapper.Map>(await query.ToListAsync()); - } - } - - public async Task> GetManyByUserAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where ou.UserId == userId - select ou; - return Mapper.Map>(await query.ToListAsync()); - } - } - - public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserUserDetailsViewQuery(); - var query = from ou in view.Run(dbContext) - where ou.OrganizationId == organizationId - select ou; - return await query.ToListAsync(); - } - } - - public async Task> GetManyDetailsByUserAsync(Guid userId, - OrganizationUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserOrganizationDetailsViewQuery(); - var query = from ou in view.Run(dbContext) - where ou.UserId == userId && - (status == null || ou.Status == status) - select ou; - var organizationUsers = await query.ToListAsync(); - return organizationUsers; - } - } - - public async Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where Ids.Contains(ou.Id) && ou.Status == OrganizationUserStatusType.Accepted - join u in dbContext.Users - on ou.UserId equals u.Id - where ou.OrganizationId == organizationId - select new { ou, u }; - var data = await query - .Select(x => new OrganizationUserPublicKey() + var dbContext = GetDatabaseContext(scope); + var availibleCollections = await ( + from c in dbContext.Collections + where c.OrganizationId == organizationUser.OrganizationId + select c).ToListAsync(); + var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); + var collectionUsers = filteredCollections.Select(y => new CollectionUser { - Id = x.ou.Id, - PublicKey = x.u.PublicKey, - }).ToListAsync(); - return data; - } - } - - public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) - { - await base.ReplaceAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - - var procedure = new OrganizationUserUpdateWithCollectionsQuery(obj, collections); - - var update = procedure.Update.Run(dbContext); - dbContext.UpdateRange(await update.ToListAsync()); - - var insert = procedure.Insert.Run(dbContext); - await dbContext.AddRangeAsync(await insert.ToListAsync()); - - dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); - await dbContext.SaveChangesAsync(); - } - } - - public async Task ReplaceManyAsync(IEnumerable organizationUsers) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - dbContext.UpdateRange(organizationUsers); - await dbContext.SaveChangesAsync(); - await UserBumpManyAccountRevisionDates(organizationUsers - .Where(ou => ou.UserId.HasValue) - .Select(ou => ou.UserId.Value).ToArray()); - } - } - - public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var usersQuery = from ou in dbContext.OrganizationUsers - join u in dbContext.Users - on ou.UserId equals u.Id into u_g - from u in u_g - where ou.OrganizationId == organizationId - select new { ou, u }; - var ouu = await usersQuery.ToListAsync(); - var ouEmails = ouu.Select(x => x.ou.Email); - var uEmails = ouu.Select(x => x.u.Email); - var knownEmails = from e in emails - where (ouEmails.Contains(e) || uEmails.Contains(e)) && - (!onlyRegisteredUsers && (uEmails.Contains(e) || ouEmails.Contains(e))) || - (onlyRegisteredUsers && uEmails.Contains(e)) - select e; - return knownEmails.ToList(); - } - } - - public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - - var procedure = new GroupUserUpdateGroupsQuery(orgUserId, groupIds); - - var insert = procedure.Insert.Run(dbContext); - var data = await insert.ToListAsync(); - await dbContext.AddRangeAsync(data); - - var delete = procedure.Delete.Run(dbContext); - var deleteData = await delete.ToListAsync(); - dbContext.RemoveRange(deleteData); - await UserBumpAccountRevisionDateByOrganizationUserId(orgUserId); - await dbContext.SaveChangesAsync(); - } - } - - public async Task UpsertManyAsync(IEnumerable organizationUsers) - { - var createUsers = new List(); - var replaceUsers = new List(); - foreach (var organizationUser in organizationUsers) - { - if (organizationUser.Id.Equals(default)) - { - createUsers.Add(organizationUser); - } - else - { - replaceUsers.Add(organizationUser); - } - } - - await CreateManyAsync(createUsers); - await ReplaceManyAsync(replaceUsers); - } - - public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.OrganizationUsers - .Include(e => e.User) - .Where(e => e.OrganizationId.Equals(organizationId) && - e.Type <= minRole && - e.Status == OrganizationUserStatusType.Confirmed) - .Select(e => new OrganizationUserUserDetails() - { - Id = e.Id, - Email = e.Email ?? e.User.Email + CollectionId = y.Id, + OrganizationUserId = organizationUser.Id, + ReadOnly = y.ReadOnly, + HidePasswords = y.HidePasswords, }); - return await query.ToListAsync(); - } - } - - public async Task RevokeAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgUser = await GetDbSet(dbContext).FindAsync(id); - if (orgUser != null) - { - dbContext.Update(orgUser); - orgUser.Status = OrganizationUserStatusType.Revoked; + await dbContext.CollectionUsers.AddRangeAsync(collectionUsers); await dbContext.SaveChangesAsync(); - if (orgUser.UserId.HasValue) + } + + return organizationUser.Id; + } + + public async Task> CreateManyAsync(IEnumerable organizationUsers) + { + if (!organizationUsers.Any()) + { + return new List(); + } + + foreach (var organizationUser in organizationUsers) + { + organizationUser.SetNewId(); + } + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = Mapper.Map>(organizationUsers); + await dbContext.AddRangeAsync(entities); + await dbContext.SaveChangesAsync(); + } + + return organizationUsers.Select(u => u.Id).ToList(); + } + + public override async Task DeleteAsync(Core.Entities.OrganizationUser organizationUser) => await DeleteAsync(organizationUser.Id); + public async Task DeleteAsync(Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgUser = await dbContext.FindAsync(organizationUserId); + + dbContext.Remove(orgUser); + await dbContext.SaveChangesAsync(); + } + } + + public async Task DeleteManyAsync(IEnumerable organizationUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = await dbContext.OrganizationUsers + .Where(ou => organizationUserIds.Contains(ou.Id)) + .ToListAsync(); + + dbContext.OrganizationUsers.RemoveRange(entities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task>> GetByIdWithCollectionsAsync(Guid id) + { + var organizationUser = await base.GetByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = await ( + from ou in dbContext.OrganizationUsers + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId + where !ou.AccessAll && + ou.Id == id + select cu).ToListAsync(); + var collections = query.Select(cu => new SelectionReadOnly { - await UserBumpAccountRevisionDate(orgUser.UserId.Value); + Id = cu.CollectionId, + ReadOnly = cu.ReadOnly, + HidePasswords = cu.HidePasswords, + }); + return new Tuple>( + organizationUser, collections.ToList()); + } + } + + public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext) + .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); + return entity; + } + } + + public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext) + .FirstOrDefaultAsync(ou => ou.OrganizationId == organizationId && + !string.IsNullOrWhiteSpace(ou.Email) && + ou.Email == email); + return entity; + } + } + + public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) + { + var query = new OrganizationUserReadCountByFreeOrganizationAdminUserQuery(userId); + return await GetCountFromQuery(query); + } + + public async Task GetCountByOnlyOwnerAsync(Guid userId) + { + var query = new OrganizationUserReadCountByOnlyOwnerQuery(userId); + return await GetCountFromQuery(query); + } + + public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) + { + var query = new OrganizationUserReadCountByOrganizationIdEmailQuery(organizationId, email, onlyRegisteredUsers); + return await GetCountFromQuery(query); + } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) + { + var query = new OrganizationUserReadCountByOrganizationIdQuery(organizationId); + return await GetCountFromQuery(query); + } + + public async Task GetDetailsByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserUserDetailsViewQuery(); + var entity = await view.Run(dbContext).FirstOrDefaultAsync(ou => ou.Id == id); + return entity; + } + } + + public async Task>> GetDetailsByIdWithCollectionsAsync(Guid id) + { + var organizationUserUserDetails = await GetDetailsByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + join cu in dbContext.CollectionUsers on ou.Id equals cu.OrganizationUserId + where !ou.AccessAll && ou.Id == id + select cu; + var collections = await query.Select(cu => new SelectionReadOnly + { + Id = cu.CollectionId, + ReadOnly = cu.ReadOnly, + HidePasswords = cu.HidePasswords, + }).ToListAsync(); + return new Tuple>(organizationUserUserDetails, collections); + } + } + + public async Task GetDetailsByUserAsync(Guid userId, Guid organizationId, OrganizationUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserOrganizationDetailsViewQuery(); + var t = await (view.Run(dbContext)).ToArrayAsync(); + var entity = await view.Run(dbContext) + .FirstOrDefaultAsync(o => o.UserId == userId && + o.OrganizationId == organizationId && + (status == null || o.Status == status)); + return entity; + } + } + + public async Task> GetManyAsync(IEnumerable Ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where Ids.Contains(ou.Id) + select ou; + var data = await query.ToArrayAsync(); + return data; + } + } + + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where userIds.Contains(ou.Id) + select ou; + return Mapper.Map>(await query.ToListAsync()); + } + } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where ou.OrganizationId == organizationId && + (type == null || ou.Type == type) + select ou; + return Mapper.Map>(await query.ToListAsync()); + } + } + + public async Task> GetManyByUserAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where ou.UserId == userId + select ou; + return Mapper.Map>(await query.ToListAsync()); + } + } + + public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserUserDetailsViewQuery(); + var query = from ou in view.Run(dbContext) + where ou.OrganizationId == organizationId + select ou; + return await query.ToListAsync(); + } + } + + public async Task> GetManyDetailsByUserAsync(Guid userId, + OrganizationUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserOrganizationDetailsViewQuery(); + var query = from ou in view.Run(dbContext) + where ou.UserId == userId && + (status == null || ou.Status == status) + select ou; + var organizationUsers = await query.ToListAsync(); + return organizationUsers; + } + } + + public async Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where Ids.Contains(ou.Id) && ou.Status == OrganizationUserStatusType.Accepted + join u in dbContext.Users + on ou.UserId equals u.Id + where ou.OrganizationId == organizationId + select new { ou, u }; + var data = await query + .Select(x => new OrganizationUserPublicKey() + { + Id = x.ou.Id, + PublicKey = x.u.PublicKey, + }).ToListAsync(); + return data; + } + } + + public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) + { + await base.ReplaceAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var procedure = new OrganizationUserUpdateWithCollectionsQuery(obj, collections); + + var update = procedure.Update.Run(dbContext); + dbContext.UpdateRange(await update.ToListAsync()); + + var insert = procedure.Insert.Run(dbContext); + await dbContext.AddRangeAsync(await insert.ToListAsync()); + + dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); + await dbContext.SaveChangesAsync(); + } + } + + public async Task ReplaceManyAsync(IEnumerable organizationUsers) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + dbContext.UpdateRange(organizationUsers); + await dbContext.SaveChangesAsync(); + await UserBumpManyAccountRevisionDates(organizationUsers + .Where(ou => ou.UserId.HasValue) + .Select(ou => ou.UserId.Value).ToArray()); + } + } + + public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var usersQuery = from ou in dbContext.OrganizationUsers + join u in dbContext.Users + on ou.UserId equals u.Id into u_g + from u in u_g + where ou.OrganizationId == organizationId + select new { ou, u }; + var ouu = await usersQuery.ToListAsync(); + var ouEmails = ouu.Select(x => x.ou.Email); + var uEmails = ouu.Select(x => x.u.Email); + var knownEmails = from e in emails + where (ouEmails.Contains(e) || uEmails.Contains(e)) && + (!onlyRegisteredUsers && (uEmails.Contains(e) || ouEmails.Contains(e))) || + (onlyRegisteredUsers && uEmails.Contains(e)) + select e; + return knownEmails.ToList(); + } + } + + public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var procedure = new GroupUserUpdateGroupsQuery(orgUserId, groupIds); + + var insert = procedure.Insert.Run(dbContext); + var data = await insert.ToListAsync(); + await dbContext.AddRangeAsync(data); + + var delete = procedure.Delete.Run(dbContext); + var deleteData = await delete.ToListAsync(); + dbContext.RemoveRange(deleteData); + await UserBumpAccountRevisionDateByOrganizationUserId(orgUserId); + await dbContext.SaveChangesAsync(); + } + } + + public async Task UpsertManyAsync(IEnumerable organizationUsers) + { + var createUsers = new List(); + var replaceUsers = new List(); + foreach (var organizationUser in organizationUsers) + { + if (organizationUser.Id.Equals(default)) + { + createUsers.Add(organizationUser); + } + else + { + replaceUsers.Add(organizationUser); + } + } + + await CreateManyAsync(createUsers); + await ReplaceManyAsync(replaceUsers); + } + + public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.OrganizationUsers + .Include(e => e.User) + .Where(e => e.OrganizationId.Equals(organizationId) && + e.Type <= minRole && + e.Status == OrganizationUserStatusType.Confirmed) + .Select(e => new OrganizationUserUserDetails() + { + Id = e.Id, + Email = e.Email ?? e.User.Email + }); + return await query.ToListAsync(); + } + } + + public async Task RevokeAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgUser = await GetDbSet(dbContext).FindAsync(id); + if (orgUser != null) + { + dbContext.Update(orgUser); + orgUser.Status = OrganizationUserStatusType.Revoked; + await dbContext.SaveChangesAsync(); + if (orgUser.UserId.HasValue) + { + await UserBumpAccountRevisionDate(orgUser.UserId.Value); + } } } } - } - public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) { - var dbContext = GetDatabaseContext(scope); - var orgUser = await GetDbSet(dbContext).FindAsync(id); - if (orgUser != null) + using (var scope = ServiceScopeFactory.CreateScope()) { - dbContext.Update(orgUser); - orgUser.Status = status; - await dbContext.SaveChangesAsync(); - if (orgUser.UserId.HasValue) + var dbContext = GetDatabaseContext(scope); + var orgUser = await GetDbSet(dbContext).FindAsync(id); + if (orgUser != null) { - await UserBumpAccountRevisionDate(orgUser.UserId.Value); + dbContext.Update(orgUser); + orgUser.Status = status; + await dbContext.SaveChangesAsync(); + if (orgUser.UserId.HasValue) + { + await UserBumpAccountRevisionDate(orgUser.UserId.Value); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs b/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs index 1a02c6aa7c..8d6a928096 100644 --- a/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs @@ -6,71 +6,72 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class PolicyRepository : Repository, IPolicyRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public PolicyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Policies) - { } - - public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) + public class PolicyRepository : Repository, IPolicyRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public PolicyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Policies) + { } + + public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Policies - .FirstOrDefaultAsync(p => p.OrganizationId == organizationId && p.Type == type); - return Mapper.Map(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Policies + .FirstOrDefaultAsync(p => p.OrganizationId == organizationId && p.Type == type); + return Mapper.Map(results); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Policies - .Where(p => p.OrganizationId == organizationId) - .ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Policies + .Where(p => p.OrganizationId == organizationId) + .ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByUserIdAsync(Guid userId) { - var dbContext = GetDatabaseContext(scope); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); - var query = new PolicyReadByUserIdQuery(userId); - var results = await query.Run(dbContext).ToListAsync(); - return Mapper.Map>(results); + var query = new PolicyReadByUserIdQuery(userId); + var results = await query.Run(dbContext).ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) { - var dbContext = GetDatabaseContext(scope); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); - var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); - var results = await query.Run(dbContext).ToListAsync(); - return Mapper.Map>(results); + var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); + var results = await query.Run(dbContext).ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) { - var dbContext = GetDatabaseContext(scope); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); - var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); - return await GetCountFromQuery(query); + var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); + return await GetCountFromQuery(query); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs index 5d17d38bbf..dd5271fcda 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs @@ -6,30 +6,31 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class ProviderOrganizationRepository : - Repository, IProviderOrganizationRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public ProviderOrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, context => context.ProviderOrganizations) - { } - - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + public class ProviderOrganizationRepository : + Repository, IProviderOrganizationRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public ProviderOrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.ProviderOrganizations) + { } + + public async Task> GetManyDetailsByProviderAsync(Guid providerId) { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(providerId); + var data = await query.Run(dbContext).ToListAsync(); + return data; + } + } + + public async Task GetByOrganizationId(Guid organizationId) + { + using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); - var query = new ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(providerId); - var data = await query.Run(dbContext).ToListAsync(); - return data; + return await GetDbSet(dbContext).Where(po => po.OrganizationId == organizationId).FirstOrDefaultAsync(); } } - - public async Task GetByOrganizationId(Guid organizationId) - { - using var scope = ServiceScopeFactory.CreateScope(); - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(po => po.OrganizationId == organizationId).FirstOrDefaultAsync(); - } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs index cf015b273f..75c8788c60 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs @@ -5,51 +5,52 @@ using Bit.Core.Repositories; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class ProviderRepository : Repository, IProviderRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - - public ProviderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, context => context.Providers) - { } - - public async Task> SearchAsync(string name, string userEmail, int skip, int take) + public class ProviderRepository : Repository, IProviderRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + + public ProviderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.Providers) + { } + + public async Task> SearchAsync(string name, string userEmail, int skip, int take) { - var dbContext = GetDatabaseContext(scope); - var query = !string.IsNullOrWhiteSpace(userEmail) ? - (from p in dbContext.Providers - join pu in dbContext.ProviderUsers - on p.Id equals pu.ProviderId - join u in dbContext.Users - on pu.UserId equals u.Id - where (string.IsNullOrWhiteSpace(name) || p.Name.Contains(name)) && - u.Email == userEmail - orderby p.CreationDate descending - select new { p, pu, u }).Skip(skip).Take(take).Select(x => x.p) : - (from p in dbContext.Providers - where string.IsNullOrWhiteSpace(name) || p.Name.Contains(name) - orderby p.CreationDate descending - select new { p }).Skip(skip).Take(take).Select(x => x.p); - var providers = await query.ToArrayAsync(); - return Mapper.Map>(providers); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = !string.IsNullOrWhiteSpace(userEmail) ? + (from p in dbContext.Providers + join pu in dbContext.ProviderUsers + on p.Id equals pu.ProviderId + join u in dbContext.Users + on pu.UserId equals u.Id + where (string.IsNullOrWhiteSpace(name) || p.Name.Contains(name)) && + u.Email == userEmail + orderby p.CreationDate descending + select new { p, pu, u }).Skip(skip).Take(take).Select(x => x.p) : + (from p in dbContext.Providers + where string.IsNullOrWhiteSpace(name) || p.Name.Contains(name) + orderby p.CreationDate descending + select new { p }).Skip(skip).Take(take).Select(x => x.p); + var providers = await query.ToArrayAsync(); + return Mapper.Map>(providers); + } } - } - public async Task> GetManyAbilitiesAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyAbilitiesAsync() { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext) - .Select(e => new ProviderAbility - { - Enabled = e.Enabled, - Id = e.Id, - UseEvents = e.UseEvents, - }).ToListAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext) + .Select(e => new ProviderAbility + { + Enabled = e.Enabled, + Id = e.Id, + UseEvents = e.UseEvents, + }).ToListAsync(); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs index 3aac5cca9d..87d82a542e 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs @@ -7,153 +7,154 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class ProviderUserRepository : - Repository, IProviderUserRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public ProviderUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.ProviderUsers) - { } + public class ProviderUserRepository : + Repository, IProviderUserRepository + { + public ProviderUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.ProviderUsers) + { } - public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - where pu.ProviderId == providerId && - ((!onlyRegisteredUsers && (pu.Email == email || u.Email == email)) || - (onlyRegisteredUsers && u.Email == email)) - select new { pu, u }; - return await query.CountAsync(); - } - } - - public async Task> GetManyAsync(IEnumerable ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.ProviderUsers.Where(item => ids.Contains(item.Id)); - return await query.ToArrayAsync(); - } - } - - public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.ProviderUsers.Where(pu => pu.ProviderId.Equals(providerId) && - (type != null && pu.Type.Equals(type))); - return await query.ToArrayAsync(); - } - } - - public async Task DeleteManyAsync(IEnumerable providerUserIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - await UserBumpAccountRevisionDateByProviderUserIds(providerUserIds.ToArray()); - var entities = dbContext.ProviderUsers.Where(pu => providerUserIds.Contains(pu.Id)); - dbContext.ProviderUsers.RemoveRange(entities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task> GetManyByUserAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - where pu.UserId == userId - select pu; - return await query.ToArrayAsync(); - } - } - public async Task GetByProviderUserAsync(Guid providerId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - where pu.UserId == userId && - pu.ProviderId == providerId - select pu; - return await query.FirstOrDefaultAsync(); - } - } - public async Task> GetManyDetailsByProviderAsync(Guid providerId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - select new { pu, u }; - var data = await view.Where(e => e.pu.ProviderId == providerId).Select(e => new ProviderUserUserDetails + using (var scope = ServiceScopeFactory.CreateScope()) { - Id = e.pu.Id, - UserId = e.pu.UserId, - ProviderId = e.pu.ProviderId, - Name = e.u.Name, - Email = e.u.Email ?? e.pu.Email, - Status = e.pu.Status, - Type = e.pu.Type, - Permissions = e.pu.Permissions, - }).ToArrayAsync(); - return data; + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + where pu.ProviderId == providerId && + ((!onlyRegisteredUsers && (pu.Email == email || u.Email == email)) || + (onlyRegisteredUsers && u.Email == email)) + select new { pu, u }; + return await query.CountAsync(); + } } - } - public async Task> GetManyDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyAsync(IEnumerable ids) { - var dbContext = GetDatabaseContext(scope); - var query = new ProviderUserProviderDetailsReadByUserIdStatusQuery(userId, status); - var data = await query.Run(dbContext).ToArrayAsync(); - return data; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.ProviderUsers.Where(item => ids.Contains(item.Id)); + return await query.ToArrayAsync(); + } } - } - public async Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null) { - var dbContext = GetDatabaseContext(scope); - var query = new UserReadPublicKeysByProviderUserIdsQuery(providerId, Ids); - var data = await query.Run(dbContext).ToListAsync(); - return data; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.ProviderUsers.Where(pu => pu.ProviderId.Equals(providerId) && + (type != null && pu.Type.Equals(type))); + return await query.ToArrayAsync(); + } } - } - public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task DeleteManyAsync(IEnumerable providerUserIds) { - var dbContext = GetDatabaseContext(scope); - var view = new ProviderUserOrganizationDetailsViewQuery(); - var query = from ou in view.Run(dbContext) - where ou.UserId == userId && - (status == null || ou.Status == status) - select ou; - var organizationUsers = await query.ToListAsync(); - return organizationUsers; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + await UserBumpAccountRevisionDateByProviderUserIds(providerUserIds.ToArray()); + var entities = dbContext.ProviderUsers.Where(pu => providerUserIds.Contains(pu.Id)); + dbContext.ProviderUsers.RemoveRange(entities); + await dbContext.SaveChangesAsync(); + } } - } - public async Task GetCountByOnlyOwnerAsync(Guid userId) - { - var query = new ProviderUserReadCountByOnlyOwnerQuery(userId); - return await GetCountFromQuery(query); + public async Task> GetManyByUserAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + where pu.UserId == userId + select pu; + return await query.ToArrayAsync(); + } + } + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + where pu.UserId == userId && + pu.ProviderId == providerId + select pu; + return await query.FirstOrDefaultAsync(); + } + } + public async Task> GetManyDetailsByProviderAsync(Guid providerId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + select new { pu, u }; + var data = await view.Where(e => e.pu.ProviderId == providerId).Select(e => new ProviderUserUserDetails + { + Id = e.pu.Id, + UserId = e.pu.UserId, + ProviderId = e.pu.ProviderId, + Name = e.u.Name, + Email = e.u.Email ?? e.pu.Email, + Status = e.pu.Status, + Type = e.pu.Type, + Permissions = e.pu.Permissions, + }).ToArrayAsync(); + return data; + } + } + + public async Task> GetManyDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new ProviderUserProviderDetailsReadByUserIdStatusQuery(userId, status); + var data = await query.Run(dbContext).ToArrayAsync(); + return data; + } + } + + public async Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new UserReadPublicKeysByProviderUserIdsQuery(providerId, Ids); + var data = await query.Run(dbContext).ToListAsync(); + return data; + } + } + + public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new ProviderUserOrganizationDetailsViewQuery(); + var query = from ou in view.Run(dbContext) + where ou.UserId == userId && + (status == null || ou.Status == status) + select ou; + var organizationUsers = await query.ToListAsync(); + return organizationUsers; + } + } + + public async Task GetCountByOnlyOwnerAsync(Guid userId) + { + var query = new ProviderUserReadCountByOnlyOwnerQuery(userId); + return await GetCountFromQuery(query); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs index 7d676c0210..38c451c3f9 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs @@ -1,36 +1,37 @@ using Bit.Core.Utilities; using Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CipherDetailsQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid? _userId; - private readonly bool _ignoreFolders; - public CipherDetailsQuery(Guid? userId, bool ignoreFolders = false) + public class CipherDetailsQuery : IQuery { - _userId = userId; - _ignoreFolders = ignoreFolders; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - select new CipherDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), - FolderId = (_ignoreFolders || !_userId.HasValue || c.Folders == null || !c.Folders.Contains(_userId.Value.ToString())) ? - null : - CoreHelpers.LoadClassFromJsonData>(c.Folders)[_userId.Value], - }; - return query; + private readonly Guid? _userId; + private readonly bool _ignoreFolders; + public CipherDetailsQuery(Guid? userId, bool ignoreFolders = false) + { + _userId = userId; + _ignoreFolders = ignoreFolders; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + select new CipherDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), + FolderId = (_ignoreFolders || !_userId.HasValue || c.Folders == null || !c.Folders.Contains(_userId.Value.ToString())) ? + null : + CoreHelpers.LoadClassFromJsonData>(c.Folders)[_userId.Value], + }; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs index b93954a521..2bca0c11a3 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs @@ -1,38 +1,39 @@ using Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CipherOrganizationDetailsReadByIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _cipherId; - - public CipherOrganizationDetailsReadByIdQuery(Guid cipherId) + public class CipherOrganizationDetailsReadByIdQuery : IQuery { - _cipherId = cipherId; - } + private readonly Guid _cipherId; - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join o in dbContext.Organizations - on c.OrganizationId equals o.Id into o_g - from o in o_g.DefaultIfEmpty() - where c.Id == _cipherId - select new CipherOrganizationDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Favorites = c.Favorites, - Folders = c.Folders, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - OrganizationUseTotp = o.UseTotp, - }; - return query; + public CipherOrganizationDetailsReadByIdQuery(Guid cipherId) + { + _cipherId = cipherId; + } + + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join o in dbContext.Organizations + on c.OrganizationId equals o.Id into o_g + from o in o_g.DefaultIfEmpty() + where c.Id == _cipherId + select new CipherOrganizationDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Favorites = c.Favorites, + Folders = c.Folders, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + OrganizationUseTotp = o.UseTotp, + }; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs index 578bd7701a..84d2779a1c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs @@ -1,37 +1,38 @@ using Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CipherOrganizationDetailsReadByOrgizationIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; + public class CipherOrganizationDetailsReadByOrgizationIdQuery : IQuery + { + private readonly Guid _organizationId; - public CipherOrganizationDetailsReadByOrgizationIdQuery(Guid organizationId) - { - _organizationId = organizationId; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join o in dbContext.Organizations - on c.OrganizationId equals o.Id into o_g - from o in o_g.DefaultIfEmpty() - where c.OrganizationId == _organizationId - select new CipherOrganizationDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Favorites = c.Favorites, - Folders = c.Folders, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - OrganizationUseTotp = o.UseTotp, - }; - return query; + public CipherOrganizationDetailsReadByOrgizationIdQuery(Guid organizationId) + { + _organizationId = organizationId; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join o in dbContext.Organizations + on c.OrganizationId equals o.Id into o_g + from o in o_g.DefaultIfEmpty() + where c.OrganizationId == _organizationId + select new CipherOrganizationDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Favorites = c.Favorites, + Folders = c.Folders, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + OrganizationUseTotp = o.UseTotp, + }; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs index 4ac3718c71..ab9a32b52e 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs @@ -1,55 +1,56 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CipherReadCanEditByIdUserIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - private readonly Guid _cipherId; - - public CipherReadCanEditByIdUserIdQuery(Guid userId, Guid cipherId) + public class CipherReadCanEditByIdUserIdQuery : IQuery { - _userId = userId; - _cipherId = cipherId; - } + private readonly Guid _userId; + private readonly Guid _cipherId; - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join o in dbContext.Organizations - on c.OrganizationId equals o.Id into o_g - from o in o_g.DefaultIfEmpty() - where !c.UserId.HasValue - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId into ou_g - from ou in ou_g.DefaultIfEmpty() - where ou.UserId == _userId - join cc in dbContext.CollectionCiphers - on c.Id equals cc.CipherId into cc_g - from cc in cc_g.DefaultIfEmpty() - where !c.UserId.HasValue && !ou.AccessAll - join cu in dbContext.CollectionUsers - on cc.CollectionId equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where ou.Id == cu.OrganizationUserId - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where !c.UserId.HasValue && cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == cc.CollectionId && - (c.Id == _cipherId && - (c.UserId == _userId || - (!c.UserId.HasValue && ou.Status == OrganizationUserStatusType.Confirmed && o.Enabled && - (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null)))) && - (c.UserId.HasValue || ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) - select c; - return query; + public CipherReadCanEditByIdUserIdQuery(Guid userId, Guid cipherId) + { + _userId = userId; + _cipherId = cipherId; + } + + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join o in dbContext.Organizations + on c.OrganizationId equals o.Id into o_g + from o in o_g.DefaultIfEmpty() + where !c.UserId.HasValue + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId into ou_g + from ou in ou_g.DefaultIfEmpty() + where ou.UserId == _userId + join cc in dbContext.CollectionCiphers + on c.Id equals cc.CipherId into cc_g + from cc in cc_g.DefaultIfEmpty() + where !c.UserId.HasValue && !ou.AccessAll + join cu in dbContext.CollectionUsers + on cc.CollectionId equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where ou.Id == cu.OrganizationUserId + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where !c.UserId.HasValue && cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == cc.CollectionId && + (c.Id == _cipherId && + (c.UserId == _userId || + (!c.UserId.HasValue && ou.Status == OrganizationUserStatusType.Confirmed && o.Enabled && + (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null)))) && + (c.UserId.HasValue || ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) + select c; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs index 859b26182c..25be9135a4 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs @@ -2,64 +2,65 @@ using Bit.Core.Enums; using CollectionCipher = Bit.Infrastructure.EntityFramework.Models.CollectionCipher; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CipherUpdateCollectionsQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Cipher _cipher; - private readonly IEnumerable _collectionIds; - - public CipherUpdateCollectionsQuery(Cipher cipher, IEnumerable collectionIds) + public class CipherUpdateCollectionsQuery : IQuery { - _cipher = cipher; - _collectionIds = collectionIds; - } + private readonly Cipher _cipher; + private readonly IEnumerable _collectionIds; - public virtual IQueryable Run(DatabaseContext dbContext) - { - if (!_cipher.OrganizationId.HasValue || !_collectionIds.Any()) + public CipherUpdateCollectionsQuery(Cipher cipher, IEnumerable collectionIds) { - return null; + _cipher = cipher; + _collectionIds = collectionIds; } - var availibleCollections = !_cipher.UserId.HasValue ? - from c in dbContext.Collections - where c.OrganizationId == _cipher.OrganizationId - select c.Id : - from c in dbContext.Collections - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId - where ou.UserId == _cipher.UserId - join cu in dbContext.CollectionUsers - on c.Id equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on c.Id equals cg.CollectionId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && gu.GroupId == cg.GroupId && - o.Id == _cipher.OrganizationId && - o.Enabled && - ou.Status == OrganizationUserStatusType.Confirmed && - (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) - select c.Id; - - if (!availibleCollections.Any()) + public virtual IQueryable Run(DatabaseContext dbContext) { - return null; - } + if (!_cipher.OrganizationId.HasValue || !_collectionIds.Any()) + { + return null; + } - var query = from c in availibleCollections - select new CollectionCipher { CollectionId = c, CipherId = _cipher.Id }; - return query; + var availibleCollections = !_cipher.UserId.HasValue ? + from c in dbContext.Collections + where c.OrganizationId == _cipher.OrganizationId + select c.Id : + from c in dbContext.Collections + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId + where ou.UserId == _cipher.UserId + join cu in dbContext.CollectionUsers + on c.Id equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on c.Id equals cg.CollectionId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && gu.GroupId == cg.GroupId && + o.Id == _cipher.OrganizationId && + o.Enabled && + ou.Status == OrganizationUserStatusType.Confirmed && + (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) + select c.Id; + + if (!availibleCollections.Any()) + { + return null; + } + + var query = from c in availibleCollections + select new CollectionCipher { CollectionId = c, CipherId = _cipher.Id }; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs index e494aec1f3..51fcb15fd5 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs @@ -1,19 +1,20 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CollectionCipherReadByUserIdCipherIdQuery : CollectionCipherReadByUserIdQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _cipherId; - - public CollectionCipherReadByUserIdCipherIdQuery(Guid userId, Guid cipherId) : base(userId) + public class CollectionCipherReadByUserIdCipherIdQuery : CollectionCipherReadByUserIdQuery { - _cipherId = cipherId; - } + private readonly Guid _cipherId; - public override IQueryable Run(DatabaseContext dbContext) - { - var query = base.Run(dbContext); - return query.Where(x => x.CipherId == _cipherId); + public CollectionCipherReadByUserIdCipherIdQuery(Guid userId, Guid cipherId) : base(userId) + { + _cipherId = cipherId; + } + + public override IQueryable Run(DatabaseContext dbContext) + { + var query = base.Run(dbContext); + return query.Where(x => x.CipherId == _cipherId); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs index 156707b46c..6c8e17372d 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs @@ -1,43 +1,44 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CollectionCipherReadByUserIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - - public CollectionCipherReadByUserIdQuery(Guid userId) + public class CollectionCipherReadByUserIdQuery : IQuery { - _userId = userId; - } + private readonly Guid _userId; - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from cc in dbContext.CollectionCiphers - join c in dbContext.Collections - on cc.CollectionId equals c.Id - join ou in dbContext.OrganizationUsers - on c.OrganizationId equals ou.OrganizationId - where ou.UserId == _userId - join cu in dbContext.CollectionUsers - on c.Id equals cu.CollectionId into cu_g - from cu in cu_g - where ou.AccessAll && cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g - join cg in dbContext.CollectionGroups - on cc.CollectionId equals cg.CollectionId into cg_g - from cg in cg_g - where g.AccessAll && cg.GroupId == gu.GroupId && - ou.Status == OrganizationUserStatusType.Confirmed && - (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) - select cc; - return query; + public CollectionCipherReadByUserIdQuery(Guid userId) + { + _userId = userId; + } + + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from cc in dbContext.CollectionCiphers + join c in dbContext.Collections + on cc.CollectionId equals c.Id + join ou in dbContext.OrganizationUsers + on c.OrganizationId equals ou.OrganizationId + where ou.UserId == _userId + join cu in dbContext.CollectionUsers + on c.Id equals cu.CollectionId into cu_g + from cu in cu_g + where ou.AccessAll && cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g + join cg in dbContext.CollectionGroups + on cc.CollectionId equals cg.CollectionId into cg_g + from cg in cg_g + where g.AccessAll && cg.GroupId == gu.GroupId && + ou.Status == OrganizationUserStatusType.Confirmed && + (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) + select cc; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs index 90e800398a..de878db34c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs @@ -1,21 +1,22 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CollectionReadCountByOrganizationIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; - - public CollectionReadCountByOrganizationIdQuery(Guid organizationId) + public class CollectionReadCountByOrganizationIdQuery : IQuery { - _organizationId = organizationId; - } + private readonly Guid _organizationId; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Collections - where c.OrganizationId == _organizationId - select c; - return query; + public CollectionReadCountByOrganizationIdQuery(Guid organizationId) + { + _organizationId = organizationId; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Collections + where c.OrganizationId == _organizationId + select c; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs index db2d911906..45023772a3 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs @@ -2,115 +2,116 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class CollectionUserUpdateUsersQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public readonly CollectionUserUpdateUsersInsertQuery Insert; - public readonly CollectionUserUpdateUsersUpdateQuery Update; - public readonly CollectionUserUpdateUsersDeleteQuery Delete; - - public CollectionUserUpdateUsersQuery(Guid collectionId, IEnumerable users) + public class CollectionUserUpdateUsersQuery { - Insert = new CollectionUserUpdateUsersInsertQuery(collectionId, users); - Update = new CollectionUserUpdateUsersUpdateQuery(collectionId, users); - Delete = new CollectionUserUpdateUsersDeleteQuery(collectionId, users); - } -} + public readonly CollectionUserUpdateUsersInsertQuery Insert; + public readonly CollectionUserUpdateUsersUpdateQuery Update; + public readonly CollectionUserUpdateUsersDeleteQuery Delete; -public class CollectionUserUpdateUsersInsertQuery : IQuery -{ - private readonly Guid _collectionId; - private readonly IEnumerable _users; - - public CollectionUserUpdateUsersInsertQuery(Guid collectionId, IEnumerable users) - { - _collectionId = collectionId; - _users = users; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; - var organizationUserIds = _users.Select(u => u.Id); - var insertQuery = from ou in dbContext.OrganizationUsers - where - organizationUserIds.Contains(ou.Id) && - ou.OrganizationId == orgId && - !dbContext.CollectionUsers.Any( - x => x.CollectionId != _collectionId && x.OrganizationUserId == ou.Id) - select ou; - return insertQuery; - } - - public async Task> BuildInMemory(DatabaseContext dbContext) - { - var data = await Run(dbContext).ToListAsync(); - var collectionUsers = data.Select(x => new CollectionUser() + public CollectionUserUpdateUsersQuery(Guid collectionId, IEnumerable users) { - CollectionId = _collectionId, - OrganizationUserId = x.Id, - ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).ReadOnly, - HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).HidePasswords, - }); - return collectionUsers; - } -} - -public class CollectionUserUpdateUsersUpdateQuery : IQuery -{ - private readonly Guid _collectionId; - private readonly IEnumerable _users; - - public CollectionUserUpdateUsersUpdateQuery(Guid collectionId, IEnumerable users) - { - _collectionId = collectionId; - _users = users; + Insert = new CollectionUserUpdateUsersInsertQuery(collectionId, users); + Update = new CollectionUserUpdateUsersUpdateQuery(collectionId, users); + Delete = new CollectionUserUpdateUsersDeleteQuery(collectionId, users); + } } - public IQueryable Run(DatabaseContext dbContext) + public class CollectionUserUpdateUsersInsertQuery : IQuery { - var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; - var ids = _users.Select(x => x.Id); - var updateQuery = from target in dbContext.CollectionUsers - where target.CollectionId == _collectionId && - ids.Contains(target.OrganizationUserId) - select target; - return updateQuery; - } + private readonly Guid _collectionId; + private readonly IEnumerable _users; - public async Task> BuildInMemory(DatabaseContext dbContext) - { - var data = await Run(dbContext).ToListAsync(); - var collectionUsers = data.Select(x => new CollectionUser + public CollectionUserUpdateUsersInsertQuery(Guid collectionId, IEnumerable users) { - CollectionId = _collectionId, - OrganizationUserId = x.OrganizationUserId, - ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).ReadOnly, - HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).HidePasswords, - }); - return collectionUsers; - } -} - -public class CollectionUserUpdateUsersDeleteQuery : IQuery -{ - private readonly Guid _collectionId; - private readonly IEnumerable _users; - - public CollectionUserUpdateUsersDeleteQuery(Guid collectionId, IEnumerable users) - { - _collectionId = collectionId; - _users = users; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; - var deleteQuery = from cu in dbContext.CollectionUsers - where !dbContext.Users.Any( - u => u.Id == cu.OrganizationUserId) - select cu; - return deleteQuery; + _collectionId = collectionId; + _users = users; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; + var organizationUserIds = _users.Select(u => u.Id); + var insertQuery = from ou in dbContext.OrganizationUsers + where + organizationUserIds.Contains(ou.Id) && + ou.OrganizationId == orgId && + !dbContext.CollectionUsers.Any( + x => x.CollectionId != _collectionId && x.OrganizationUserId == ou.Id) + select ou; + return insertQuery; + } + + public async Task> BuildInMemory(DatabaseContext dbContext) + { + var data = await Run(dbContext).ToListAsync(); + var collectionUsers = data.Select(x => new CollectionUser() + { + CollectionId = _collectionId, + OrganizationUserId = x.Id, + ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).ReadOnly, + HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).HidePasswords, + }); + return collectionUsers; + } + } + + public class CollectionUserUpdateUsersUpdateQuery : IQuery + { + private readonly Guid _collectionId; + private readonly IEnumerable _users; + + public CollectionUserUpdateUsersUpdateQuery(Guid collectionId, IEnumerable users) + { + _collectionId = collectionId; + _users = users; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; + var ids = _users.Select(x => x.Id); + var updateQuery = from target in dbContext.CollectionUsers + where target.CollectionId == _collectionId && + ids.Contains(target.OrganizationUserId) + select target; + return updateQuery; + } + + public async Task> BuildInMemory(DatabaseContext dbContext) + { + var data = await Run(dbContext).ToListAsync(); + var collectionUsers = data.Select(x => new CollectionUser + { + CollectionId = _collectionId, + OrganizationUserId = x.OrganizationUserId, + ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).ReadOnly, + HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).HidePasswords, + }); + return collectionUsers; + } + } + + public class CollectionUserUpdateUsersDeleteQuery : IQuery + { + private readonly Guid _collectionId; + private readonly IEnumerable _users; + + public CollectionUserUpdateUsersDeleteQuery(Guid collectionId, IEnumerable users) + { + _collectionId = collectionId; + _users = users; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; + var deleteQuery = from cu in dbContext.CollectionUsers + where !dbContext.Users.Any( + u => u.Id == cu.OrganizationUserId) + select cu; + return deleteQuery; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs index 2ad2149aef..24c1bda8da 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs @@ -1,37 +1,38 @@ using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EmergencyAccessDetailsViewQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public IQueryable Run(DatabaseContext dbContext) + public class EmergencyAccessDetailsViewQuery : IQuery { - var query = from ea in dbContext.EmergencyAccesses - join grantee in dbContext.Users - on ea.GranteeId equals grantee.Id into grantee_g - from grantee in grantee_g.DefaultIfEmpty() - join grantor in dbContext.Users - on ea.GrantorId equals grantor.Id into grantor_g - from grantor in grantor_g.DefaultIfEmpty() - select new { ea, grantee, grantor }; - return query.Select(x => new EmergencyAccessDetails + public IQueryable Run(DatabaseContext dbContext) { - Id = x.ea.Id, - GrantorId = x.ea.GrantorId, - GranteeId = x.ea.GranteeId, - Email = x.ea.Email, - KeyEncrypted = x.ea.KeyEncrypted, - Type = x.ea.Type, - Status = x.ea.Status, - WaitTimeDays = x.ea.WaitTimeDays, - RecoveryInitiatedDate = x.ea.RecoveryInitiatedDate, - LastNotificationDate = x.ea.LastNotificationDate, - CreationDate = x.ea.CreationDate, - RevisionDate = x.ea.RevisionDate, - GranteeName = x.grantee.Name, - GranteeEmail = x.grantee.Email, - GrantorName = x.grantor.Name, - GrantorEmail = x.grantor.Email, - }); + var query = from ea in dbContext.EmergencyAccesses + join grantee in dbContext.Users + on ea.GranteeId equals grantee.Id into grantee_g + from grantee in grantee_g.DefaultIfEmpty() + join grantor in dbContext.Users + on ea.GrantorId equals grantor.Id into grantor_g + from grantor in grantor_g.DefaultIfEmpty() + select new { ea, grantee, grantor }; + return query.Select(x => new EmergencyAccessDetails + { + Id = x.ea.Id, + GrantorId = x.ea.GrantorId, + GranteeId = x.ea.GranteeId, + Email = x.ea.Email, + KeyEncrypted = x.ea.KeyEncrypted, + Type = x.ea.Type, + Status = x.ea.Status, + WaitTimeDays = x.ea.WaitTimeDays, + RecoveryInitiatedDate = x.ea.RecoveryInitiatedDate, + LastNotificationDate = x.ea.LastNotificationDate, + CreationDate = x.ea.CreationDate, + RevisionDate = x.ea.RevisionDate, + GranteeName = x.grantee.Name, + GranteeEmail = x.grantee.Email, + GrantorName = x.grantor.Name, + GrantorEmail = x.grantor.Email, + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs index 3a09fa8577..d28ce13728 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs @@ -1,30 +1,31 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EmergencyAccessReadCountByGrantorIdEmailQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _grantorId; - private readonly string _email; - private readonly bool _onlyRegisteredUsers; - - public EmergencyAccessReadCountByGrantorIdEmailQuery(Guid grantorId, string email, bool onlyRegisteredUsers) + public class EmergencyAccessReadCountByGrantorIdEmailQuery : IQuery { - _grantorId = grantorId; - _email = email; - _onlyRegisteredUsers = onlyRegisteredUsers; - } + private readonly Guid _grantorId; + private readonly string _email; + private readonly bool _onlyRegisteredUsers; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ea in dbContext.EmergencyAccesses - join u in dbContext.Users - on ea.GranteeId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - where ea.GrantorId == _grantorId && - ((!_onlyRegisteredUsers && (ea.Email == _email || u.Email == _email)) - || (_onlyRegisteredUsers && u.Email == _email)) - select ea; - return query; + public EmergencyAccessReadCountByGrantorIdEmailQuery(Guid grantorId, string email, bool onlyRegisteredUsers) + { + _grantorId = grantorId; + _email = email; + _onlyRegisteredUsers = onlyRegisteredUsers; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ea in dbContext.EmergencyAccesses + join u in dbContext.Users + on ea.GranteeId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + where ea.GrantorId == _grantorId && + ((!_onlyRegisteredUsers && (ea.Email == _email || u.Email == _email)) + || (_onlyRegisteredUsers && u.Email == _email)) + select ea; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs index 570f3a2494..d94d130ef6 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs @@ -2,46 +2,47 @@ using Bit.Core.Models.Data; using Event = Bit.Infrastructure.EntityFramework.Models.Event; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EventReadPageByCipherIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Cipher _cipher; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; - - public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) + public class EventReadPageByCipherIdQuery : IQuery { - _cipher = cipher; - _startDate = startDate; - _endDate = endDate; - _beforeDate = null; - _pageOptions = pageOptions; - } + private readonly Cipher _cipher; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; - public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _cipher = cipher; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } + public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + _cipher = cipher; + _startDate = startDate; + _endDate = endDate; + _beforeDate = null; + _pageOptions = pageOptions; + } - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate == null || e.Date < _beforeDate.Value) && - ((!_cipher.OrganizationId.HasValue && !e.OrganizationId.HasValue) || - (_cipher.OrganizationId.HasValue && _cipher.OrganizationId == e.OrganizationId)) && - ((!_cipher.UserId.HasValue && !e.UserId.HasValue) || - (_cipher.UserId.HasValue && _cipher.UserId == e.UserId)) && - _cipher.Id == e.CipherId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); + public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _cipher = cipher; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate == null || e.Date < _beforeDate.Value) && + ((!_cipher.OrganizationId.HasValue && !e.OrganizationId.HasValue) || + (_cipher.OrganizationId.HasValue && _cipher.OrganizationId == e.OrganizationId)) && + ((!_cipher.UserId.HasValue && !e.UserId.HasValue) || + (_cipher.UserId.HasValue && _cipher.UserId == e.UserId)) && + _cipher.Id == e.CipherId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs index 8e49ca2394..f9553dd383 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs @@ -1,38 +1,39 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EventReadPageByOrganizationIdActingUserIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; - private readonly Guid _actingUserId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; - - public EventReadPageByOrganizationIdActingUserIdQuery(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + public class EventReadPageByOrganizationIdActingUserIdQuery : IQuery { - _organizationId = organizationId; - _actingUserId = actingUserId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } + private readonly Guid _organizationId; + private readonly Guid _actingUserId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.OrganizationId == _organizationId && - e.ActingUserId == _actingUserId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); + public EventReadPageByOrganizationIdActingUserIdQuery(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _organizationId = organizationId; + _actingUserId = actingUserId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.OrganizationId == _organizationId && + e.ActingUserId == _actingUserId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs index ce0de6afc6..261bef32a2 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs @@ -1,35 +1,36 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EventReadPageByOrganizationIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; - - public EventReadPageByOrganizationIdQuery(Guid organizationId, DateTime startDate, - DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + public class EventReadPageByOrganizationIdQuery : IQuery { - _organizationId = organizationId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } + private readonly Guid _organizationId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.OrganizationId == _organizationId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); + public EventReadPageByOrganizationIdQuery(Guid organizationId, DateTime startDate, + DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _organizationId = organizationId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.OrganizationId == _organizationId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs index 171b4e26c7..4b08ecf2ba 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs @@ -1,38 +1,39 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EventReadPageByProviderIdActingUserIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _providerId; - private readonly Guid _actingUserId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; - - public EventReadPageByProviderIdActingUserIdQuery(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + public class EventReadPageByProviderIdActingUserIdQuery : IQuery { - _providerId = providerId; - _actingUserId = actingUserId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } + private readonly Guid _providerId; + private readonly Guid _actingUserId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.ProviderId == _providerId && - e.ActingUserId == _actingUserId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); + public EventReadPageByProviderIdActingUserIdQuery(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _providerId = providerId; + _actingUserId = actingUserId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.ProviderId == _providerId && + e.ActingUserId == _actingUserId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs index 52421b9e94..49e8f518b3 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs @@ -1,35 +1,36 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EventReadPageByProviderIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _providerId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; - - public EventReadPageByProviderIdQuery(Guid providerId, DateTime startDate, - DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + public class EventReadPageByProviderIdQuery : IQuery { - _providerId = providerId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } + private readonly Guid _providerId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.ProviderId == _providerId && e.OrganizationId == null - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); + public EventReadPageByProviderIdQuery(Guid providerId, DateTime startDate, + DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _providerId = providerId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.ProviderId == _providerId && e.OrganizationId == null + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs index d173c4842b..3e7ff4cc3c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs @@ -1,36 +1,37 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class EventReadPageByUserIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; - - public EventReadPageByUserIdQuery(Guid userId, DateTime startDate, - DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + public class EventReadPageByUserIdQuery : IQuery { - _userId = userId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } + private readonly Guid _userId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - !e.OrganizationId.HasValue && - e.ActingUserId == _userId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); + public EventReadPageByUserIdQuery(Guid userId, DateTime startDate, + DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _userId = userId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + !e.OrganizationId.HasValue && + e.ActingUserId == _userId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs index dacbabb280..580199cafe 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs @@ -1,68 +1,69 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class GroupUserUpdateGroupsQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public readonly GroupUserUpdateGroupsInsertQuery Insert; - public readonly GroupUserUpdateGroupsDeleteQuery Delete; - - public GroupUserUpdateGroupsQuery(Guid organizationUserId, IEnumerable groupIds) + public class GroupUserUpdateGroupsQuery { - Insert = new GroupUserUpdateGroupsInsertQuery(organizationUserId, groupIds); - Delete = new GroupUserUpdateGroupsDeleteQuery(organizationUserId, groupIds); - } -} + public readonly GroupUserUpdateGroupsInsertQuery Insert; + public readonly GroupUserUpdateGroupsDeleteQuery Delete; -public class GroupUserUpdateGroupsInsertQuery : IQuery -{ - private readonly Guid _organizationUserId; - private readonly IEnumerable _groupIds; - - public GroupUserUpdateGroupsInsertQuery(Guid organizationUserId, IEnumerable collections) - { - _organizationUserId = organizationUserId; - _groupIds = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgUser = from ou in dbContext.OrganizationUsers - where ou.Id == _organizationUserId - select ou; - var groupIdEntities = dbContext.Groups.Where(x => _groupIds.Contains(x.Id)); - var query = from g in dbContext.Groups - join ou in orgUser - on g.OrganizationId equals ou.OrganizationId - join gie in groupIdEntities - on g.Id equals gie.Id - where !dbContext.GroupUsers.Any(gu => _groupIds.Contains(gu.GroupId) && gu.OrganizationUserId == _organizationUserId) - select g; - return query.Select(x => new GroupUser + public GroupUserUpdateGroupsQuery(Guid organizationUserId, IEnumerable groupIds) { - GroupId = x.Id, - OrganizationUserId = _organizationUserId, - }); - } -} - -public class GroupUserUpdateGroupsDeleteQuery : IQuery -{ - private readonly Guid _organizationUserId; - private readonly IEnumerable _groupIds; - - public GroupUserUpdateGroupsDeleteQuery(Guid organizationUserId, IEnumerable groupIds) - { - _organizationUserId = organizationUserId; - _groupIds = groupIds; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var deleteQuery = from gu in dbContext.GroupUsers - where gu.OrganizationUserId == _organizationUserId && - !_groupIds.Any(x => gu.GroupId == x) - select gu; - return deleteQuery; + Insert = new GroupUserUpdateGroupsInsertQuery(organizationUserId, groupIds); + Delete = new GroupUserUpdateGroupsDeleteQuery(organizationUserId, groupIds); + } + } + + public class GroupUserUpdateGroupsInsertQuery : IQuery + { + private readonly Guid _organizationUserId; + private readonly IEnumerable _groupIds; + + public GroupUserUpdateGroupsInsertQuery(Guid organizationUserId, IEnumerable collections) + { + _organizationUserId = organizationUserId; + _groupIds = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgUser = from ou in dbContext.OrganizationUsers + where ou.Id == _organizationUserId + select ou; + var groupIdEntities = dbContext.Groups.Where(x => _groupIds.Contains(x.Id)); + var query = from g in dbContext.Groups + join ou in orgUser + on g.OrganizationId equals ou.OrganizationId + join gie in groupIdEntities + on g.Id equals gie.Id + where !dbContext.GroupUsers.Any(gu => _groupIds.Contains(gu.GroupId) && gu.OrganizationUserId == _organizationUserId) + select g; + return query.Select(x => new GroupUser + { + GroupId = x.Id, + OrganizationUserId = _organizationUserId, + }); + } + } + + public class GroupUserUpdateGroupsDeleteQuery : IQuery + { + private readonly Guid _organizationUserId; + private readonly IEnumerable _groupIds; + + public GroupUserUpdateGroupsDeleteQuery(Guid organizationUserId, IEnumerable groupIds) + { + _organizationUserId = organizationUserId; + _groupIds = groupIds; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var deleteQuery = from gu in dbContext.GroupUsers + where gu.OrganizationUserId == _organizationUserId && + !_groupIds.Any(x => gu.GroupId == x) + select gu; + return deleteQuery; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs index 554efe0b72..8729f5b158 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs @@ -1,6 +1,7 @@ -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public interface IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - IQueryable Run(DatabaseContext dbContext); + public interface IQuery + { + IQueryable Run(DatabaseContext dbContext); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs index 84dc4a7ad6..cd01229702 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs @@ -1,64 +1,65 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserOrganizationDetailsViewQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public IQueryable Run(DatabaseContext dbContext) + public class OrganizationUserOrganizationDetailsViewQuery : IQuery { - var query = from ou in dbContext.OrganizationUsers - join o in dbContext.Organizations on ou.OrganizationId equals o.Id - join su in dbContext.SsoUsers on ou.UserId equals su.UserId into su_g - from su in su_g.DefaultIfEmpty() - join po in dbContext.ProviderOrganizations on o.Id equals po.OrganizationId into po_g - from po in po_g.DefaultIfEmpty() - join p in dbContext.Providers on po.ProviderId equals p.Id into p_g - from p in p_g.DefaultIfEmpty() - join os in dbContext.OrganizationSponsorships on ou.Id equals os.SponsoringOrganizationUserId into os_g - from os in os_g.DefaultIfEmpty() - join ss in dbContext.SsoConfigs on ou.OrganizationId equals ss.OrganizationId into ss_g - from ss in ss_g.DefaultIfEmpty() - where ((su == null || !su.OrganizationId.HasValue) || su.OrganizationId == ou.OrganizationId) - select new { ou, o, su, p, ss, os }; - - return query.Select(x => new OrganizationUserOrganizationDetails + public IQueryable Run(DatabaseContext dbContext) { - OrganizationId = x.ou.OrganizationId, - UserId = x.ou.UserId, - Name = x.o.Name, - Enabled = x.o.Enabled, - PlanType = x.o.PlanType, - UsePolicies = x.o.UsePolicies, - UseSso = x.o.UseSso, - UseKeyConnector = x.o.UseKeyConnector, - UseScim = x.o.UseScim, - UseGroups = x.o.UseGroups, - UseDirectory = x.o.UseDirectory, - UseEvents = x.o.UseEvents, - UseTotp = x.o.UseTotp, - Use2fa = x.o.Use2fa, - UseApi = x.o.UseApi, - SelfHost = x.o.SelfHost, - UsersGetPremium = x.o.UsersGetPremium, - Seats = x.o.Seats, - MaxCollections = x.o.MaxCollections, - MaxStorageGb = x.o.MaxStorageGb, - Identifier = x.o.Identifier, - Key = x.ou.Key, - ResetPasswordKey = x.ou.ResetPasswordKey, - Status = x.ou.Status, - Type = x.ou.Type, - SsoExternalId = x.su.ExternalId, - Permissions = x.ou.Permissions, - PublicKey = x.o.PublicKey, - PrivateKey = x.o.PrivateKey, - ProviderId = x.p.Id, - ProviderName = x.p.Name, - SsoConfig = x.ss.Data, - FamilySponsorshipFriendlyName = x.os.FriendlyName, - FamilySponsorshipLastSyncDate = x.os.LastSyncDate, - FamilySponsorshipToDelete = x.os.ToDelete, - FamilySponsorshipValidUntil = x.os.ValidUntil - }); + var query = from ou in dbContext.OrganizationUsers + join o in dbContext.Organizations on ou.OrganizationId equals o.Id + join su in dbContext.SsoUsers on ou.UserId equals su.UserId into su_g + from su in su_g.DefaultIfEmpty() + join po in dbContext.ProviderOrganizations on o.Id equals po.OrganizationId into po_g + from po in po_g.DefaultIfEmpty() + join p in dbContext.Providers on po.ProviderId equals p.Id into p_g + from p in p_g.DefaultIfEmpty() + join os in dbContext.OrganizationSponsorships on ou.Id equals os.SponsoringOrganizationUserId into os_g + from os in os_g.DefaultIfEmpty() + join ss in dbContext.SsoConfigs on ou.OrganizationId equals ss.OrganizationId into ss_g + from ss in ss_g.DefaultIfEmpty() + where ((su == null || !su.OrganizationId.HasValue) || su.OrganizationId == ou.OrganizationId) + select new { ou, o, su, p, ss, os }; + + return query.Select(x => new OrganizationUserOrganizationDetails + { + OrganizationId = x.ou.OrganizationId, + UserId = x.ou.UserId, + Name = x.o.Name, + Enabled = x.o.Enabled, + PlanType = x.o.PlanType, + UsePolicies = x.o.UsePolicies, + UseSso = x.o.UseSso, + UseKeyConnector = x.o.UseKeyConnector, + UseScim = x.o.UseScim, + UseGroups = x.o.UseGroups, + UseDirectory = x.o.UseDirectory, + UseEvents = x.o.UseEvents, + UseTotp = x.o.UseTotp, + Use2fa = x.o.Use2fa, + UseApi = x.o.UseApi, + SelfHost = x.o.SelfHost, + UsersGetPremium = x.o.UsersGetPremium, + Seats = x.o.Seats, + MaxCollections = x.o.MaxCollections, + MaxStorageGb = x.o.MaxStorageGb, + Identifier = x.o.Identifier, + Key = x.ou.Key, + ResetPasswordKey = x.ou.ResetPasswordKey, + Status = x.ou.Status, + Type = x.ou.Type, + SsoExternalId = x.su.ExternalId, + Permissions = x.ou.Permissions, + PublicKey = x.o.PublicKey, + PrivateKey = x.o.PrivateKey, + ProviderId = x.p.Id, + ProviderName = x.p.Name, + SsoConfig = x.ss.Data, + FamilySponsorshipFriendlyName = x.os.FriendlyName, + FamilySponsorshipLastSyncDate = x.os.LastSyncDate, + FamilySponsorshipToDelete = x.os.ToDelete, + FamilySponsorshipValidUntil = x.os.ValidUntil + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs index c1656d3dfd..26c66fce9c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs @@ -1,28 +1,29 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserReadCountByFreeOrganizationAdminUserQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - - public OrganizationUserReadCountByFreeOrganizationAdminUserQuery(Guid userId) + public class OrganizationUserReadCountByFreeOrganizationAdminUserQuery : IQuery { - _userId = userId; - } + private readonly Guid _userId; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - join o in dbContext.Organizations - on ou.OrganizationId equals o.Id - where ou.UserId == _userId && - (ou.Type == OrganizationUserType.Owner || ou.Type == OrganizationUserType.Admin) && - o.PlanType == PlanType.Free && - ou.Status == OrganizationUserStatusType.Confirmed - select ou; + public OrganizationUserReadCountByFreeOrganizationAdminUserQuery(Guid userId) + { + _userId = userId; + } - return query; + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + join o in dbContext.Organizations + on ou.OrganizationId equals o.Id + where ou.UserId == _userId && + (ou.Type == OrganizationUserType.Owner || ou.Type == OrganizationUserType.Admin) && + o.PlanType == PlanType.Free && + ou.Status == OrganizationUserStatusType.Confirmed + select ou; + + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs index 53977e8b74..6cf8e3c3b7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs @@ -1,36 +1,37 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserReadCountByOnlyOwnerQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - - public OrganizationUserReadCountByOnlyOwnerQuery(Guid userId) + public class OrganizationUserReadCountByOnlyOwnerQuery : IQuery { - _userId = userId; - } + private readonly Guid _userId; - public IQueryable Run(DatabaseContext dbContext) - { - var owners = from ou in dbContext.OrganizationUsers - where ou.Type == OrganizationUserType.Owner && - ou.Status == OrganizationUserStatusType.Confirmed - group ou by ou.OrganizationId into g - select new - { - OrgUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), - ConfirmedOwnerCount = g.Count(), - }; + public OrganizationUserReadCountByOnlyOwnerQuery(Guid userId) + { + _userId = userId; + } - var query = from owner in owners - join ou in dbContext.OrganizationUsers - on owner.OrgUser.Id equals ou.Id - where owner.OrgUser.UserId == _userId && - owner.ConfirmedOwnerCount == 1 - select ou; + public IQueryable Run(DatabaseContext dbContext) + { + var owners = from ou in dbContext.OrganizationUsers + where ou.Type == OrganizationUserType.Owner && + ou.Status == OrganizationUserStatusType.Confirmed + group ou by ou.OrganizationId into g + select new + { + OrgUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), + ConfirmedOwnerCount = g.Count(), + }; - return query; + var query = from owner in owners + join ou in dbContext.OrganizationUsers + on owner.OrgUser.Id equals ou.Id + where owner.OrgUser.UserId == _userId && + owner.ConfirmedOwnerCount == 1 + select ou; + + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs index 0cb2abc46a..ed4c786c74 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs @@ -1,30 +1,31 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserReadCountByOrganizationIdEmailQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; - private readonly string _email; - private readonly bool _onlyUsers; - - public OrganizationUserReadCountByOrganizationIdEmailQuery(Guid organizationId, string email, bool onlyUsers) + public class OrganizationUserReadCountByOrganizationIdEmailQuery : IQuery { - _organizationId = organizationId; - _email = email; - _onlyUsers = onlyUsers; - } + private readonly Guid _organizationId; + private readonly string _email; + private readonly bool _onlyUsers; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - join u in dbContext.Users - on ou.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - where ou.OrganizationId == _organizationId && - ((!_onlyUsers && (ou.Email == _email || u.Email == _email)) - || (_onlyUsers && u.Email == _email)) - select ou; - return query; + public OrganizationUserReadCountByOrganizationIdEmailQuery(Guid organizationId, string email, bool onlyUsers) + { + _organizationId = organizationId; + _email = email; + _onlyUsers = onlyUsers; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + join u in dbContext.Users + on ou.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + where ou.OrganizationId == _organizationId && + ((!_onlyUsers && (ou.Email == _email || u.Email == _email)) + || (_onlyUsers && u.Email == _email)) + select ou; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs index a4ab7cb85d..05c6dd0496 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs @@ -1,21 +1,22 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserReadCountByOrganizationIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; - - public OrganizationUserReadCountByOrganizationIdQuery(Guid organizationId) + public class OrganizationUserReadCountByOrganizationIdQuery : IQuery { - _organizationId = organizationId; - } + private readonly Guid _organizationId; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - where ou.OrganizationId == _organizationId - select ou; - return query; + public OrganizationUserReadCountByOrganizationIdQuery(Guid organizationId) + { + _organizationId = organizationId; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + where ou.OrganizationId == _organizationId + select ou; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs index 0a21514d62..10dbf88d7d 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs @@ -2,104 +2,105 @@ using Bit.Core.Models.Data; using CollectionUser = Bit.Infrastructure.EntityFramework.Models.CollectionUser; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserUpdateWithCollectionsQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public OrganizationUserUpdateWithCollectionsInsertQuery Insert { get; set; } - public OrganizationUserUpdateWithCollectionsUpdateQuery Update { get; set; } - public OrganizationUserUpdateWithCollectionsDeleteQuery Delete { get; set; } - - public OrganizationUserUpdateWithCollectionsQuery(OrganizationUser organizationUser, - IEnumerable collections) + public class OrganizationUserUpdateWithCollectionsQuery { - Insert = new OrganizationUserUpdateWithCollectionsInsertQuery(organizationUser, collections); - Update = new OrganizationUserUpdateWithCollectionsUpdateQuery(organizationUser, collections); - Delete = new OrganizationUserUpdateWithCollectionsDeleteQuery(organizationUser, collections); - } -} + public OrganizationUserUpdateWithCollectionsInsertQuery Insert { get; set; } + public OrganizationUserUpdateWithCollectionsUpdateQuery Update { get; set; } + public OrganizationUserUpdateWithCollectionsDeleteQuery Delete { get; set; } -public class OrganizationUserUpdateWithCollectionsInsertQuery : IQuery -{ - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsInsertQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var collectionIds = _collections.Select(c => c.Id).ToArray(); - var t = (from cu in dbContext.CollectionUsers - where collectionIds.Contains(cu.CollectionId) && - cu.OrganizationUserId == _organizationUser.Id - select cu).AsEnumerable(); - var insertQuery = (from c in dbContext.Collections - where collectionIds.Contains(c.Id) && - c.OrganizationId == _organizationUser.OrganizationId && - !t.Any() - select c).AsEnumerable(); - return insertQuery.Select(x => new CollectionUser + public OrganizationUserUpdateWithCollectionsQuery(OrganizationUser organizationUser, + IEnumerable collections) { - CollectionId = x.Id, - OrganizationUserId = _organizationUser.Id, - ReadOnly = _collections.FirstOrDefault(c => c.Id == x.Id).ReadOnly, - HidePasswords = _collections.FirstOrDefault(c => c.Id == x.Id).HidePasswords, - }).AsQueryable(); - } -} - -public class OrganizationUserUpdateWithCollectionsUpdateQuery : IQuery -{ - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsUpdateQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; + Insert = new OrganizationUserUpdateWithCollectionsInsertQuery(organizationUser, collections); + Update = new OrganizationUserUpdateWithCollectionsUpdateQuery(organizationUser, collections); + Delete = new OrganizationUserUpdateWithCollectionsDeleteQuery(organizationUser, collections); + } } - public IQueryable Run(DatabaseContext dbContext) + public class OrganizationUserUpdateWithCollectionsInsertQuery : IQuery { - var collectionIds = _collections.Select(c => c.Id).ToArray(); - var updateQuery = (from target in dbContext.CollectionUsers - where collectionIds.Contains(target.CollectionId) && - target.OrganizationUserId == _organizationUser.Id - select new { target }).AsEnumerable(); - updateQuery = updateQuery.Where(cu => - cu.target.ReadOnly == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).ReadOnly && - cu.target.HidePasswords == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).HidePasswords); - return updateQuery.Select(x => new CollectionUser + private readonly OrganizationUser _organizationUser; + private readonly IEnumerable _collections; + + public OrganizationUserUpdateWithCollectionsInsertQuery(OrganizationUser organizationUser, IEnumerable collections) { - CollectionId = x.target.CollectionId, - OrganizationUserId = _organizationUser.Id, - ReadOnly = x.target.ReadOnly, - HidePasswords = x.target.HidePasswords, - }).AsQueryable(); - } -} - -public class OrganizationUserUpdateWithCollectionsDeleteQuery : IQuery -{ - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsDeleteQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var deleteQuery = from cu in dbContext.CollectionUsers - where !_collections.Any( - c => c.Id == cu.CollectionId) - select cu; - return deleteQuery; + _organizationUser = organizationUser; + _collections = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var collectionIds = _collections.Select(c => c.Id).ToArray(); + var t = (from cu in dbContext.CollectionUsers + where collectionIds.Contains(cu.CollectionId) && + cu.OrganizationUserId == _organizationUser.Id + select cu).AsEnumerable(); + var insertQuery = (from c in dbContext.Collections + where collectionIds.Contains(c.Id) && + c.OrganizationId == _organizationUser.OrganizationId && + !t.Any() + select c).AsEnumerable(); + return insertQuery.Select(x => new CollectionUser + { + CollectionId = x.Id, + OrganizationUserId = _organizationUser.Id, + ReadOnly = _collections.FirstOrDefault(c => c.Id == x.Id).ReadOnly, + HidePasswords = _collections.FirstOrDefault(c => c.Id == x.Id).HidePasswords, + }).AsQueryable(); + } + } + + public class OrganizationUserUpdateWithCollectionsUpdateQuery : IQuery + { + private readonly OrganizationUser _organizationUser; + private readonly IEnumerable _collections; + + public OrganizationUserUpdateWithCollectionsUpdateQuery(OrganizationUser organizationUser, IEnumerable collections) + { + _organizationUser = organizationUser; + _collections = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var collectionIds = _collections.Select(c => c.Id).ToArray(); + var updateQuery = (from target in dbContext.CollectionUsers + where collectionIds.Contains(target.CollectionId) && + target.OrganizationUserId == _organizationUser.Id + select new { target }).AsEnumerable(); + updateQuery = updateQuery.Where(cu => + cu.target.ReadOnly == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).ReadOnly && + cu.target.HidePasswords == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).HidePasswords); + return updateQuery.Select(x => new CollectionUser + { + CollectionId = x.target.CollectionId, + OrganizationUserId = _organizationUser.Id, + ReadOnly = x.target.ReadOnly, + HidePasswords = x.target.HidePasswords, + }).AsQueryable(); + } + } + + public class OrganizationUserUpdateWithCollectionsDeleteQuery : IQuery + { + private readonly OrganizationUser _organizationUser; + private readonly IEnumerable _collections; + + public OrganizationUserUpdateWithCollectionsDeleteQuery(OrganizationUser organizationUser, IEnumerable collections) + { + _organizationUser = organizationUser; + _collections = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var deleteQuery = from cu in dbContext.CollectionUsers + where !_collections.Any( + c => c.Id == cu.CollectionId) + select cu; + return deleteQuery; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs index 2a5bf06fd1..248957196e 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs @@ -1,34 +1,35 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserUserDetailsViewQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public IQueryable Run(DatabaseContext dbContext) + public class OrganizationUserUserDetailsViewQuery : IQuery { - var query = from ou in dbContext.OrganizationUsers - join u in dbContext.Users on ou.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - join su in dbContext.SsoUsers on u.Id equals su.UserId into su_g - from su in su_g.DefaultIfEmpty() - select new { ou, u, su }; - return query.Select(x => new OrganizationUserUserDetails + public IQueryable Run(DatabaseContext dbContext) { - Id = x.ou.Id, - OrganizationId = x.ou.OrganizationId, - UserId = x.ou.UserId, - Name = x.u.Name, - Email = x.u.Email ?? x.ou.Email, - TwoFactorProviders = x.u.TwoFactorProviders, - Premium = x.u.Premium, - Status = x.ou.Status, - Type = x.ou.Type, - AccessAll = x.ou.AccessAll, - ExternalId = x.ou.ExternalId, - SsoExternalId = x.su.ExternalId, - Permissions = x.ou.Permissions, - ResetPasswordKey = x.ou.ResetPasswordKey, - UsesKeyConnector = x.u != null && x.u.UsesKeyConnector, - }); + var query = from ou in dbContext.OrganizationUsers + join u in dbContext.Users on ou.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + join su in dbContext.SsoUsers on u.Id equals su.UserId into su_g + from su in su_g.DefaultIfEmpty() + select new { ou, u, su }; + return query.Select(x => new OrganizationUserUserDetails + { + Id = x.ou.Id, + OrganizationId = x.ou.OrganizationId, + UserId = x.ou.UserId, + Name = x.u.Name, + Email = x.u.Email ?? x.ou.Email, + TwoFactorProviders = x.u.TwoFactorProviders, + Premium = x.u.Premium, + Status = x.ou.Status, + Type = x.ou.Type, + AccessAll = x.ou.AccessAll, + ExternalId = x.ou.ExternalId, + SsoExternalId = x.su.ExternalId, + Permissions = x.ou.Permissions, + ResetPasswordKey = x.ou.ResetPasswordKey, + UsesKeyConnector = x.u != null && x.u.UsesKeyConnector, + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs index 21e5f9a281..bd69570ec5 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs @@ -1,50 +1,51 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class PolicyReadByTypeApplicableToUserQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - private readonly PolicyType _policyType; - private readonly OrganizationUserStatusType _minimumStatus; - - public PolicyReadByTypeApplicableToUserQuery(Guid userId, PolicyType policyType, OrganizationUserStatusType minimumStatus) + public class PolicyReadByTypeApplicableToUserQuery : IQuery { - _userId = userId; - _policyType = policyType; - _minimumStatus = minimumStatus; - } + private readonly Guid _userId; + private readonly PolicyType _policyType; + private readonly OrganizationUserStatusType _minimumStatus; - public IQueryable Run(DatabaseContext dbContext) - { - var providerOrganizations = from pu in dbContext.ProviderUsers - where pu.UserId == _userId - join po in dbContext.ProviderOrganizations - on pu.ProviderId equals po.ProviderId - select po; - - string userEmail = null; - if (_minimumStatus == OrganizationUserStatusType.Invited) + public PolicyReadByTypeApplicableToUserQuery(Guid userId, PolicyType policyType, OrganizationUserStatusType minimumStatus) { - // Invited orgUsers do not have a UserId associated with them, so we have to match up their email - userEmail = dbContext.Users.Find(_userId)?.Email; + _userId = userId; + _policyType = policyType; + _minimumStatus = minimumStatus; } - var query = from p in dbContext.Policies - join ou in dbContext.OrganizationUsers - on p.OrganizationId equals ou.OrganizationId - where - ((_minimumStatus > OrganizationUserStatusType.Invited && ou.UserId == _userId) || - (_minimumStatus == OrganizationUserStatusType.Invited && ou.Email == userEmail)) && - p.Type == _policyType && - p.Enabled && - ou.Status >= _minimumStatus && - ou.Type >= OrganizationUserType.User && - (ou.Permissions == null || - ou.Permissions.Contains($"\"managePolicies\":false")) && - !providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId) - select p; - return query; + public IQueryable Run(DatabaseContext dbContext) + { + var providerOrganizations = from pu in dbContext.ProviderUsers + where pu.UserId == _userId + join po in dbContext.ProviderOrganizations + on pu.ProviderId equals po.ProviderId + select po; + + string userEmail = null; + if (_minimumStatus == OrganizationUserStatusType.Invited) + { + // Invited orgUsers do not have a UserId associated with them, so we have to match up their email + userEmail = dbContext.Users.Find(_userId)?.Email; + } + + var query = from p in dbContext.Policies + join ou in dbContext.OrganizationUsers + on p.OrganizationId equals ou.OrganizationId + where + ((_minimumStatus > OrganizationUserStatusType.Invited && ou.UserId == _userId) || + (_minimumStatus == OrganizationUserStatusType.Invited && ou.Email == userEmail)) && + p.Type == _policyType && + p.Enabled && + ou.Status >= _minimumStatus && + ou.Type >= OrganizationUserType.User && + (ou.Permissions == null || + ou.Permissions.Contains($"\"managePolicies\":false")) && + !providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId) + select p; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs index 58c06395a9..e910e2f75d 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs @@ -1,29 +1,30 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class PolicyReadByUserIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - - public PolicyReadByUserIdQuery(Guid userId) + public class PolicyReadByUserIdQuery : IQuery { - _userId = userId; - } + private readonly Guid _userId; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from p in dbContext.Policies - join ou in dbContext.OrganizationUsers - on p.OrganizationId equals ou.OrganizationId - join o in dbContext.Organizations - on ou.OrganizationId equals o.Id - where ou.UserId == _userId && - ou.Status == OrganizationUserStatusType.Confirmed && - o.Enabled == true - select p; + public PolicyReadByUserIdQuery(Guid userId) + { + _userId = userId; + } - return query; + public IQueryable Run(DatabaseContext dbContext) + { + var query = from p in dbContext.Policies + join ou in dbContext.OrganizationUsers + on p.OrganizationId equals ou.OrganizationId + join o in dbContext.Organizations + on ou.OrganizationId equals o.Id + where ou.UserId == _userId && + ou.Status == OrganizationUserStatusType.Confirmed && + o.Enabled == true + select p; + + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs index 1429a136cc..03ada03c34 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs @@ -1,37 +1,38 @@ using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class ProviderOrganizationOrganizationDetailsReadByProviderIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _providerId; - public ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(Guid providerId) + public class ProviderOrganizationOrganizationDetailsReadByProviderIdQuery : IQuery { - _providerId = providerId; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from po in dbContext.ProviderOrganizations - join o in dbContext.Organizations - on po.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on po.OrganizationId equals ou.OrganizationId - where po.ProviderId == _providerId - select new { po, o }; - return query.Select(x => new ProviderOrganizationOrganizationDetails() + private readonly Guid _providerId; + public ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(Guid providerId) { - Id = x.po.Id, - ProviderId = x.po.ProviderId, - OrganizationId = x.po.OrganizationId, - OrganizationName = x.o.Name, - Key = x.po.Key, - Settings = x.po.Settings, - CreationDate = x.po.CreationDate, - RevisionDate = x.po.RevisionDate, - UserCount = x.o.OrganizationUsers.Count(ou => ou.Status == Core.Enums.OrganizationUserStatusType.Confirmed), - Seats = x.o.Seats, - Plan = x.o.Plan - }); + _providerId = providerId; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from po in dbContext.ProviderOrganizations + join o in dbContext.Organizations + on po.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on po.OrganizationId equals ou.OrganizationId + where po.ProviderId == _providerId + select new { po, o }; + return query.Select(x => new ProviderOrganizationOrganizationDetails() + { + Id = x.po.Id, + ProviderId = x.po.ProviderId, + OrganizationId = x.po.OrganizationId, + OrganizationName = x.o.Name, + Key = x.po.Key, + Settings = x.po.Settings, + CreationDate = x.po.CreationDate, + RevisionDate = x.po.RevisionDate, + UserCount = x.o.OrganizationUsers.Count(ou => ou.Status == Core.Enums.OrganizationUserStatusType.Confirmed), + Seats = x.o.Seats, + Plan = x.o.Plan + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs index dfd5f61924..8f3e71861f 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs @@ -1,45 +1,46 @@ using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class ProviderUserOrganizationDetailsViewQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - public IQueryable Run(DatabaseContext dbContext) + public class ProviderUserOrganizationDetailsViewQuery : IQuery { - var query = from pu in dbContext.ProviderUsers - join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId - join o in dbContext.Organizations on po.OrganizationId equals o.Id - join p in dbContext.Providers on pu.ProviderId equals p.Id - select new { pu, po, o, p }; - return query.Select(x => new ProviderUserOrganizationDetails + public IQueryable Run(DatabaseContext dbContext) { - OrganizationId = x.po.OrganizationId, - UserId = x.pu.UserId, - Name = x.o.Name, - Enabled = x.o.Enabled, - UsePolicies = x.o.UsePolicies, - UseSso = x.o.UseSso, - UseKeyConnector = x.o.UseKeyConnector, - UseScim = x.o.UseScim, - UseGroups = x.o.UseGroups, - UseDirectory = x.o.UseDirectory, - UseEvents = x.o.UseEvents, - UseTotp = x.o.UseTotp, - Use2fa = x.o.Use2fa, - UseApi = x.o.UseApi, - SelfHost = x.o.SelfHost, - UsersGetPremium = x.o.UsersGetPremium, - Seats = x.o.Seats, - MaxCollections = x.o.MaxCollections, - MaxStorageGb = x.o.MaxStorageGb, - Identifier = x.o.Identifier, - Key = x.po.Key, - Status = x.pu.Status, - Type = x.pu.Type, - PublicKey = x.o.PublicKey, - PrivateKey = x.o.PrivateKey, - ProviderId = x.p.Id, - ProviderName = x.p.Name, - }); + var query = from pu in dbContext.ProviderUsers + join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId + join o in dbContext.Organizations on po.OrganizationId equals o.Id + join p in dbContext.Providers on pu.ProviderId equals p.Id + select new { pu, po, o, p }; + return query.Select(x => new ProviderUserOrganizationDetails + { + OrganizationId = x.po.OrganizationId, + UserId = x.pu.UserId, + Name = x.o.Name, + Enabled = x.o.Enabled, + UsePolicies = x.o.UsePolicies, + UseSso = x.o.UseSso, + UseKeyConnector = x.o.UseKeyConnector, + UseScim = x.o.UseScim, + UseGroups = x.o.UseGroups, + UseDirectory = x.o.UseDirectory, + UseEvents = x.o.UseEvents, + UseTotp = x.o.UseTotp, + Use2fa = x.o.Use2fa, + UseApi = x.o.UseApi, + SelfHost = x.o.SelfHost, + UsersGetPremium = x.o.UsersGetPremium, + Seats = x.o.Seats, + MaxCollections = x.o.MaxCollections, + MaxStorageGb = x.o.MaxStorageGb, + Identifier = x.o.Identifier, + Key = x.po.Key, + Status = x.pu.Status, + Type = x.pu.Type, + PublicKey = x.o.PublicKey, + PrivateKey = x.o.PrivateKey, + ProviderId = x.p.Id, + ProviderName = x.p.Name, + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs index 1cae8437ac..efa7681230 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs @@ -1,38 +1,39 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class ProviderUserProviderDetailsReadByUserIdStatusQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - private readonly ProviderUserStatusType? _status; - public ProviderUserProviderDetailsReadByUserIdStatusQuery(Guid userId, ProviderUserStatusType? status) + public class ProviderUserProviderDetailsReadByUserIdStatusQuery : IQuery { - _userId = userId; - _status = status; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from pu in dbContext.ProviderUsers - join p in dbContext.Providers - on pu.ProviderId equals p.Id into p_g - from p in p_g.DefaultIfEmpty() - where pu.UserId == _userId && p.Status != ProviderStatusType.Pending && (_status == null || pu.Status == _status) - select new { pu, p }; - return query.Select(x => new ProviderUserProviderDetails() + private readonly Guid _userId; + private readonly ProviderUserStatusType? _status; + public ProviderUserProviderDetailsReadByUserIdStatusQuery(Guid userId, ProviderUserStatusType? status) { - UserId = x.pu.UserId, - ProviderId = x.pu.ProviderId, - Name = x.p.Name, - Key = x.pu.Key, - Status = x.pu.Status, - Type = x.pu.Type, - Enabled = x.p.Enabled, - Permissions = x.pu.Permissions, - UseEvents = x.p.UseEvents, - ProviderStatus = x.p.Status, - }); + _userId = userId; + _status = status; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from pu in dbContext.ProviderUsers + join p in dbContext.Providers + on pu.ProviderId equals p.Id into p_g + from p in p_g.DefaultIfEmpty() + where pu.UserId == _userId && p.Status != ProviderStatusType.Pending && (_status == null || pu.Status == _status) + select new { pu, p }; + return query.Select(x => new ProviderUserProviderDetails() + { + UserId = x.pu.UserId, + ProviderId = x.pu.ProviderId, + Name = x.p.Name, + Key = x.pu.Key, + Status = x.pu.Status, + Type = x.pu.Type, + Enabled = x.p.Enabled, + Permissions = x.pu.Permissions, + UseEvents = x.p.UseEvents, + ProviderStatus = x.p.Status, + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs index 899c78b547..fd6e8521d7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs @@ -1,36 +1,37 @@ using Bit.Core.Enums.Provider; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class ProviderUserReadCountByOnlyOwnerQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _userId; - - public ProviderUserReadCountByOnlyOwnerQuery(Guid userId) + public class ProviderUserReadCountByOnlyOwnerQuery : IQuery { - _userId = userId; - } + private readonly Guid _userId; - public IQueryable Run(DatabaseContext dbContext) - { - var owners = from pu in dbContext.ProviderUsers - where pu.Type == ProviderUserType.ProviderAdmin && - pu.Status == ProviderUserStatusType.Confirmed - group pu by pu.ProviderId into g - select new - { - ProviderUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), - ConfirmedOwnerCount = g.Count(), - }; + public ProviderUserReadCountByOnlyOwnerQuery(Guid userId) + { + _userId = userId; + } - var query = from owner in owners - join pu in dbContext.ProviderUsers - on owner.ProviderUser.Id equals pu.Id - where owner.ProviderUser.UserId == _userId && - owner.ConfirmedOwnerCount == 1 - select pu; + public IQueryable Run(DatabaseContext dbContext) + { + var owners = from pu in dbContext.ProviderUsers + where pu.Type == ProviderUserType.ProviderAdmin && + pu.Status == ProviderUserStatusType.Confirmed + group pu by pu.ProviderId into g + select new + { + ProviderUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), + ConfirmedOwnerCount = g.Count(), + }; - return query; + var query = from owner in owners + join pu in dbContext.ProviderUsers + on owner.ProviderUser.Id equals pu.Id + where owner.ProviderUser.UserId == _userId && + owner.ConfirmedOwnerCount == 1 + select pu; + + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs index 00369c5940..6fcf15ab40 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs @@ -2,50 +2,51 @@ using Bit.Core.Enums; using User = Bit.Infrastructure.EntityFramework.Models.User; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Cipher _cipher; - - public UserBumpAccountRevisionDateByCipherIdQuery(Cipher cipher) + public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery { - _cipher = cipher; - } + private readonly Cipher _cipher; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - join collectionCipher in dbContext.CollectionCiphers - on _cipher.Id equals collectionCipher.CipherId into cc_g - from cc in cc_g.DefaultIfEmpty() - join collectionUser in dbContext.CollectionUsers - on cc.CollectionId equals collectionUser.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where ou.AccessAll && - cu.OrganizationUserId == ou.Id - join groupUser in dbContext.GroupUsers - on ou.Id equals groupUser.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && - !ou.AccessAll - join grp in dbContext.Groups - on gu.GroupId equals grp.Id into g_g - from g in g_g.DefaultIfEmpty() - join collectionGroup in dbContext.CollectionGroups - on cc.CollectionId equals collectionGroup.CollectionId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && - cg.GroupId == gu.GroupId - where ou.OrganizationId == _cipher.OrganizationId && - ou.Status == OrganizationUserStatusType.Confirmed && - (cu.CollectionId != null || - cg.CollectionId != null || - ou.AccessAll || - g.AccessAll) - select u; - return query; + public UserBumpAccountRevisionDateByCipherIdQuery(Cipher cipher) + { + _cipher = cipher; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + join collectionCipher in dbContext.CollectionCiphers + on _cipher.Id equals collectionCipher.CipherId into cc_g + from cc in cc_g.DefaultIfEmpty() + join collectionUser in dbContext.CollectionUsers + on cc.CollectionId equals collectionUser.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where ou.AccessAll && + cu.OrganizationUserId == ou.Id + join groupUser in dbContext.GroupUsers + on ou.Id equals groupUser.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && + !ou.AccessAll + join grp in dbContext.Groups + on gu.GroupId equals grp.Id into g_g + from g in g_g.DefaultIfEmpty() + join collectionGroup in dbContext.CollectionGroups + on cc.CollectionId equals collectionGroup.CollectionId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && + cg.GroupId == gu.GroupId + where ou.OrganizationId == _cipher.OrganizationId && + ou.Status == OrganizationUserStatusType.Confirmed && + (cu.CollectionId != null || + cg.CollectionId != null || + ou.AccessAll || + g.AccessAll) + select u; + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs index d18cdc064b..87c2bcf086 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs @@ -1,26 +1,27 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class UserBumpAccountRevisionDateByOrganizationIdQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _organizationId; - - public UserBumpAccountRevisionDateByOrganizationIdQuery(Guid organizationId) + public class UserBumpAccountRevisionDateByOrganizationIdQuery : IQuery { - _organizationId = organizationId; - } + private readonly Guid _organizationId; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - where ou.OrganizationId == _organizationId && - ou.Status == OrganizationUserStatusType.Confirmed - select u; + public UserBumpAccountRevisionDateByOrganizationIdQuery(Guid organizationId) + { + _organizationId = organizationId; + } - return query; + public IQueryable Run(DatabaseContext dbContext) + { + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + where ou.OrganizationId == _organizationId && + ou.Status == OrganizationUserStatusType.Confirmed + select u; + + return query; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs index b74060ba3d..3417abd30c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs @@ -2,70 +2,71 @@ using Core.Models.Data; using Newtonsoft.Json.Linq; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class UserCipherDetailsQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid? _userId; - public UserCipherDetailsQuery(Guid? userId) + public class UserCipherDetailsQuery : IQuery { - _userId = userId; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join ou in dbContext.OrganizationUsers - on c.OrganizationId equals ou.OrganizationId - where ou.UserId == _userId && - ou.Status == OrganizationUserStatusType.Confirmed - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - where o.Id == ou.OrganizationId && o.Enabled - join cc in dbContext.CollectionCiphers - on c.Id equals cc.CipherId into cc_g - from cc in cc_g.DefaultIfEmpty() - where ou.AccessAll - join cu in dbContext.CollectionUsers - on cc.CollectionId equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on cc.CollectionId equals cg.CollectionId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.GroupId == gu.GroupId && - ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null - select new { c, ou, o, cc, cu, gu, g, cg }.c; - - var query2 = from c in dbContext.Ciphers - where c.UserId == _userId - select c; - - var union = query.Union(query2).Select(c => new CipherDetails + private readonly Guid? _userId; + public UserCipherDetailsQuery(Guid? userId) { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), - FolderId = _userId.HasValue && !string.IsNullOrWhiteSpace(c.Folders) ? - Guid.Parse(JObject.Parse(c.Folders)[_userId.Value.ToString()].Value()) : - null, - Edit = true, - ViewPassword = true, - OrganizationUseTotp = false, - }); - return union; + _userId = userId; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join ou in dbContext.OrganizationUsers + on c.OrganizationId equals ou.OrganizationId + where ou.UserId == _userId && + ou.Status == OrganizationUserStatusType.Confirmed + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + where o.Id == ou.OrganizationId && o.Enabled + join cc in dbContext.CollectionCiphers + on c.Id equals cc.CipherId into cc_g + from cc in cc_g.DefaultIfEmpty() + where ou.AccessAll + join cu in dbContext.CollectionUsers + on cc.CollectionId equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on cc.CollectionId equals cg.CollectionId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.GroupId == gu.GroupId && + ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null + select new { c, ou, o, cc, cu, gu, g, cg }.c; + + var query2 = from c in dbContext.Ciphers + where c.UserId == _userId + select c; + + var union = query.Union(query2).Select(c => new CipherDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), + FolderId = _userId.HasValue && !string.IsNullOrWhiteSpace(c.Folders) ? + Guid.Parse(JObject.Parse(c.Folders)[_userId.Value.ToString()].Value()) : + null, + Edit = true, + ViewPassword = true, + OrganizationUseTotp = false, + }); + return union; + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs index 7004a6f752..c2325f10f2 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs @@ -1,52 +1,53 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class UserCollectionDetailsQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid? _userId; - public UserCollectionDetailsQuery(Guid? userId) + public class UserCollectionDetailsQuery : IQuery { - _userId = userId; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Collections - join ou in dbContext.OrganizationUsers - on c.OrganizationId equals ou.OrganizationId - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join cu in dbContext.CollectionUsers - on c.Id equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where ou.AccessAll && cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == c.Id && - ou.UserId == _userId && - ou.Status == OrganizationUserStatusType.Confirmed && - o.Enabled && - (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) - select new { c, ou, o, cu, gu, g, cg }; - return query.Select(x => new CollectionDetails + private readonly Guid? _userId; + public UserCollectionDetailsQuery(Guid? userId) { - Id = x.c.Id, - OrganizationId = x.c.OrganizationId, - Name = x.c.Name, - ExternalId = x.c.ExternalId, - CreationDate = x.c.CreationDate, - RevisionDate = x.c.RevisionDate, - ReadOnly = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.ReadOnly || x.cg.ReadOnly), - HidePasswords = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.HidePasswords || x.cg.HidePasswords), - }); + _userId = userId; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Collections + join ou in dbContext.OrganizationUsers + on c.OrganizationId equals ou.OrganizationId + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join cu in dbContext.CollectionUsers + on c.Id equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where ou.AccessAll && cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == c.Id && + ou.UserId == _userId && + ou.Status == OrganizationUserStatusType.Confirmed && + o.Enabled && + (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) + select new { c, ou, o, cu, gu, g, cg }; + return query.Select(x => new CollectionDetails + { + Id = x.c.Id, + OrganizationId = x.c.OrganizationId, + Name = x.c.Name, + ExternalId = x.c.ExternalId, + CreationDate = x.c.CreationDate, + RevisionDate = x.c.RevisionDate, + ReadOnly = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.ReadOnly || x.cg.ReadOnly), + HidePasswords = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.HidePasswords || x.cg.HidePasswords), + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs index db347b99b4..10782e0ea5 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs @@ -1,32 +1,33 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class UserReadPublicKeysByProviderUserIdsQuery : IQuery +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries { - private readonly Guid _providerId; - private readonly IEnumerable _ids; - - public UserReadPublicKeysByProviderUserIdsQuery(Guid providerId, IEnumerable Ids) + public class UserReadPublicKeysByProviderUserIdsQuery : IQuery { - _providerId = providerId; - _ids = Ids; - } + private readonly Guid _providerId; + private readonly IEnumerable _ids; - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id - where _ids.Contains(pu.Id) && - pu.Status == ProviderUserStatusType.Accepted && - pu.ProviderId == _providerId - select new { pu, u }; - return query.Select(x => new ProviderUserPublicKey + public UserReadPublicKeysByProviderUserIdsQuery(Guid providerId, IEnumerable Ids) { - Id = x.pu.Id, - PublicKey = x.u.PublicKey, - }); + _providerId = providerId; + _ids = Ids; + } + + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id + where _ids.Contains(pu.Id) && + pu.Status == ProviderUserStatusType.Accepted && + pu.ProviderId == _providerId + select new { pu, u }; + return query.Select(x => new ProviderUserPublicKey + { + Id = x.pu.Id, + PublicKey = x.u.PublicKey, + }); + } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Repository.cs b/src/Infrastructure.EntityFramework/Repositories/Repository.cs index 4c509540d7..2d933da805 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Repository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Repository.cs @@ -5,117 +5,118 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public abstract class Repository : BaseEntityFrameworkRepository, IRepository - where TId : IEquatable - where T : class, ITableObject - where TEntity : class, ITableObject +namespace Bit.Infrastructure.EntityFramework.Repositories { - public Repository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func> getDbSet) - : base(serviceScopeFactory, mapper) + public abstract class Repository : BaseEntityFrameworkRepository, IRepository + where TId : IEquatable + where T : class, ITableObject + where TEntity : class, ITableObject { - GetDbSet = getDbSet; - } - - protected Func> GetDbSet { get; private set; } - - public virtual async Task GetByIdAsync(TId id) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public Repository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func> getDbSet) + : base(serviceScopeFactory, mapper) { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).FindAsync(id); - return Mapper.Map(entity); + GetDbSet = getDbSet; } - } - public virtual async Task CreateAsync(T obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - obj.SetNewId(); - var entity = Mapper.Map(obj); - await dbContext.AddAsync(entity); - await dbContext.SaveChangesAsync(); - obj.Id = entity.Id; - return obj; - } - } + protected Func> GetDbSet { get; private set; } - public virtual async Task ReplaceAsync(T obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public virtual async Task GetByIdAsync(TId id) { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).FindAsync(obj.Id); - if (entity != null) + using (var scope = ServiceScopeFactory.CreateScope()) { - var mappedEntity = Mapper.Map(obj); - dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).FindAsync(id); + return Mapper.Map(entity); + } + } + + public virtual async Task CreateAsync(T obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + obj.SetNewId(); + var entity = Mapper.Map(obj); + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + obj.Id = entity.Id; + return obj; + } + } + + public virtual async Task ReplaceAsync(T obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).FindAsync(obj.Id); + if (entity != null) + { + var mappedEntity = Mapper.Map(obj); + dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); + await dbContext.SaveChangesAsync(); + } + } + } + + public virtual async Task UpsertAsync(T obj) + { + if (obj.Id.Equals(default(TId))) + { + await CreateAsync(obj); + } + else + { + await ReplaceAsync(obj); + } + } + + public virtual async Task DeleteAsync(T obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = Mapper.Map(obj); + dbContext.Remove(entity); await dbContext.SaveChangesAsync(); } } - } - public virtual async Task UpsertAsync(T obj) - { - if (obj.Id.Equals(default(TId))) + public virtual async Task RefreshDb() { - await CreateAsync(obj); - } - else - { - await ReplaceAsync(obj); - } - } - - public virtual async Task DeleteAsync(T obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = Mapper.Map(obj); - dbContext.Remove(entity); - await dbContext.SaveChangesAsync(); - } - } - - public virtual async Task RefreshDb() - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var context = GetDatabaseContext(scope); - await context.Database.EnsureDeletedAsync(); - await context.Database.EnsureCreatedAsync(); - } - } - - public virtual async Task> CreateMany(List objs) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var entities = new List(); - foreach (var o in objs) + using (var scope = ServiceScopeFactory.CreateScope()) { - o.SetNewId(); - var entity = Mapper.Map(o); - entities.Add(entity); + var context = GetDatabaseContext(scope); + await context.Database.EnsureDeletedAsync(); + await context.Database.EnsureCreatedAsync(); } - var dbContext = GetDatabaseContext(scope); - await GetDbSet(dbContext).AddRangeAsync(entities); - await dbContext.SaveChangesAsync(); - return objs; } - } - public IQueryable Run(IQuery query) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public virtual async Task> CreateMany(List objs) { - var dbContext = GetDatabaseContext(scope); - return query.Run(dbContext); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var entities = new List(); + foreach (var o in objs) + { + o.SetNewId(); + var entity = Mapper.Map(o); + entities.Add(entity); + } + var dbContext = GetDatabaseContext(scope); + await GetDbSet(dbContext).AddRangeAsync(entities); + await dbContext.SaveChangesAsync(); + return objs; + } + } + + public IQueryable Run(IQuery query) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return query.Run(dbContext); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs b/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs index e102ddea7a..691a86c3b2 100644 --- a/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs @@ -4,42 +4,43 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class SendRepository : Repository, ISendRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public SendRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Sends) - { } - - public override async Task CreateAsync(Core.Entities.Send send) + public class SendRepository : Repository, ISendRepository { - send = await base.CreateAsync(send); - if (send.UserId.HasValue) + public SendRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Sends) + { } + + public override async Task CreateAsync(Core.Entities.Send send) { - await UserUpdateStorage(send.UserId.Value); - await UserBumpAccountRevisionDate(send.UserId.Value); + send = await base.CreateAsync(send); + if (send.UserId.HasValue) + { + await UserUpdateStorage(send.UserId.Value); + await UserBumpAccountRevisionDate(send.UserId.Value); + } + return send; } - return send; - } - public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Sends.Where(s => s.DeletionDate < deletionDateBefore).ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Sends.Where(s => s.DeletionDate < deletionDateBefore).ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByUserIdAsync(Guid userId) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Sends.Where(s => s.UserId == userId).ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Sends.Where(s => s.UserId == userId).ToListAsync(); + return Mapper.Map>(results); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs b/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs index c9a772e9a7..8c0d221656 100644 --- a/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs @@ -4,42 +4,43 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class SsoConfigRepository : Repository, ISsoConfigRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public SsoConfigRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoConfigs) - { } - - public async Task GetByOrganizationIdAsync(Guid organizationId) + public class SsoConfigRepository : Repository, ISsoConfigRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public SsoConfigRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoConfigs) + { } + + public async Task GetByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.OrganizationId == organizationId); - return Mapper.Map(ssoConfig); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.OrganizationId == organizationId); + return Mapper.Map(ssoConfig); + } } - } - public async Task GetByIdentifierAsync(string identifier) - { - - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByIdentifierAsync(string identifier) { - var dbContext = GetDatabaseContext(scope); - var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.Organization.Identifier == identifier); - return Mapper.Map(ssoConfig); + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.Organization.Identifier == identifier); + return Mapper.Map(ssoConfig); + } } - } - public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) { - var dbContext = GetDatabaseContext(scope); - var ssoConfigs = await GetDbSet(dbContext).Where(sc => sc.Enabled && sc.RevisionDate >= notBefore).ToListAsync(); - return Mapper.Map>(ssoConfigs); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var ssoConfigs = await GetDbSet(dbContext).Where(sc => sc.Enabled && sc.RevisionDate >= notBefore).ToListAsync(); + return Mapper.Map>(ssoConfigs); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs index f41f0d540e..c413648dc6 100644 --- a/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs @@ -4,33 +4,34 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class SsoUserRepository : Repository, ISsoUserRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public SsoUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoUsers) - { } - - public async Task DeleteAsync(Guid userId, Guid? organizationId) + public class SsoUserRepository : Repository, ISsoUserRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public SsoUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoUsers) + { } + + public async Task DeleteAsync(Guid userId, Guid? organizationId) { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).SingleOrDefaultAsync(su => su.UserId == userId && su.OrganizationId == organizationId); - dbContext.Entry(entity).State = EntityState.Deleted; - await dbContext.SaveChangesAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).SingleOrDefaultAsync(su => su.UserId == userId && su.OrganizationId == organizationId); + dbContext.Entry(entity).State = EntityState.Deleted; + await dbContext.SaveChangesAsync(); + } } - } - public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext) - .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); - return entity; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext) + .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); + return entity; + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs b/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs index fcf4014a10..a575892d43 100644 --- a/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs @@ -4,63 +4,64 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class TaxRateRepository : Repository, ITaxRateRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public TaxRateRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.TaxRates) - { } - - public async Task ArchiveAsync(Core.Entities.TaxRate model) + public class TaxRateRepository : Repository, ITaxRateRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public TaxRateRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.TaxRates) + { } + + public async Task ArchiveAsync(Core.Entities.TaxRate model) { - var dbContext = GetDatabaseContext(scope); - var entity = await dbContext.FindAsync(model); - entity.Active = false; - await dbContext.SaveChangesAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.FindAsync(model); + entity.Active = false; + await dbContext.SaveChangesAsync(); + } } - } - public async Task> GetAllActiveAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetAllActiveAsync() { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.TaxRates - .Where(t => t.Active) - .ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.TaxRates + .Where(t => t.Active) + .ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task> GetByLocationAsync(Core.Entities.TaxRate taxRate) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetByLocationAsync(Core.Entities.TaxRate taxRate) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.TaxRates - .Where(t => t.Active && - t.Country == taxRate.Country && - t.PostalCode == taxRate.PostalCode) - .ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.TaxRates + .Where(t => t.Active && + t.Country == taxRate.Country && + t.PostalCode == taxRate.PostalCode) + .ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task> SearchAsync(int skip, int count) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> SearchAsync(int skip, int count) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.TaxRates - .Skip(skip) - .Take(count) - .Where(t => t.Active) - .OrderBy(t => t.Country).ThenByDescending(t => t.PostalCode) - .ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.TaxRates + .Skip(skip) + .Take(count) + .Where(t => t.Active) + .OrderBy(t => t.Country).ThenByDescending(t => t.PostalCode) + .ToListAsync(); + return Mapper.Map>(results); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs index 45c052cbbd..fad4389ed0 100644 --- a/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs @@ -5,46 +5,47 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class TransactionRepository : Repository, ITransactionRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public TransactionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Transactions) - { } - - public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) + public class TransactionRepository : Repository, ITransactionRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public TransactionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Transactions) + { } + + public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Transactions - .FirstOrDefaultAsync(t => (t.GatewayId == gatewayId && t.Gateway == gatewayType)); - return Mapper.Map(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Transactions + .FirstOrDefaultAsync(t => (t.GatewayId == gatewayId && t.Gateway == gatewayType)); + return Mapper.Map(results); + } } - } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Transactions - .Where(t => (t.OrganizationId == organizationId && !t.UserId.HasValue)) - .ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Transactions + .Where(t => (t.OrganizationId == organizationId && !t.UserId.HasValue)) + .ToListAsync(); + return Mapper.Map>(results); + } } - } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task> GetManyByUserIdAsync(Guid userId) { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Transactions - .Where(t => (t.UserId == userId)) - .ToListAsync(); - return Mapper.Map>(results); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Transactions + .Where(t => (t.UserId == userId)) + .ToListAsync(); + return Mapper.Map>(results); + } } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index be75e85ff8..4074debf25 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -5,141 +5,142 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using DataModel = Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories; - -public class UserRepository : Repository, IUserRepository +namespace Bit.Infrastructure.EntityFramework.Repositories { - public UserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users) - { } - - public async Task GetByEmailAsync(string email) + public class UserRepository : Repository, IUserRepository { - using (var scope = ServiceScopeFactory.CreateScope()) + public UserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users) + { } + + public async Task GetByEmailAsync(string email) { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).FirstOrDefaultAsync(e => e.Email == email); - return Mapper.Map(entity); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).FirstOrDefaultAsync(e => e.Email == email); + return Mapper.Map(entity); + } } - } - public async Task GetKdfInformationByEmailAsync(string email) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetKdfInformationByEmailAsync(string email) { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.Email == email) - .Select(e => new DataModel.UserKdfInformation + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.Email == email) + .Select(e => new DataModel.UserKdfInformation + { + Kdf = e.Kdf, + KdfIterations = e.KdfIterations + }).SingleOrDefaultAsync(); + } + } + + public async Task> SearchAsync(string email, int skip, int take) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + List users; + if (dbContext.Database.IsNpgsql()) { - Kdf = e.Kdf, - KdfIterations = e.KdfIterations - }).SingleOrDefaultAsync(); - } - } - - public async Task> SearchAsync(string email, int skip, int take) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - List users; - if (dbContext.Database.IsNpgsql()) - { - users = await GetDbSet(dbContext) - .Where(e => e.Email == null || - EF.Functions.ILike(EF.Functions.Collate(e.Email, "default"), "a%")) - .OrderBy(e => e.Email) - .Skip(skip).Take(take) - .ToListAsync(); + users = await GetDbSet(dbContext) + .Where(e => e.Email == null || + EF.Functions.ILike(EF.Functions.Collate(e.Email, "default"), "a%")) + .OrderBy(e => e.Email) + .Skip(skip).Take(take) + .ToListAsync(); + } + else + { + users = await GetDbSet(dbContext) + .Where(e => email == null || e.Email.StartsWith(email)) + .OrderBy(e => e.Email) + .Skip(skip).Take(take) + .ToListAsync(); + } + return Mapper.Map>(users); } - else + } + + public async Task> GetManyByPremiumAsync(bool premium) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - users = await GetDbSet(dbContext) - .Where(e => email == null || e.Email.StartsWith(email)) - .OrderBy(e => e.Email) - .Skip(skip).Take(take) - .ToListAsync(); + var dbContext = GetDatabaseContext(scope); + var users = await GetDbSet(dbContext).Where(e => e.Premium == premium).ToListAsync(); + return Mapper.Map>(users); } - return Mapper.Map>(users); } - } - public async Task> GetManyByPremiumAsync(bool premium) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetPublicKeyAsync(Guid id) { - var dbContext = GetDatabaseContext(scope); - var users = await GetDbSet(dbContext).Where(e => e.Premium == premium).ToListAsync(); - return Mapper.Map>(users); - } - } - - public async Task GetPublicKeyAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.PublicKey).SingleOrDefaultAsync(); - } - } - - public async Task GetAccountRevisionDateAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.AccountRevisionDate) - .SingleOrDefaultAsync(); - } - } - - public async Task UpdateStorageAsync(Guid id) - { - await base.UserUpdateStorage(id); - } - - public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var user = new User + using (var scope = ServiceScopeFactory.CreateScope()) { - Id = id, - RenewalReminderDate = renewalReminderDate, - }; - var set = GetDbSet(dbContext); - set.Attach(user); - dbContext.Entry(user).Property(e => e.RenewalReminderDate).IsModified = true; - await dbContext.SaveChangesAsync(); - } - } - - public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ssoUser = await dbContext.SsoUsers.SingleOrDefaultAsync(e => - e.OrganizationId == organizationId && e.ExternalId == externalId); - - if (ssoUser == null) - { - return null; + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.PublicKey).SingleOrDefaultAsync(); } - - var entity = await dbContext.Users.SingleOrDefaultAsync(e => e.Id == ssoUser.UserId); - return Mapper.Map(entity); } - } - public async Task> GetManyAsync(IEnumerable ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) + public async Task GetAccountRevisionDateAsync(Guid id) { - var dbContext = GetDatabaseContext(scope); - var users = dbContext.Users.Where(x => ids.Contains(x.Id)); - return await users.ToListAsync(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.AccountRevisionDate) + .SingleOrDefaultAsync(); + } + } + + public async Task UpdateStorageAsync(Guid id) + { + await base.UserUpdateStorage(id); + } + + public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var user = new User + { + Id = id, + RenewalReminderDate = renewalReminderDate, + }; + var set = GetDbSet(dbContext); + set.Attach(user); + dbContext.Entry(user).Property(e => e.RenewalReminderDate).IsModified = true; + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var ssoUser = await dbContext.SsoUsers.SingleOrDefaultAsync(e => + e.OrganizationId == organizationId && e.ExternalId == externalId); + + if (ssoUser == null) + { + return null; + } + + var entity = await dbContext.Users.SingleOrDefaultAsync(e => e.Id == ssoUser.UserId); + return Mapper.Map(entity); + } + } + + public async Task> GetManyAsync(IEnumerable ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var users = dbContext.Users.Where(x => ids.Contains(x.Id)); + return await users.ToListAsync(); + } } } } diff --git a/src/Notifications/AzureQueueHostedService.cs b/src/Notifications/AzureQueueHostedService.cs index ba2e38d2c0..edc735d86f 100644 --- a/src/Notifications/AzureQueueHostedService.cs +++ b/src/Notifications/AzureQueueHostedService.cs @@ -3,90 +3,91 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; - -public class AzureQueueHostedService : IHostedService, IDisposable +namespace Bit.Notifications { - private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly GlobalSettings _globalSettings; - - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; - - public AzureQueueHostedService( - ILogger logger, - IHubContext hubContext, - GlobalSettings globalSettings) + public class AzureQueueHostedService : IHostedService, IDisposable { - _logger = logger; - _hubContext = hubContext; - _globalSettings = globalSettings; - } + private readonly ILogger _logger; + private readonly IHubContext _hubContext; + private readonly GlobalSettings _globalSettings; - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + private Task _executingTask; + private CancellationTokenSource _cts; + private QueueClient _queueClient; - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) + public AzureQueueHostedService( + ILogger logger, + IHubContext hubContext, + GlobalSettings globalSettings) { - return; + _logger = logger; + _hubContext = hubContext; + _globalSettings = globalSettings; } - _logger.LogWarning("Stopping service."); - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - public void Dispose() - { } - - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); - while (!cancellationToken.IsCancellationRequested) + public Task StartAsync(CancellationToken cancellationToken) { - try + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - var messages = await _queueClient.ReceiveMessagesAsync(32); - if (messages.Value?.Any() ?? false) + return; + } + _logger.LogWarning("Stopping service."); + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } + + public void Dispose() + { } + + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); + while (!cancellationToken.IsCancellationRequested) + { + try { - foreach (var message in messages.Value) + var messages = await _queueClient.ReceiveMessagesAsync(32); + if (messages.Value?.Any() ?? false) { - try + foreach (var message in messages.Value) { - await HubHelpers.SendNotificationToHubAsync( - message.DecodeMessageText(), _hubContext, cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); - } - catch (Exception e) - { - _logger.LogError("Error processing dequeued message: " + - $"{message.MessageId} x{message.DequeueCount}. {e.Message}", e); - if (message.DequeueCount > 2) + try { + await HubHelpers.SendNotificationToHubAsync( + message.DecodeMessageText(), _hubContext, cancellationToken); await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } + catch (Exception e) + { + _logger.LogError("Error processing dequeued message: " + + $"{message.MessageId} x{message.DequeueCount}. {e.Message}", e); + if (message.DequeueCount > 2) + { + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + } + } } } + else + { + await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); + } } - else + catch (Exception e) { - await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); + _logger.LogError("Error processing messages.", e); } } - catch (Exception e) - { - _logger.LogError("Error processing messages.", e); - } - } - _logger.LogWarning("Done processing."); + _logger.LogWarning("Done processing."); + } } } diff --git a/src/Notifications/ConnectionCounter.cs b/src/Notifications/ConnectionCounter.cs index 25d3156168..330b8ee6dc 100644 --- a/src/Notifications/ConnectionCounter.cs +++ b/src/Notifications/ConnectionCounter.cs @@ -1,26 +1,27 @@ -namespace Bit.Notifications; - -public class ConnectionCounter +namespace Bit.Notifications { - private int _count = 0; - - public void Increment() + public class ConnectionCounter { - Interlocked.Increment(ref _count); - } + private int _count = 0; - public void Decrement() - { - Interlocked.Decrement(ref _count); - } + public void Increment() + { + Interlocked.Increment(ref _count); + } - public void Reset() - { - _count = 0; - } + public void Decrement() + { + Interlocked.Decrement(ref _count); + } - public int GetCount() - { - return _count; + public void Reset() + { + _count = 0; + } + + public int GetCount() + { + return _count; + } } } diff --git a/src/Notifications/Controllers/InfoController.cs b/src/Notifications/Controllers/InfoController.cs index 6a8eaf2827..402fc49376 100644 --- a/src/Notifications/Controllers/InfoController.cs +++ b/src/Notifications/Controllers/InfoController.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Notifications.Controllers; - -public class InfoController : Controller +namespace Bit.Notifications.Controllers { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() + public class InfoController : Controller { - return DateTime.UtcNow; - } + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } } } diff --git a/src/Notifications/Controllers/SendController.cs b/src/Notifications/Controllers/SendController.cs index 90fdac7d09..81698c9113 100644 --- a/src/Notifications/Controllers/SendController.cs +++ b/src/Notifications/Controllers/SendController.cs @@ -4,28 +4,29 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; - -[Authorize("Internal")] -public class SendController : Controller +namespace Bit.Notifications { - private readonly IHubContext _hubContext; - - public SendController(IHubContext hubContext) + [Authorize("Internal")] + public class SendController : Controller { - _hubContext = hubContext; - } + private readonly IHubContext _hubContext; - [HttpPost("~/send")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostSend() - { - using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) + public SendController(IHubContext hubContext) { - var notificationJson = await reader.ReadToEndAsync(); - if (!string.IsNullOrWhiteSpace(notificationJson)) + _hubContext = hubContext; + } + + [HttpPost("~/send")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostSend() + { + using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) { - await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext); + var notificationJson = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(notificationJson)) + { + await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext); + } } } } diff --git a/src/Notifications/HeartbeatHostedService.cs b/src/Notifications/HeartbeatHostedService.cs index 717fdeb784..e916669263 100644 --- a/src/Notifications/HeartbeatHostedService.cs +++ b/src/Notifications/HeartbeatHostedService.cs @@ -1,56 +1,57 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; - -public class HeartbeatHostedService : IHostedService, IDisposable +namespace Bit.Notifications { - private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly GlobalSettings _globalSettings; - - private Task _executingTask; - private CancellationTokenSource _cts; - - public HeartbeatHostedService( - ILogger logger, - IHubContext hubContext, - GlobalSettings globalSettings) + public class HeartbeatHostedService : IHostedService, IDisposable { - _logger = logger; - _hubContext = hubContext; - _globalSettings = globalSettings; - } + private readonly ILogger _logger; + private readonly IHubContext _hubContext; + private readonly GlobalSettings _globalSettings; - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + private Task _executingTask; + private CancellationTokenSource _cts; - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) + public HeartbeatHostedService( + ILogger logger, + IHubContext hubContext, + GlobalSettings globalSettings) { - return; + _logger = logger; + _hubContext = hubContext; + _globalSettings = globalSettings; } - _logger.LogWarning("Stopping service."); - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - public void Dispose() - { } - - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - while (!cancellationToken.IsCancellationRequested) + public Task StartAsync(CancellationToken cancellationToken) { - await _hubContext.Clients.All.SendAsync("Heartbeat"); - await Task.Delay(120000); + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) + { + return; + } + _logger.LogWarning("Stopping service."); + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } + + public void Dispose() + { } + + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + await _hubContext.Clients.All.SendAsync("Heartbeat"); + await Task.Delay(120000); + } + _logger.LogWarning("Done with heartbeat."); } - _logger.LogWarning("Done with heartbeat."); } } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 38b87e2276..2ba0037f4e 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -3,66 +3,67 @@ using Bit.Core.Enums; using Bit.Core.Models; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; - -public static class HubHelpers +namespace Bit.Notifications { - public static async Task SendNotificationToHubAsync(string notificationJson, - IHubContext hubContext, CancellationToken cancellationToken = default(CancellationToken)) + public static class HubHelpers { - var notification = JsonSerializer.Deserialize>(notificationJson); - switch (notification.Type) + public static async Task SendNotificationToHubAsync(string notificationJson, + IHubContext hubContext, CancellationToken cancellationToken = default(CancellationToken)) { - case PushType.SyncCipherUpdate: - case PushType.SyncCipherCreate: - case PushType.SyncCipherDelete: - case PushType.SyncLoginDelete: - var cipherNotification = - JsonSerializer.Deserialize>( - notificationJson); - if (cipherNotification.Payload.UserId.HasValue) - { - await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); - } - else if (cipherNotification.Payload.OrganizationId.HasValue) - { - await hubContext.Clients.Group( - $"Organization_{cipherNotification.Payload.OrganizationId}") - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); - } - break; - case PushType.SyncFolderUpdate: - case PushType.SyncFolderCreate: - case PushType.SyncFolderDelete: - var folderNotification = - JsonSerializer.Deserialize>( - notificationJson); - await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", folderNotification, cancellationToken); - break; - case PushType.SyncCiphers: - case PushType.SyncVault: - case PushType.SyncOrgKeys: - case PushType.SyncSettings: - case PushType.LogOut: - var userNotification = - JsonSerializer.Deserialize>( - notificationJson); - await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", userNotification, cancellationToken); - break; - case PushType.SyncSendCreate: - case PushType.SyncSendUpdate: - case PushType.SyncSendDelete: - var sendNotification = - JsonSerializer.Deserialize>( + var notification = JsonSerializer.Deserialize>(notificationJson); + switch (notification.Type) + { + case PushType.SyncCipherUpdate: + case PushType.SyncCipherCreate: + case PushType.SyncCipherDelete: + case PushType.SyncLoginDelete: + var cipherNotification = + JsonSerializer.Deserialize>( notificationJson); - await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", sendNotification, cancellationToken); - break; - default: - break; + if (cipherNotification.Payload.UserId.HasValue) + { + await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + } + else if (cipherNotification.Payload.OrganizationId.HasValue) + { + await hubContext.Clients.Group( + $"Organization_{cipherNotification.Payload.OrganizationId}") + .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + } + break; + case PushType.SyncFolderUpdate: + case PushType.SyncFolderCreate: + case PushType.SyncFolderDelete: + var folderNotification = + JsonSerializer.Deserialize>( + notificationJson); + await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", folderNotification, cancellationToken); + break; + case PushType.SyncCiphers: + case PushType.SyncVault: + case PushType.SyncOrgKeys: + case PushType.SyncSettings: + case PushType.LogOut: + var userNotification = + JsonSerializer.Deserialize>( + notificationJson); + await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", userNotification, cancellationToken); + break; + case PushType.SyncSendCreate: + case PushType.SyncSendUpdate: + case PushType.SyncSendDelete: + var sendNotification = + JsonSerializer.Deserialize>( + notificationJson); + await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", sendNotification, cancellationToken); + break; + default: + break; + } } } } diff --git a/src/Notifications/Jobs/JobsHostedService.cs b/src/Notifications/Jobs/JobsHostedService.cs index a1f84e18b7..326e8dce0e 100644 --- a/src/Notifications/Jobs/JobsHostedService.cs +++ b/src/Notifications/Jobs/JobsHostedService.cs @@ -2,35 +2,36 @@ using Bit.Core.Settings; using Quartz; -namespace Bit.Notifications.Jobs; - -public class JobsHostedService : BaseJobsHostedService +namespace Bit.Notifications.Jobs { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } - - public override async Task StartAsync(CancellationToken cancellationToken) + public class JobsHostedService : BaseJobsHostedService { - var everyFiveMinutesTrigger = TriggerBuilder.Create() - .WithIdentity("EveryFiveMinutesTrigger") - .StartNow() - .WithCronSchedule("0 */30 * * * ?") - .Build(); + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } - Jobs = new List> + public override async Task StartAsync(CancellationToken cancellationToken) { - new Tuple(typeof(LogConnectionCounterJob), everyFiveMinutesTrigger) - }; + var everyFiveMinutesTrigger = TriggerBuilder.Create() + .WithIdentity("EveryFiveMinutesTrigger") + .StartNow() + .WithCronSchedule("0 */30 * * * ?") + .Build(); - await base.StartAsync(cancellationToken); - } + Jobs = new List> + { + new Tuple(typeof(LogConnectionCounterJob), everyFiveMinutesTrigger) + }; - public static void AddJobsServices(IServiceCollection services) - { - services.AddTransient(); + await base.StartAsync(cancellationToken); + } + + public static void AddJobsServices(IServiceCollection services) + { + services.AddTransient(); + } } } diff --git a/src/Notifications/Jobs/LogConnectionCounterJob.cs b/src/Notifications/Jobs/LogConnectionCounterJob.cs index 9b4e2ee4fd..6b7bc70fff 100644 --- a/src/Notifications/Jobs/LogConnectionCounterJob.cs +++ b/src/Notifications/Jobs/LogConnectionCounterJob.cs @@ -2,24 +2,25 @@ using Bit.Core.Jobs; using Quartz; -namespace Bit.Notifications.Jobs; - -public class LogConnectionCounterJob : BaseJob +namespace Bit.Notifications.Jobs { - private readonly ConnectionCounter _connectionCounter; - - public LogConnectionCounterJob( - ILogger logger, - ConnectionCounter connectionCounter) - : base(logger) + public class LogConnectionCounterJob : BaseJob { - _connectionCounter = connectionCounter; - } + private readonly ConnectionCounter _connectionCounter; - protected override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, - "Connection count for server {0}: {1}", Environment.MachineName, _connectionCounter.GetCount()); - return Task.FromResult(0); + public LogConnectionCounterJob( + ILogger logger, + ConnectionCounter connectionCounter) + : base(logger) + { + _connectionCounter = connectionCounter; + } + + protected override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, + "Connection count for server {0}: {1}", Environment.MachineName, _connectionCounter.GetCount()); + return Task.FromResult(0); + } } } diff --git a/src/Notifications/NotificationsHub.cs b/src/Notifications/NotificationsHub.cs index 6d7a66b894..7d6e94a429 100644 --- a/src/Notifications/NotificationsHub.cs +++ b/src/Notifications/NotificationsHub.cs @@ -2,47 +2,48 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; -namespace Bit.Notifications; - -[Authorize("Application")] -public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub +namespace Bit.Notifications { - private readonly ConnectionCounter _connectionCounter; - private readonly GlobalSettings _globalSettings; - - public NotificationsHub(ConnectionCounter connectionCounter, GlobalSettings globalSettings) + [Authorize("Application")] + public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { - _connectionCounter = connectionCounter; - _globalSettings = globalSettings; - } + private readonly ConnectionCounter _connectionCounter; + private readonly GlobalSettings _globalSettings; - public override async Task OnConnectedAsync() - { - var currentContext = new CurrentContext(null); - await currentContext.BuildAsync(Context.User, _globalSettings); - if (currentContext.Organizations != null) + public NotificationsHub(ConnectionCounter connectionCounter, GlobalSettings globalSettings) { - foreach (var org in currentContext.Organizations) - { - await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); - } + _connectionCounter = connectionCounter; + _globalSettings = globalSettings; } - _connectionCounter.Increment(); - await base.OnConnectedAsync(); - } - public override async Task OnDisconnectedAsync(Exception exception) - { - var currentContext = new CurrentContext(null); - await currentContext.BuildAsync(Context.User, _globalSettings); - if (currentContext.Organizations != null) + public override async Task OnConnectedAsync() { - foreach (var org in currentContext.Organizations) + var currentContext = new CurrentContext(null); + await currentContext.BuildAsync(Context.User, _globalSettings); + if (currentContext.Organizations != null) { - await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + foreach (var org in currentContext.Organizations) + { + await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + } } + _connectionCounter.Increment(); + await base.OnConnectedAsync(); + } + + public override async Task OnDisconnectedAsync(Exception exception) + { + var currentContext = new CurrentContext(null); + await currentContext.BuildAsync(Context.User, _globalSettings); + if (currentContext.Organizations != null) + { + foreach (var org in currentContext.Organizations) + { + await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + } + } + _connectionCounter.Decrement(); + await base.OnDisconnectedAsync(exception); } - _connectionCounter.Decrement(); - await base.OnDisconnectedAsync(exception); } } diff --git a/src/Notifications/Program.cs b/src/Notifications/Program.cs index 4834972abb..8ea3a5a1b7 100644 --- a/src/Notifications/Program.cs +++ b/src/Notifications/Program.cs @@ -1,50 +1,51 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Notifications; - -public class Program +namespace Bit.Notifications { - public static void Main(string[] args) + public class Program { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + public static void Main(string[] args) + { + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - return e.Level > LogEventLevel.Error; - } + var context = e.Properties["SourceContext"].ToString(); + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + { + return e.Level > LogEventLevel.Error; + } - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text == "Failed connection handshake.") - { - return false; - } + if (e.Level == LogEventLevel.Error && + e.MessageTemplate.Text == "Failed connection handshake.") + { + return false; + } - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text.StartsWith("Failed writing message.")) - { - return false; - } + if (e.Level == LogEventLevel.Error && + e.MessageTemplate.Text.StartsWith("Failed writing message.")) + { + return false; + } - if (e.Level == LogEventLevel.Warning && - e.MessageTemplate.Text.StartsWith("Heartbeat took longer")) - { - return false; - } + if (e.Level == LogEventLevel.Warning && + e.MessageTemplate.Text.StartsWith("Heartbeat took longer")) + { + return false; + } - return e.Level >= LogEventLevel.Warning; - })); - }) - .Build() - .Run(); + return e.Level >= LogEventLevel.Warning; + })); + }) + .Build() + .Run(); + } } } diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index e2509e46c5..c42b9fcd22 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -6,114 +6,115 @@ using IdentityModel; using Microsoft.AspNetCore.SignalR; using Microsoft.IdentityModel.Logging; -namespace Bit.Notifications; - -public class Startup +namespace Bit.Notifications { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + public class Startup { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - - // Identity - services.AddIdentityAuthenticationServices(globalSettings, Environment, config => + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - config.AddPolicy("Application", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - }); - config.AddPolicy("Internal", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "internal"); - }); - }); - - // SignalR - var signalRServerBuilder = services.AddSignalR().AddMessagePackProtocol(options => - { - options.SerializerOptions = MessagePack.MessagePackSerializerOptions.Standard - .WithResolver(MessagePack.Resolvers.ContractlessStandardResolver.Instance); - }); - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.RedisConnectionString)) - { - signalRServerBuilder.AddStackExchangeRedis(globalSettings.Notifications.RedisConnectionString, - options => - { - options.Configuration.ChannelPrefix = "Notifications"; - }); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; } - services.AddSingleton(); - services.AddSingleton(); - // Mvc - services.AddMvc(); + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } - services.AddHostedService(); - if (!globalSettings.SelfHosted) + public void ConfigureServices(IServiceCollection services) { - // Hosted Services - Jobs.JobsHostedService.AddJobsServices(services); - services.AddHostedService(); - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + + // Identity + services.AddIdentityAuthenticationServices(globalSettings, Environment, config => { - services.AddHostedService(); + config.AddPolicy("Application", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + }); + config.AddPolicy("Internal", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "internal"); + }); + }); + + // SignalR + var signalRServerBuilder = services.AddSignalR().AddMessagePackProtocol(options => + { + options.SerializerOptions = MessagePack.MessagePackSerializerOptions.Standard + .WithResolver(MessagePack.Resolvers.ContractlessStandardResolver.Instance); + }); + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.RedisConnectionString)) + { + signalRServerBuilder.AddStackExchangeRedis(globalSettings.Notifications.RedisConnectionString, + options => + { + options.Configuration.ChannelPrefix = "Notifications"; + }); + } + services.AddSingleton(); + services.AddSingleton(); + + // Mvc + services.AddMvc(); + + services.AddHostedService(); + if (!globalSettings.SelfHosted) + { + // Hosted Services + Jobs.JobsHostedService.AddJobsServices(services); + services.AddHostedService(); + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) + { + services.AddHostedService(); + } } } - } - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) { - app.UseDeveloperExceptionPage(); - } + IdentityModelEventSource.ShowPII = true; + app.UseSerilog(env, appLifetime, globalSettings); - // Add routing - app.UseRouting(); + // Add general security headers + app.UseMiddleware(); - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add authentication to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add endpoints to the request pipeline. - app.UseEndpoints(endpoints => - { - endpoints.MapHub("/hub", options => + if (env.IsDevelopment()) { - options.ApplicationMaxBufferSize = 2048; // client => server messages are not even used - options.TransportMaxBufferSize = 4096; + app.UseDeveloperExceptionPage(); + } + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add authentication to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add endpoints to the request pipeline. + app.UseEndpoints(endpoints => + { + endpoints.MapHub("/hub", options => + { + options.ApplicationMaxBufferSize = 2048; // client => server messages are not even used + options.TransportMaxBufferSize = 4096; + }); + endpoints.MapDefaultControllerRoute(); }); - endpoints.MapDefaultControllerRoute(); - }); + } } } diff --git a/src/Notifications/SubjectUserIdProvider.cs b/src/Notifications/SubjectUserIdProvider.cs index 261394d06c..ee6ab6be57 100644 --- a/src/Notifications/SubjectUserIdProvider.cs +++ b/src/Notifications/SubjectUserIdProvider.cs @@ -1,12 +1,13 @@ using IdentityModel; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; - -public class SubjectUserIdProvider : IUserIdProvider +namespace Bit.Notifications { - public string GetUserId(HubConnectionContext connection) + public class SubjectUserIdProvider : IUserIdProvider { - return connection.User?.FindFirst(JwtClaimTypes.Subject)?.Value; + public string GetUserId(HubConnectionContext connection) + { + return connection.User?.FindFirst(JwtClaimTypes.Subject)?.Value; + } } } diff --git a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs index f43544bca4..2de2a4d732 100644 --- a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs @@ -8,83 +8,84 @@ using Microsoft.Extensions.Logging; using Microsoft.IdentityModel.Tokens; using InternalApi = Bit.Core.Models.Api; -namespace Bit.SharedWeb.Utilities; - -public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute +namespace Bit.SharedWeb.Utilities { - public ExceptionHandlerFilterAttribute() + public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute { - } - - public override void OnException(ExceptionContext context) - { - var errorMessage = "An error has occurred."; - - var exception = context.Exception; - if (exception == null) + public ExceptionHandlerFilterAttribute() { - // Should never happen. - return; } - InternalApi.ErrorResponseModel internalErrorModel = null; - if (exception is BadRequestException badRequestException) + public override void OnException(ExceptionContext context) { - context.HttpContext.Response.StatusCode = 400; - if (badRequestException.ModelState != null) + var errorMessage = "An error has occurred."; + + var exception = context.Exception; + if (exception == null) { - internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); + // Should never happen. + return; + } + + InternalApi.ErrorResponseModel internalErrorModel = null; + if (exception is BadRequestException badRequestException) + { + context.HttpContext.Response.StatusCode = 400; + if (badRequestException.ModelState != null) + { + internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); + } + else + { + errorMessage = badRequestException.Message; + } + } + else if (exception is GatewayException) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is ApplicationException) + { + context.HttpContext.Response.StatusCode = 402; + } + else if (exception is NotFoundException) + { + errorMessage = "Resource not found."; + context.HttpContext.Response.StatusCode = 404; + } + else if (exception is SecurityTokenValidationException) + { + errorMessage = "Invalid token."; + context.HttpContext.Response.StatusCode = 403; + } + else if (exception is UnauthorizedAccessException) + { + errorMessage = "Unauthorized."; + context.HttpContext.Response.StatusCode = 401; } else { - errorMessage = badRequestException.Message; + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + logger.LogError(0, exception, exception.Message); + errorMessage = "An unhandled server error has occurred."; + context.HttpContext.Response.StatusCode = 500; } - } - else if (exception is GatewayException) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is ApplicationException) - { - context.HttpContext.Response.StatusCode = 402; - } - else if (exception is NotFoundException) - { - errorMessage = "Resource not found."; - context.HttpContext.Response.StatusCode = 404; - } - else if (exception is SecurityTokenValidationException) - { - errorMessage = "Invalid token."; - context.HttpContext.Response.StatusCode = 403; - } - else if (exception is UnauthorizedAccessException) - { - errorMessage = "Unauthorized."; - context.HttpContext.Response.StatusCode = 401; - } - else - { - var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); - errorMessage = "An unhandled server error has occurred."; - context.HttpContext.Response.StatusCode = 500; - } - var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); - var env = context.HttpContext.RequestServices.GetRequiredService(); - if (env.IsDevelopment()) - { - errorModel.ExceptionMessage = exception.Message; - errorModel.ExceptionStackTrace = exception.StackTrace; - errorModel.InnerExceptionMessage = exception?.InnerException?.Message; + var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); + var env = context.HttpContext.RequestServices.GetRequiredService(); + if (env.IsDevelopment()) + { + errorModel.ExceptionMessage = exception.Message; + errorModel.ExceptionStackTrace = exception.StackTrace; + errorModel.InnerExceptionMessage = exception?.InnerException?.Message; + } + context.Result = new ObjectResult(errorModel); } - context.Result = new ObjectResult(errorModel); } } diff --git a/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs b/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs index c4dfbfb89e..11d642f321 100644 --- a/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs +++ b/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs @@ -2,30 +2,31 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; -namespace Bit.SharedWeb.Utilities; - -public class ModelStateValidationFilterAttribute : ActionFilterAttribute +namespace Bit.SharedWeb.Utilities { - public ModelStateValidationFilterAttribute() + public class ModelStateValidationFilterAttribute : ActionFilterAttribute { - } - - public override void OnActionExecuting(ActionExecutingContext context) - { - var model = context.ActionArguments.FirstOrDefault(a => a.Key == "model"); - if (model.Key == "model" && model.Value == null) + public ModelStateValidationFilterAttribute() { - context.ModelState.AddModelError(string.Empty, "Body is empty."); } - if (!context.ModelState.IsValid) + public override void OnActionExecuting(ActionExecutingContext context) { - OnModelStateInvalid(context); - } - } + var model = context.ActionArguments.FirstOrDefault(a => a.Key == "model"); + if (model.Key == "model" && model.Value == null) + { + context.ModelState.AddModelError(string.Empty, "Body is empty."); + } - protected virtual void OnModelStateInvalid(ActionExecutingContext context) - { - context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); + if (!context.ModelState.IsValid) + { + OnModelStateInvalid(context); + } + } + + protected virtual void OnModelStateInvalid(ActionExecutingContext context) + { + context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); + } } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index b2efe511d0..9102a3a18c 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -43,612 +43,613 @@ using NoopRepos = Bit.Core.Repositories.Noop; using Role = Bit.Core.Entities.Role; using TableStorageRepos = Bit.Core.Repositories.TableStorage; -namespace Bit.SharedWeb.Utilities; - -public static class ServiceCollectionExtensions +namespace Bit.SharedWeb.Utilities { - public static void AddSqlServerRepositories(this IServiceCollection services, GlobalSettings globalSettings) + public static class ServiceCollectionExtensions { - var selectedDatabaseProvider = globalSettings.DatabaseProvider; - var provider = SupportedDatabaseProviders.SqlServer; - var connectionString = string.Empty; - if (!string.IsNullOrWhiteSpace(selectedDatabaseProvider)) + public static void AddSqlServerRepositories(this IServiceCollection services, GlobalSettings globalSettings) { - switch (selectedDatabaseProvider.ToLowerInvariant()) + var selectedDatabaseProvider = globalSettings.DatabaseProvider; + var provider = SupportedDatabaseProviders.SqlServer; + var connectionString = string.Empty; + if (!string.IsNullOrWhiteSpace(selectedDatabaseProvider)) { - case "postgres": - case "postgresql": - provider = SupportedDatabaseProviders.Postgres; - connectionString = globalSettings.PostgreSql.ConnectionString; - break; - case "mysql": - case "mariadb": - provider = SupportedDatabaseProviders.MySql; - connectionString = globalSettings.MySql.ConnectionString; - break; - default: - break; + switch (selectedDatabaseProvider.ToLowerInvariant()) + { + case "postgres": + case "postgresql": + provider = SupportedDatabaseProviders.Postgres; + connectionString = globalSettings.PostgreSql.ConnectionString; + break; + case "mysql": + case "mariadb": + provider = SupportedDatabaseProviders.MySql; + connectionString = globalSettings.MySql.ConnectionString; + break; + default: + break; + } + } + + var useEf = (provider != SupportedDatabaseProviders.SqlServer); + + if (useEf) + { + services.AddEFRepositories(globalSettings.SelfHosted, connectionString, provider); + } + else + { + services.AddDapperRepositories(globalSettings.SelfHosted); + } + + if (globalSettings.SelfHosted) + { + services.AddSingleton(); + services.AddSingleton(); + } + else + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); } } - var useEf = (provider != SupportedDatabaseProviders.SqlServer); - - if (useEf) + public static void AddBaseServices(this IServiceCollection services, IGlobalSettings globalSettings) { - services.AddEFRepositories(globalSettings.SelfHosted, connectionString, provider); - } - else - { - services.AddDapperRepositories(globalSettings.SelfHosted); + services.AddScoped(); + services.AddScoped(); + services.AddOrganizationServices(globalSettings); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddSingleton(); + services.AddSingleton(); + services.AddScoped(); + services.AddScoped(); } - if (globalSettings.SelfHosted) + public static void AddTokenizers(this IServiceCollection services) { - services.AddSingleton(); - services.AddSingleton(); + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + EmergencyAccessInviteTokenable.ClearTextPrefix, + EmergencyAccessInviteTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider()) + ); + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + HCaptchaTokenable.ClearTextPrefix, + HCaptchaTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider()) + ); + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + SsoTokenable.ClearTextPrefix, + SsoTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider())); } - else + + public static void AddDefaultServices(this IServiceCollection services, GlobalSettings globalSettings) { - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - } - } + // Required for UserService + services.AddWebAuthn(globalSettings); + // Required for HTTP calls + services.AddHttpClient(); - public static void AddBaseServices(this IServiceCollection services, IGlobalSettings globalSettings) - { - services.AddScoped(); - services.AddScoped(); - services.AddOrganizationServices(globalSettings); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddSingleton(); - services.AddSingleton(); - services.AddScoped(); - services.AddScoped(); - } - - public static void AddTokenizers(this IServiceCollection services) - { - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - EmergencyAccessInviteTokenable.ClearTextPrefix, - EmergencyAccessInviteTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider()) - ); - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - HCaptchaTokenable.ClearTextPrefix, - HCaptchaTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider()) - ); - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - SsoTokenable.ClearTextPrefix, - SsoTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider())); - } - - public static void AddDefaultServices(this IServiceCollection services, GlobalSettings globalSettings) - { - // Required for UserService - services.AddWebAuthn(globalSettings); - // Required for HTTP calls - services.AddHttpClient(); - - services.AddSingleton(); - services.AddSingleton((serviceProvider) => - { - return new Braintree.BraintreeGateway + services.AddSingleton(); + services.AddSingleton((serviceProvider) => { - Environment = globalSettings.Braintree.Production ? - Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, - MerchantId = globalSettings.Braintree.MerchantId, - PublicKey = globalSettings.Braintree.PublicKey, - PrivateKey = globalSettings.Braintree.PrivateKey - }; - }); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddTokenizers(); + return new Braintree.BraintreeGateway + { + Environment = globalSettings.Braintree.Production ? + Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, + MerchantId = globalSettings.Braintree.MerchantId, + PublicKey = globalSettings.Braintree.PublicKey, + PrivateKey = globalSettings.Braintree.PrivateKey + }; + }); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddTokenizers(); - if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); + if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + var awsConfigured = CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret); + if (awsConfigured && CoreHelpers.SettingHasValue(globalSettings.Mail?.SendGridApiKey)) + { + services.AddSingleton(); + } + else if (awsConfigured) + { + services.AddSingleton(); + } + else if (CoreHelpers.SettingHasValue(globalSettings.Mail?.Smtp?.Host)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + services.AddSingleton(); + if (globalSettings.SelfHosted && + CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && + globalSettings.Installation?.Id != null && + CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + { + services.AddSingleton(); + } + else if (!globalSettings.SelfHosted) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) + { + services.AddSingleton(); + } + else if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) + { + services.AddSingleton(); + } + else if (globalSettings.SelfHosted) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (CoreHelpers.SettingHasValue(globalSettings.Attachment.ConnectionString)) + { + services.AddSingleton(); + } + else if (CoreHelpers.SettingHasValue(globalSettings.Attachment.BaseDirectory)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (CoreHelpers.SettingHasValue(globalSettings.Send.ConnectionString)) + { + services.AddSingleton(); + } + else if (CoreHelpers.SettingHasValue(globalSettings.Send.BaseDirectory)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (globalSettings.SelfHosted) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSecretKey) && + CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSiteKey)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } } - var awsConfigured = CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret); - if (awsConfigured && CoreHelpers.SettingHasValue(globalSettings.Mail?.SendGridApiKey)) + public static void AddOosServices(this IServiceCollection services) { - services.AddSingleton(); + services.AddScoped(); } - else if (awsConfigured) - { - services.AddSingleton(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Mail?.Smtp?.Host)) - { - services.AddSingleton(); - } - else + + public static void AddNoopServices(this IServiceCollection services) { + services.AddSingleton(); services.AddSingleton(); - } - - services.AddSingleton(); - if (globalSettings.SelfHosted && - CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) - { - services.AddSingleton(); - } - else if (!globalSettings.SelfHosted) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) - { - services.AddSingleton(); - } - else if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) - { - services.AddSingleton(); - } - else - { + services.AddSingleton(); services.AddSingleton(); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.AddSingleton(); - } - else if (globalSettings.SelfHosted) - { - services.AddSingleton(); - } - else - { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); } - if (CoreHelpers.SettingHasValue(globalSettings.Attachment.ConnectionString)) + public static IdentityBuilder AddCustomIdentityServices( + this IServiceCollection services, GlobalSettings globalSettings) { - services.AddSingleton(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Attachment.BaseDirectory)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (CoreHelpers.SettingHasValue(globalSettings.Send.ConnectionString)) - { - services.AddSingleton(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Send.BaseDirectory)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (globalSettings.SelfHosted) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSecretKey) && - CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSiteKey)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - } - - public static void AddOosServices(this IServiceCollection services) - { - services.AddScoped(); - } - - public static void AddNoopServices(this IServiceCollection services) - { - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - } - - public static IdentityBuilder AddCustomIdentityServices( - this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddSingleton(); - services.Configure(options => options.IterationCount = 100000); - services.Configure(options => - { - options.TokenLifespan = TimeSpan.FromDays(30); - }); - - var identityBuilder = services.AddIdentityWithoutCookieAuth(options => - { - options.User = new UserOptions + services.AddSingleton(); + services.Configure(options => options.IterationCount = 100000); + services.Configure(options => { - RequireUniqueEmail = true, - AllowedUserNameCharacters = null // all - }; - options.Password = new PasswordOptions - { - RequireDigit = false, - RequireLowercase = false, - RequiredLength = 8, - RequireNonAlphanumeric = false, - RequireUppercase = false - }; - options.ClaimsIdentity = new ClaimsIdentityOptions - { - SecurityStampClaimType = "sstamp", - UserNameClaimType = JwtClaimTypes.Email, - UserIdClaimType = JwtClaimTypes.Subject - }; - options.Tokens.ChangeEmailTokenProvider = TokenOptions.DefaultEmailProvider; - }); - - identityBuilder - .AddUserStore() - .AddRoleStore() - .AddTokenProvider>(TokenOptions.DefaultProvider) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Email)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Duo)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)) - .AddTokenProvider>(TokenOptions.DefaultEmailProvider) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.WebAuthn)); - - return identityBuilder; - } - - public static Tuple AddPasswordlessIdentityServices( - this IServiceCollection services, GlobalSettings globalSettings) where TUserStore : class - { - services.TryAddTransient(); - services.Configure(options => - { - options.TokenLifespan = TimeSpan.FromMinutes(15); - }); - - var passwordlessIdentityBuilder = services.AddIdentity() - .AddUserStore() - .AddRoleStore() - .AddDefaultTokenProviders(); - - var regularIdentityBuilder = services.AddIdentityCore() - .AddUserStore(); - - services.TryAddScoped, PasswordlessSignInManager>(); - - services.ConfigureApplicationCookie(options => - { - options.LoginPath = "/login"; - options.LogoutPath = "/"; - options.AccessDeniedPath = "/login?accessDenied=true"; - options.Cookie.Name = $"Bitwarden_{globalSettings.ProjectName}"; - options.Cookie.HttpOnly = true; - options.ExpireTimeSpan = TimeSpan.FromDays(2); - options.ReturnUrlParameter = "returnUrl"; - options.SlidingExpiration = true; - }); - - return new Tuple(passwordlessIdentityBuilder, regularIdentityBuilder); - } - - public static void AddIdentityAuthenticationServices( - this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment, - Action addAuthorization) - { - services - .AddAuthentication(IdentityServerAuthenticationDefaults.AuthenticationScheme) - .AddIdentityServerAuthentication(options => - { - options.Authority = globalSettings.BaseServiceUri.InternalIdentity; - options.RequireHttpsMetadata = !environment.IsDevelopment() && - globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); - options.TokenRetriever = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); - options.NameClaimType = ClaimTypes.Email; - options.SupportedTokens = SupportedTokens.Jwt; + options.TokenLifespan = TimeSpan.FromDays(30); }); - if (addAuthorization != null) - { - services.AddAuthorization(config => + var identityBuilder = services.AddIdentityWithoutCookieAuth(options => { - addAuthorization.Invoke(config); + options.User = new UserOptions + { + RequireUniqueEmail = true, + AllowedUserNameCharacters = null // all + }; + options.Password = new PasswordOptions + { + RequireDigit = false, + RequireLowercase = false, + RequiredLength = 8, + RequireNonAlphanumeric = false, + RequireUppercase = false + }; + options.ClaimsIdentity = new ClaimsIdentityOptions + { + SecurityStampClaimType = "sstamp", + UserNameClaimType = JwtClaimTypes.Email, + UserIdClaimType = JwtClaimTypes.Subject + }; + options.Tokens.ChangeEmailTokenProvider = TokenOptions.DefaultEmailProvider; + }); + + identityBuilder + .AddUserStore() + .AddRoleStore() + .AddTokenProvider>(TokenOptions.DefaultProvider) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Email)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Duo)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)) + .AddTokenProvider>(TokenOptions.DefaultEmailProvider) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.WebAuthn)); + + return identityBuilder; + } + + public static Tuple AddPasswordlessIdentityServices( + this IServiceCollection services, GlobalSettings globalSettings) where TUserStore : class + { + services.TryAddTransient(); + services.Configure(options => + { + options.TokenLifespan = TimeSpan.FromMinutes(15); + }); + + var passwordlessIdentityBuilder = services.AddIdentity() + .AddUserStore() + .AddRoleStore() + .AddDefaultTokenProviders(); + + var regularIdentityBuilder = services.AddIdentityCore() + .AddUserStore(); + + services.TryAddScoped, PasswordlessSignInManager>(); + + services.ConfigureApplicationCookie(options => + { + options.LoginPath = "/login"; + options.LogoutPath = "/"; + options.AccessDeniedPath = "/login?accessDenied=true"; + options.Cookie.Name = $"Bitwarden_{globalSettings.ProjectName}"; + options.Cookie.HttpOnly = true; + options.ExpireTimeSpan = TimeSpan.FromDays(2); + options.ReturnUrlParameter = "returnUrl"; + options.SlidingExpiration = true; + }); + + return new Tuple(passwordlessIdentityBuilder, regularIdentityBuilder); + } + + public static void AddIdentityAuthenticationServices( + this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment, + Action addAuthorization) + { + services + .AddAuthentication(IdentityServerAuthenticationDefaults.AuthenticationScheme) + .AddIdentityServerAuthentication(options => + { + options.Authority = globalSettings.BaseServiceUri.InternalIdentity; + options.RequireHttpsMetadata = !environment.IsDevelopment() && + globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); + options.TokenRetriever = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); + options.NameClaimType = ClaimTypes.Email; + options.SupportedTokens = SupportedTokens.Jwt; + }); + + if (addAuthorization != null) + { + services.AddAuthorization(config => + { + addAuthorization.Invoke(config); + }); + } + + if (environment.IsDevelopment()) + { + Microsoft.IdentityModel.Logging.IdentityModelEventSource.ShowPII = true; + } + } + + public static void AddCustomDataProtectionServices( + this IServiceCollection services, IWebHostEnvironment env, GlobalSettings globalSettings) + { + var builder = services.AddDataProtection(options => options.ApplicationDiscriminator = "Bitwarden"); + if (env.IsDevelopment()) + { + return; + } + + if (globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.DataProtection.Directory)) + { + builder.PersistKeysToFileSystem(new DirectoryInfo(globalSettings.DataProtection.Directory)); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) + { + X509Certificate2 dataProtectionCert = null; + if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificateThumbprint)) + { + dataProtectionCert = CoreHelpers.GetCertificate( + globalSettings.DataProtection.CertificateThumbprint); + } + else if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificatePassword)) + { + dataProtectionCert = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", + "dataprotection.pfx", globalSettings.DataProtection.CertificatePassword) + .GetAwaiter().GetResult(); + } + //TODO djsmith85 Check if this is the correct container name + builder + .PersistKeysToAzureBlobStorage(globalSettings.Storage.ConnectionString, "aspnet-dataprotection", "keys.xml") + .ProtectKeysWithCertificate(dataProtectionCert); + } + } + + public static IIdentityServerBuilder AddIdentityServerCertificate( + this IIdentityServerBuilder identityServerBuilder, IWebHostEnvironment env, GlobalSettings globalSettings) + { + var certificate = CoreHelpers.GetIdentityServerCertificate(globalSettings); + if (certificate != null) + { + identityServerBuilder.AddSigningCredential(certificate); + } + else if (env.IsDevelopment()) + { + identityServerBuilder.AddDeveloperSigningCredential(false); + } + else + { + throw new Exception("No identity certificate to use."); + } + return identityServerBuilder; + } + + public static GlobalSettings AddGlobalSettingsServices(this IServiceCollection services, + IConfiguration configuration, IWebHostEnvironment environment) + { + var globalSettings = new GlobalSettings(); + ConfigurationBinder.Bind(configuration.GetSection("GlobalSettings"), globalSettings); + + if (environment.IsDevelopment() && configuration.GetValue("developSelfHosted")) + { + // Override settings with selfHostedOverride settings + ConfigurationBinder.Bind(configuration.GetSection("Dev:SelfHostOverride:GlobalSettings"), globalSettings); + } + + services.AddSingleton(s => globalSettings); + services.AddSingleton(s => globalSettings); + return globalSettings; + } + + public static void UseDefaultMiddleware(this IApplicationBuilder app, + IWebHostEnvironment env, GlobalSettings globalSettings) + { + string GetHeaderValue(HttpContext httpContext, string header) + { + if (httpContext.Request.Headers.ContainsKey(header)) + { + return httpContext.Request.Headers[header]; + } + return null; + } + + // Add version information to response headers + app.Use(async (httpContext, next) => + { + using (LogContext.PushProperty("IPAddress", httpContext.GetIpAddress(globalSettings))) + using (LogContext.PushProperty("UserAgent", GetHeaderValue(httpContext, "user-agent"))) + using (LogContext.PushProperty("DeviceType", GetHeaderValue(httpContext, "device-type"))) + using (LogContext.PushProperty("Origin", GetHeaderValue(httpContext, "origin"))) + { + httpContext.Response.OnStarting((state) => + { + httpContext.Response.Headers.Append("Server-Version", CoreHelpers.GetVersion()); + return Task.FromResult(0); + }, null); + await next.Invoke(); + } }); } - if (environment.IsDevelopment()) + public static void UseForwardedHeaders(this IApplicationBuilder app, GlobalSettings globalSettings) { - Microsoft.IdentityModel.Logging.IdentityModelEventSource.ShowPII = true; - } - } - - public static void AddCustomDataProtectionServices( - this IServiceCollection services, IWebHostEnvironment env, GlobalSettings globalSettings) - { - var builder = services.AddDataProtection(options => options.ApplicationDiscriminator = "Bitwarden"); - if (env.IsDevelopment()) - { - return; - } - - if (globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.DataProtection.Directory)) - { - builder.PersistKeysToFileSystem(new DirectoryInfo(globalSettings.DataProtection.Directory)); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) - { - X509Certificate2 dataProtectionCert = null; - if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificateThumbprint)) + var options = new ForwardedHeadersOptions { - dataProtectionCert = CoreHelpers.GetCertificate( - globalSettings.DataProtection.CertificateThumbprint); - } - else if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificatePassword)) + ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto + }; + if (!string.IsNullOrWhiteSpace(globalSettings.KnownProxies)) { - dataProtectionCert = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", - "dataprotection.pfx", globalSettings.DataProtection.CertificatePassword) - .GetAwaiter().GetResult(); - } - //TODO djsmith85 Check if this is the correct container name - builder - .PersistKeysToAzureBlobStorage(globalSettings.Storage.ConnectionString, "aspnet-dataprotection", "keys.xml") - .ProtectKeysWithCertificate(dataProtectionCert); - } - } - - public static IIdentityServerBuilder AddIdentityServerCertificate( - this IIdentityServerBuilder identityServerBuilder, IWebHostEnvironment env, GlobalSettings globalSettings) - { - var certificate = CoreHelpers.GetIdentityServerCertificate(globalSettings); - if (certificate != null) - { - identityServerBuilder.AddSigningCredential(certificate); - } - else if (env.IsDevelopment()) - { - identityServerBuilder.AddDeveloperSigningCredential(false); - } - else - { - throw new Exception("No identity certificate to use."); - } - return identityServerBuilder; - } - - public static GlobalSettings AddGlobalSettingsServices(this IServiceCollection services, - IConfiguration configuration, IWebHostEnvironment environment) - { - var globalSettings = new GlobalSettings(); - ConfigurationBinder.Bind(configuration.GetSection("GlobalSettings"), globalSettings); - - if (environment.IsDevelopment() && configuration.GetValue("developSelfHosted")) - { - // Override settings with selfHostedOverride settings - ConfigurationBinder.Bind(configuration.GetSection("Dev:SelfHostOverride:GlobalSettings"), globalSettings); - } - - services.AddSingleton(s => globalSettings); - services.AddSingleton(s => globalSettings); - return globalSettings; - } - - public static void UseDefaultMiddleware(this IApplicationBuilder app, - IWebHostEnvironment env, GlobalSettings globalSettings) - { - string GetHeaderValue(HttpContext httpContext, string header) - { - if (httpContext.Request.Headers.ContainsKey(header)) - { - return httpContext.Request.Headers[header]; - } - return null; - } - - // Add version information to response headers - app.Use(async (httpContext, next) => - { - using (LogContext.PushProperty("IPAddress", httpContext.GetIpAddress(globalSettings))) - using (LogContext.PushProperty("UserAgent", GetHeaderValue(httpContext, "user-agent"))) - using (LogContext.PushProperty("DeviceType", GetHeaderValue(httpContext, "device-type"))) - using (LogContext.PushProperty("Origin", GetHeaderValue(httpContext, "origin"))) - { - httpContext.Response.OnStarting((state) => + var proxies = globalSettings.KnownProxies.Split(','); + foreach (var proxy in proxies) { - httpContext.Response.Headers.Append("Server-Version", CoreHelpers.GetVersion()); - return Task.FromResult(0); - }, null); - await next.Invoke(); - } - }); - } - - public static void UseForwardedHeaders(this IApplicationBuilder app, GlobalSettings globalSettings) - { - var options = new ForwardedHeadersOptions - { - ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto - }; - if (!string.IsNullOrWhiteSpace(globalSettings.KnownProxies)) - { - var proxies = globalSettings.KnownProxies.Split(','); - foreach (var proxy in proxies) - { - if (System.Net.IPAddress.TryParse(proxy.Trim(), out var ip)) - { - options.KnownProxies.Add(ip); + if (System.Net.IPAddress.TryParse(proxy.Trim(), out var ip)) + { + options.KnownProxies.Add(ip); + } } } - } - if (options.KnownProxies.Count > 1) - { - options.ForwardLimit = null; - } - app.UseForwardedHeaders(options); - } - - public static void AddCoreLocalizationServices(this IServiceCollection services) - { - services.AddTransient(); - services.AddLocalization(options => options.ResourcesPath = "Resources"); - } - - public static IApplicationBuilder UseCoreLocalization(this IApplicationBuilder app) - { - var supportedCultures = new[] { "en" }; - return app.UseRequestLocalization(options => options - .SetDefaultCulture(supportedCultures[0]) - .AddSupportedCultures(supportedCultures) - .AddSupportedUICultures(supportedCultures)); - } - - public static IMvcBuilder AddViewAndDataAnnotationLocalization(this IMvcBuilder mvc) - { - mvc.Services.AddTransient(); - return mvc.AddViewLocalization(options => options.ResourcesPath = "Resources") - .AddDataAnnotationsLocalization(options => - options.DataAnnotationLocalizerProvider = (type, factory) => - { - var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); - return factory.Create("SharedResources", assemblyName.Name); - }); - } - - public static IServiceCollection AddDistributedIdentityServices(this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddOidcStateDataFormatterCache(); - services.AddSession(); - services.ConfigureApplicationCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); - services.ConfigureExternalCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); - services.AddSingleton>( - svcs => new ConfigureOpenIdConnectDistributedOptions( - svcs.GetRequiredService(), - globalSettings, - svcs.GetRequiredService()) - ); - - return services; - } - - public static void AddWebAuthn(this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddFido2(options => - { - options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host; - options.ServerName = "Bitwarden"; - options.Origins = new HashSet { globalSettings.BaseServiceUri.Vault, }; - options.TimestampDriftTolerance = 300000; - }); - } - - /// - /// Adds either an in-memory or distributed IP rate limiter depending if a Redis connection string is available. - /// - /// - /// - public static void AddIpRateLimiting(this IServiceCollection services, - GlobalSettings globalSettings) - { - services.AddHostedService(); - services.AddSingleton(); - - if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) - { - services.AddInMemoryRateLimiting(); - } - else - { - services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer - } - } - - /// - /// Adds an implementation of to the service collection. Uses a memory - /// cache if self hosted or no Redis connection string is available in GlobalSettings. - /// - public static void AddDistributedCache( - this IServiceCollection services, - GlobalSettings globalSettings) - { - if (globalSettings.SelfHosted || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) - { - services.AddDistributedMemoryCache(); - return; - } - - // Register the IConnectionMultiplexer explicitly so it can be accessed via DI - // (e.g. for the IP rate limiting store) - services.AddSingleton( - _ => ConnectionMultiplexer.Connect(globalSettings.Redis.ConnectionString)); - - // Explicitly register IDistributedCache to re-use existing IConnectionMultiplexer - // to reduce the number of redundant connections to the Redis instance - services.AddSingleton(s => - { - return new RedisCache(new RedisCacheOptions + if (options.KnownProxies.Count > 1) { - // Use "ProjectName:" as an instance name to namespace keys and avoid conflicts between projects - InstanceName = $"{globalSettings.ProjectName}:", - ConnectionMultiplexerFactory = () => - Task.FromResult(s.GetRequiredService()) + options.ForwardLimit = null; + } + app.UseForwardedHeaders(options); + } + + public static void AddCoreLocalizationServices(this IServiceCollection services) + { + services.AddTransient(); + services.AddLocalization(options => options.ResourcesPath = "Resources"); + } + + public static IApplicationBuilder UseCoreLocalization(this IApplicationBuilder app) + { + var supportedCultures = new[] { "en" }; + return app.UseRequestLocalization(options => options + .SetDefaultCulture(supportedCultures[0]) + .AddSupportedCultures(supportedCultures) + .AddSupportedUICultures(supportedCultures)); + } + + public static IMvcBuilder AddViewAndDataAnnotationLocalization(this IMvcBuilder mvc) + { + mvc.Services.AddTransient(); + return mvc.AddViewLocalization(options => options.ResourcesPath = "Resources") + .AddDataAnnotationsLocalization(options => + options.DataAnnotationLocalizerProvider = (type, factory) => + { + var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); + return factory.Create("SharedResources", assemblyName.Name); + }); + } + + public static IServiceCollection AddDistributedIdentityServices(this IServiceCollection services, GlobalSettings globalSettings) + { + services.AddOidcStateDataFormatterCache(); + services.AddSession(); + services.ConfigureApplicationCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); + services.ConfigureExternalCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); + services.AddSingleton>( + svcs => new ConfigureOpenIdConnectDistributedOptions( + svcs.GetRequiredService(), + globalSettings, + svcs.GetRequiredService()) + ); + + return services; + } + + public static void AddWebAuthn(this IServiceCollection services, GlobalSettings globalSettings) + { + services.AddFido2(options => + { + options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host; + options.ServerName = "Bitwarden"; + options.Origins = new HashSet { globalSettings.BaseServiceUri.Vault, }; + options.TimestampDriftTolerance = 300000; }); - }); + } + + /// + /// Adds either an in-memory or distributed IP rate limiter depending if a Redis connection string is available. + /// + /// + /// + public static void AddIpRateLimiting(this IServiceCollection services, + GlobalSettings globalSettings) + { + services.AddHostedService(); + services.AddSingleton(); + + if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) + { + services.AddInMemoryRateLimiting(); + } + else + { + services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer + } + } + + /// + /// Adds an implementation of to the service collection. Uses a memory + /// cache if self hosted or no Redis connection string is available in GlobalSettings. + /// + public static void AddDistributedCache( + this IServiceCollection services, + GlobalSettings globalSettings) + { + if (globalSettings.SelfHosted || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) + { + services.AddDistributedMemoryCache(); + return; + } + + // Register the IConnectionMultiplexer explicitly so it can be accessed via DI + // (e.g. for the IP rate limiting store) + services.AddSingleton( + _ => ConnectionMultiplexer.Connect(globalSettings.Redis.ConnectionString)); + + // Explicitly register IDistributedCache to re-use existing IConnectionMultiplexer + // to reduce the number of redundant connections to the Redis instance + services.AddSingleton(s => + { + return new RedisCache(new RedisCacheOptions + { + // Use "ProjectName:" as an instance name to namespace keys and avoid conflicts between projects + InstanceName = $"{globalSettings.ProjectName}:", + ConnectionMultiplexerFactory = () => + Task.FromResult(s.GetRequiredService()) + }); + }); + } } } diff --git a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs index 693e530831..8d678c3f48 100644 --- a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs +++ b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs @@ -3,41 +3,42 @@ using Bit.IntegrationTestCommon.Factories; using IdentityServer4.AccessTokenValidation; using Microsoft.AspNetCore.TestHost; -namespace Bit.Api.IntegrationTest.Factories; - -public class ApiApplicationFactory : WebApplicationFactoryBase +namespace Bit.Api.IntegrationTest.Factories { - private readonly IdentityApplicationFactory _identityApplicationFactory; - - public ApiApplicationFactory() + public class ApiApplicationFactory : WebApplicationFactoryBase { - _identityApplicationFactory = new IdentityApplicationFactory(); - } + private readonly IdentityApplicationFactory _identityApplicationFactory; - protected override void ConfigureWebHost(IWebHostBuilder builder) - { - base.ConfigureWebHost(builder); - - builder.ConfigureTestServices(services => + public ApiApplicationFactory() { - services.PostConfigure(IdentityServerAuthenticationDefaults.AuthenticationScheme, options => + _identityApplicationFactory = new IdentityApplicationFactory(); + } + + protected override void ConfigureWebHost(IWebHostBuilder builder) + { + base.ConfigureWebHost(builder); + + builder.ConfigureTestServices(services => { - options.JwtBackChannelHandler = _identityApplicationFactory.Server.CreateHandler(); + services.PostConfigure(IdentityServerAuthenticationDefaults.AuthenticationScheme, options => + { + options.JwtBackChannelHandler = _identityApplicationFactory.Server.CreateHandler(); + }); }); - }); - } + } - /// - /// Helper for registering and logging in to a new account - /// - public async Task<(string Token, string RefreshToken)> LoginWithNewAccount(string email = "integration-test@bitwarden.com", string masterPasswordHash = "master_password_hash") - { - await _identityApplicationFactory.RegisterAsync(new RegisterRequestModel + /// + /// Helper for registering and logging in to a new account + /// + public async Task<(string Token, string RefreshToken)> LoginWithNewAccount(string email = "integration-test@bitwarden.com", string masterPasswordHash = "master_password_hash") { - Email = email, - MasterPasswordHash = masterPasswordHash, - }); + await _identityApplicationFactory.RegisterAsync(new RegisterRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + }); - return await _identityApplicationFactory.TokenFromPasswordAsync(email, masterPasswordHash); + return await _identityApplicationFactory.TokenFromPasswordAsync(email, masterPasswordHash); + } } } diff --git a/test/Api.Test/Controllers/AccountsControllerTests.cs b/test/Api.Test/Controllers/AccountsControllerTests.cs index 0b2747c386..cd33de4c8b 100644 --- a/test/Api.Test/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Controllers/AccountsControllerTests.cs @@ -13,412 +13,413 @@ using Microsoft.AspNetCore.Identity; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -public class AccountsControllerTests : IDisposable +namespace Bit.Api.Test.Controllers { - - private readonly AccountsController _sut; - private readonly GlobalSettings _globalSettings; - private readonly ICipherRepository _cipherRepository; - private readonly IFolderRepository _folderRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly ISendRepository _sendRepository; - private readonly ISendService _sendService; - private readonly IProviderUserRepository _providerUserRepository; - - public AccountsControllerTests() + public class AccountsControllerTests : IDisposable { - _userService = Substitute.For(); - _userRepository = Substitute.For(); - _cipherRepository = Substitute.For(); - _folderRepository = Substitute.For(); - _organizationService = Substitute.For(); - _organizationUserRepository = Substitute.For(); - _providerUserRepository = Substitute.For(); - _paymentService = Substitute.For(); - _globalSettings = new GlobalSettings(); - _sendRepository = Substitute.For(); - _sendService = Substitute.For(); - _sut = new AccountsController( - _globalSettings, - _cipherRepository, - _folderRepository, - _organizationService, - _organizationUserRepository, - _providerUserRepository, - _paymentService, - _userRepository, - _userService, - _sendRepository, - _sendService - ); - } - public void Dispose() - { - _sut?.Dispose(); - } + private readonly AccountsController _sut; + private readonly GlobalSettings _globalSettings; + private readonly ICipherRepository _cipherRepository; + private readonly IFolderRepository _folderRepository; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPaymentService _paymentService; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; + private readonly IProviderUserRepository _providerUserRepository; - [Fact] - public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() - { - var userKdfInfo = new UserKdfInformation + public AccountsControllerTests() { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 5000 - }; - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); + _userService = Substitute.For(); + _userRepository = Substitute.For(); + _cipherRepository = Substitute.For(); + _folderRepository = Substitute.For(); + _organizationService = Substitute.For(); + _organizationUserRepository = Substitute.For(); + _providerUserRepository = Substitute.For(); + _paymentService = Substitute.For(); + _globalSettings = new GlobalSettings(); + _sendRepository = Substitute.For(); + _sendService = Substitute.For(); + _sut = new AccountsController( + _globalSettings, + _cipherRepository, + _folderRepository, + _organizationService, + _organizationUserRepository, + _providerUserRepository, + _paymentService, + _userRepository, + _userService, + _sendRepository, + _sendService + ); + } - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - - Assert.Equal(userKdfInfo.Kdf, response.Kdf); - Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); - } - - [Fact] - public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() - { - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult((UserKdfInformation)null)); - - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - - Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); - Assert.Equal(100000, response.KdfIterations); - } - - [Fact] - public async Task PostRegister_ShouldRegisterUser() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Success)); - var request = new RegisterRequestModel + public void Dispose() { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; + _sut?.Dispose(); + } - await _sut.PostRegister(request); - - await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); - } - - [Fact] - public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Failed())); - var request = new RegisterRequestModel + [Fact] + public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; + var userKdfInfo = new UserKdfInformation + { + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 5000 + }; + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); - await Assert.ThrowsAsync(() => _sut.PostRegister(request)); - } + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - [Fact] - public async Task PostPasswordHint_ShouldNotifyUserService() - { - var email = "user@example.com"; + Assert.Equal(userKdfInfo.Kdf, response.Kdf); + Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); + } - await _sut.PostPasswordHint(new PasswordHintRequestModel { Email = email }); - - await _userService.Received(1).SendMasterPasswordHintAsync(email); - } - - [Fact] - public async Task PostEmailToken_ShouldInitiateEmailChange() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToAcceptPasswordFor(user); - var newEmail = "example@user.com"; - - await _sut.PostEmailToken(new EmailTokenRequestModel { NewEmail = newEmail }); - - await _userService.Received(1).InitiateEmailChangeAsync(user, newEmail); - } - - [Fact] - public async Task PostEmailToken_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostEmailToken(new EmailTokenRequestModel()) - ); - } - - [Fact] - public async Task PostEmailToken_WhenInvalidPasssword_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToRejectPasswordFor(user); - - await Assert.ThrowsAsync( - () => _sut.PostEmailToken(new EmailTokenRequestModel()) - ); - } - - [Fact] - public async Task PostEmail_ShouldChangeUserEmail() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangeEmailAsync(user, default, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Success)); - - await _sut.PostEmail(new EmailRequestModel()); - - await _userService.Received(1).ChangeEmailAsync(user, default, default, default, default, default); - } - - [Fact] - public async Task PostEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostEmail(new EmailRequestModel()) - ); - } - - [Fact] - public async Task PostEmail_WhenEmailCannotBeChanged_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangeEmailAsync(user, default, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Failed())); - - await Assert.ThrowsAsync( - () => _sut.PostEmail(new EmailRequestModel()) - ); - } - - [Fact] - public async Task PostVerifyEmail_ShouldSendEmailVerification() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - - await _sut.PostVerifyEmail(); - - await _userService.Received(1).SendEmailVerificationAsync(user); - } - - [Fact] - public async Task PostVerifyEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostVerifyEmail() - ); - } - - [Fact] - public async Task PostVerifyEmailToken_ShouldConfirmEmail() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidIdFor(user); - _userService.ConfirmEmailAsync(user, Arg.Any()) - .Returns(Task.FromResult(IdentityResult.Success)); - - await _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }); - - await _userService.Received(1).ConfirmEmailAsync(user, Arg.Any()); - } - - [Fact] - public async Task PostVerifyEmailToken_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnNullUserId(); - - await Assert.ThrowsAsync( - () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) - ); - } - - [Fact] - public async Task PostVerifyEmailToken_WhenEmailConfirmationFails_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidIdFor(user); - _userService.ConfirmEmailAsync(user, Arg.Any()) - .Returns(Task.FromResult(IdentityResult.Failed())); - - await Assert.ThrowsAsync( - () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) - ); - } - - [Fact] - public async Task PostPassword_ShouldChangePassword() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangePasswordAsync(user, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Success)); - - await _sut.PostPassword(new PasswordRequestModel()); - - await _userService.Received(1).ChangePasswordAsync(user, default, default, default, default); - } - - [Fact] - public async Task PostPassword_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostPassword(new PasswordRequestModel()) - ); - } - - [Fact] - public async Task PostPassword_WhenPasswordChangeFails_ShouldBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangePasswordAsync(user, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Failed())); - - await Assert.ThrowsAsync( - () => _sut.PostPassword(new PasswordRequestModel()) - ); - } - - [Fact] - public async Task GetApiKey_ShouldReturnApiKeyResponse() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToAcceptPasswordFor(user); - await _sut.ApiKey(new SecretVerificationRequestModel()); - } - - [Fact] - public async Task GetApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - [Fact] - public async Task GetApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToRejectPasswordFor(user); - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - [Fact] - public async Task PostRotateApiKey_ShouldRotateApiKey() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToAcceptPasswordFor(user); - await _sut.RotateApiKey(new SecretVerificationRequestModel()); - } - - [Fact] - public async Task PostRotateApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - [Fact] - public async Task PostRotateApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToRejectPasswordFor(user); - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - // Below are helper functions that currently belong to this - // test class, but ultimately may need to be split out into - // something greater in order to share common test steps with - // other test suites. They are included here for the time being - // until that day comes. - private User GenerateExampleUser() - { - return new User + [Fact] + public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() { - Email = "user@example.com" - }; - } + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult((UserKdfInformation)null)); - private void ConfigureUserServiceToReturnNullPrincipal() - { - _userService.GetUserByPrincipalAsync(Arg.Any()) - .Returns(Task.FromResult((User)null)); - } + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - private void ConfigureUserServiceToReturnValidPrincipalFor(User user) - { - _userService.GetUserByPrincipalAsync(Arg.Any()) - .Returns(Task.FromResult(user)); - } + Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); + Assert.Equal(100000, response.KdfIterations); + } - private void ConfigureUserServiceToRejectPasswordFor(User user) - { - _userService.CheckPasswordAsync(user, Arg.Any()) - .Returns(Task.FromResult(false)); - } + [Fact] + public async Task PostRegister_ShouldRegisterUser() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Success)); + var request = new RegisterRequestModel + { + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; - private void ConfigureUserServiceToAcceptPasswordFor(User user) - { - _userService.CheckPasswordAsync(user, Arg.Any()) - .Returns(Task.FromResult(true)); - _userService.VerifySecretAsync(user, Arg.Any()) - .Returns(Task.FromResult(true)); - } + await _sut.PostRegister(request); - private void ConfigureUserServiceToReturnValidIdFor(User user) - { - _userService.GetUserByIdAsync(Arg.Any()) - .Returns(Task.FromResult(user)); - } + await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); + } - private void ConfigureUserServiceToReturnNullUserId() - { - _userService.GetUserByIdAsync(Arg.Any()) - .Returns(Task.FromResult((User)null)); + [Fact] + public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Failed())); + var request = new RegisterRequestModel + { + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; + + await Assert.ThrowsAsync(() => _sut.PostRegister(request)); + } + + [Fact] + public async Task PostPasswordHint_ShouldNotifyUserService() + { + var email = "user@example.com"; + + await _sut.PostPasswordHint(new PasswordHintRequestModel { Email = email }); + + await _userService.Received(1).SendMasterPasswordHintAsync(email); + } + + [Fact] + public async Task PostEmailToken_ShouldInitiateEmailChange() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToAcceptPasswordFor(user); + var newEmail = "example@user.com"; + + await _sut.PostEmailToken(new EmailTokenRequestModel { NewEmail = newEmail }); + + await _userService.Received(1).InitiateEmailChangeAsync(user, newEmail); + } + + [Fact] + public async Task PostEmailToken_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostEmailToken(new EmailTokenRequestModel()) + ); + } + + [Fact] + public async Task PostEmailToken_WhenInvalidPasssword_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToRejectPasswordFor(user); + + await Assert.ThrowsAsync( + () => _sut.PostEmailToken(new EmailTokenRequestModel()) + ); + } + + [Fact] + public async Task PostEmail_ShouldChangeUserEmail() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangeEmailAsync(user, default, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Success)); + + await _sut.PostEmail(new EmailRequestModel()); + + await _userService.Received(1).ChangeEmailAsync(user, default, default, default, default, default); + } + + [Fact] + public async Task PostEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostEmail(new EmailRequestModel()) + ); + } + + [Fact] + public async Task PostEmail_WhenEmailCannotBeChanged_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangeEmailAsync(user, default, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Failed())); + + await Assert.ThrowsAsync( + () => _sut.PostEmail(new EmailRequestModel()) + ); + } + + [Fact] + public async Task PostVerifyEmail_ShouldSendEmailVerification() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + + await _sut.PostVerifyEmail(); + + await _userService.Received(1).SendEmailVerificationAsync(user); + } + + [Fact] + public async Task PostVerifyEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostVerifyEmail() + ); + } + + [Fact] + public async Task PostVerifyEmailToken_ShouldConfirmEmail() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidIdFor(user); + _userService.ConfirmEmailAsync(user, Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); + + await _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }); + + await _userService.Received(1).ConfirmEmailAsync(user, Arg.Any()); + } + + [Fact] + public async Task PostVerifyEmailToken_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnNullUserId(); + + await Assert.ThrowsAsync( + () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) + ); + } + + [Fact] + public async Task PostVerifyEmailToken_WhenEmailConfirmationFails_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidIdFor(user); + _userService.ConfirmEmailAsync(user, Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Failed())); + + await Assert.ThrowsAsync( + () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) + ); + } + + [Fact] + public async Task PostPassword_ShouldChangePassword() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangePasswordAsync(user, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Success)); + + await _sut.PostPassword(new PasswordRequestModel()); + + await _userService.Received(1).ChangePasswordAsync(user, default, default, default, default); + } + + [Fact] + public async Task PostPassword_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostPassword(new PasswordRequestModel()) + ); + } + + [Fact] + public async Task PostPassword_WhenPasswordChangeFails_ShouldBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangePasswordAsync(user, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Failed())); + + await Assert.ThrowsAsync( + () => _sut.PostPassword(new PasswordRequestModel()) + ); + } + + [Fact] + public async Task GetApiKey_ShouldReturnApiKeyResponse() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToAcceptPasswordFor(user); + await _sut.ApiKey(new SecretVerificationRequestModel()); + } + + [Fact] + public async Task GetApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + [Fact] + public async Task GetApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToRejectPasswordFor(user); + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + [Fact] + public async Task PostRotateApiKey_ShouldRotateApiKey() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToAcceptPasswordFor(user); + await _sut.RotateApiKey(new SecretVerificationRequestModel()); + } + + [Fact] + public async Task PostRotateApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + [Fact] + public async Task PostRotateApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToRejectPasswordFor(user); + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + // Below are helper functions that currently belong to this + // test class, but ultimately may need to be split out into + // something greater in order to share common test steps with + // other test suites. They are included here for the time being + // until that day comes. + private User GenerateExampleUser() + { + return new User + { + Email = "user@example.com" + }; + } + + private void ConfigureUserServiceToReturnNullPrincipal() + { + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(Task.FromResult((User)null)); + } + + private void ConfigureUserServiceToReturnValidPrincipalFor(User user) + { + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(Task.FromResult(user)); + } + + private void ConfigureUserServiceToRejectPasswordFor(User user) + { + _userService.CheckPasswordAsync(user, Arg.Any()) + .Returns(Task.FromResult(false)); + } + + private void ConfigureUserServiceToAcceptPasswordFor(User user) + { + _userService.CheckPasswordAsync(user, Arg.Any()) + .Returns(Task.FromResult(true)); + _userService.VerifySecretAsync(user, Arg.Any()) + .Returns(Task.FromResult(true)); + } + + private void ConfigureUserServiceToReturnValidIdFor(User user) + { + _userService.GetUserByIdAsync(Arg.Any()) + .Returns(Task.FromResult(user)); + } + + private void ConfigureUserServiceToReturnNullUserId() + { + _userService.GetUserByIdAsync(Arg.Any()) + .Returns(Task.FromResult((User)null)); + } } } diff --git a/test/Api.Test/Controllers/CollectionsControllerTests.cs b/test/Api.Test/Controllers/CollectionsControllerTests.cs index b5d304f7a0..ba9620a007 100644 --- a/test/Api.Test/Controllers/CollectionsControllerTests.cs +++ b/test/Api.Test/Controllers/CollectionsControllerTests.cs @@ -11,78 +11,79 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -[ControllerCustomize(typeof(CollectionsController))] -[SutProviderCustomize] -public class CollectionsControllerTests +namespace Bit.Api.Test.Controllers { - [Theory, BitAutoData] - public async Task Post_Success(Guid orgId, SutProvider sutProvider) + [ControllerCustomize(typeof(CollectionsController))] + [SutProviderCustomize] + public class CollectionsControllerTests { - sutProvider.GetDependency() - .CreateNewCollections(orgId) - .Returns(true); - - sutProvider.GetDependency() - .EditAnyCollection(orgId) - .Returns(false); - - var collectionRequest = new CollectionRequestModel + [Theory, BitAutoData] + public async Task Post_Success(Guid orgId, SutProvider sutProvider) { - Name = "encrypted_string", - ExternalId = "my_external_id" - }; + sutProvider.GetDependency() + .CreateNewCollections(orgId) + .Returns(true); - _ = await sutProvider.Sut.Post(orgId, collectionRequest); + sutProvider.GetDependency() + .EditAnyCollection(orgId) + .Returns(false); - await sutProvider.GetDependency() - .Received(1) - .SaveAsync(Arg.Any(), Arg.Any>(), null); - } - - [Theory, BitAutoData] - public async Task Put_Success(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .ViewAssignedCollections(orgId) - .Returns(true); - - sutProvider.GetDependency() - .EditAssignedCollections(orgId) - .Returns(true); - - sutProvider.GetDependency() - .UserId - .Returns(userId); - - sutProvider.GetDependency() - .GetByIdAsync(collectionId, userId) - .Returns(new CollectionDetails + var collectionRequest = new CollectionRequestModel { - OrganizationId = orgId, - }); + Name = "encrypted_string", + ExternalId = "my_external_id" + }; - _ = await sutProvider.Sut.Put(orgId, collectionId, collectionRequest); - } + _ = await sutProvider.Sut.Post(orgId, collectionRequest); - [Theory, BitAutoData] - public async Task Put_CanNotEditAssignedCollection_ThrowsNotFound(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .EditAssignedCollections(orgId) - .Returns(true); + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Any(), Arg.Any>(), null); + } - sutProvider.GetDependency() - .UserId - .Returns(userId); + [Theory, BitAutoData] + public async Task Put_Success(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .ViewAssignedCollections(orgId) + .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(collectionId, userId) - .Returns(Task.FromResult(null)); + sutProvider.GetDependency() + .EditAssignedCollections(orgId) + .Returns(true); - _ = await Assert.ThrowsAsync(async () => await sutProvider.Sut.Put(orgId, collectionId, collectionRequest)); + sutProvider.GetDependency() + .UserId + .Returns(userId); + + sutProvider.GetDependency() + .GetByIdAsync(collectionId, userId) + .Returns(new CollectionDetails + { + OrganizationId = orgId, + }); + + _ = await sutProvider.Sut.Put(orgId, collectionId, collectionRequest); + } + + [Theory, BitAutoData] + public async Task Put_CanNotEditAssignedCollection_ThrowsNotFound(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .EditAssignedCollections(orgId) + .Returns(true); + + sutProvider.GetDependency() + .UserId + .Returns(userId); + + sutProvider.GetDependency() + .GetByIdAsync(collectionId, userId) + .Returns(Task.FromResult(null)); + + _ = await Assert.ThrowsAsync(async () => await sutProvider.Sut.Put(orgId, collectionId, collectionRequest)); + } } } diff --git a/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs b/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs index 80bfcfe006..88834621ab 100644 --- a/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs @@ -18,301 +18,302 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -[ControllerCustomize(typeof(OrganizationConnectionsController))] -[SutProviderCustomize] -[JsonDocumentCustomize] -public class OrganizationConnectionsControllerTests +namespace Bit.Api.Test.Controllers { - public static IEnumerable ConnectionTypes => - Enum.GetValues().Select(p => new object[] { p }); - - - [Theory] - [BitAutoData(true, true)] - [BitAutoData(false, true)] - [BitAutoData(true, false)] - [BitAutoData(false, false)] - public void ConnectionEnabled_RequiresBothSelfHostAndCommunications(bool selfHosted, bool enableCloudCommunication, SutProvider sutProvider) + [ControllerCustomize(typeof(OrganizationConnectionsController))] + [SutProviderCustomize] + [JsonDocumentCustomize] + public class OrganizationConnectionsControllerTests { - var globalSettingsMock = sutProvider.GetDependency(); - globalSettingsMock.SelfHosted.Returns(selfHosted); - globalSettingsMock.EnableCloudCommunication.Returns(enableCloudCommunication); + public static IEnumerable ConnectionTypes => + Enum.GetValues().Select(p => new object[] { p }); - Action assert = selfHosted && enableCloudCommunication ? Assert.True : Assert.False; - var result = sutProvider.Sut.ConnectionsEnabled(); - - assert(result); - } - - [Theory] - [BitAutoData] - public async Task CreateConnection_CloudBillingSync_RequiresOwnerPermissions(SutProvider sutProvider) - { - var model = new OrganizationConnectionRequestModel + [Theory] + [BitAutoData(true, true)] + [BitAutoData(false, true)] + [BitAutoData(true, false)] + [BitAutoData(false, false)] + public void ConnectionEnabled_RequiresBothSelfHostAndCommunications(bool selfHosted, bool enableCloudCommunication, SutProvider sutProvider) { - Type = OrganizationConnectionType.CloudBillingSync, - }; - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); + var globalSettingsMock = sutProvider.GetDependency(); + globalSettingsMock.SelfHosted.Returns(selfHosted); + globalSettingsMock.EnableCloudCommunication.Returns(enableCloudCommunication); - Assert.Contains($"You do not have permission to create a connection of type", exception.Message); - } + Action assert = selfHosted && enableCloudCommunication ? Assert.True : Assert.False; - [Theory] - [BitMemberAutoData(nameof(ConnectionTypes))] - public async Task CreateConnection_OnlyOneConnectionOfEachType(OrganizationConnectionType type, - OrganizationConnectionRequestModel model, BillingSyncConfig config, Guid existingEntityId, - SutProvider sutProvider) - { - model.Type = type; - model.Config = JsonDocumentFromObject(config); - var typedModel = new OrganizationConnectionRequestModel(model); - var existing = typedModel.ToData(existingEntityId).ToEntity(); + var result = sutProvider.Sut.ConnectionsEnabled(); - sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + assert(result); + } - sutProvider.GetDependency().GetByOrganizationIdTypeAsync(model.OrganizationId, type).Returns(new[] { existing }); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); - - Assert.Contains($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization.", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task CreateConnection_BillingSyncType_InvalidLicense_Throws(OrganizationConnectionRequestModel model, - BillingSyncConfig config, Guid cloudOrgId, OrganizationLicense organizationLicense, - SutProvider sutProvider) - { - model.Type = OrganizationConnectionType.CloudBillingSync; - organizationLicense.Id = cloudOrgId; - - model.Config = JsonDocumentFromObject(config); - var typedModel = new OrganizationConnectionRequestModel(model); - typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; - - sutProvider.GetDependency() - .OrganizationOwner(model.OrganizationId) - .Returns(true); - - sutProvider.GetDependency() - .ReadOrganizationLicenseAsync(model.OrganizationId) - .Returns(organizationLicense); - - sutProvider.GetDependency() - .VerifyLicense(organizationLicense) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateConnection(model)); - } - - [Theory] - [BitAutoData] - public async Task CreateConnection_Success(OrganizationConnectionRequestModel model, BillingSyncConfig config, - Guid cloudOrgId, OrganizationLicense organizationLicense, SutProvider sutProvider) - { - organizationLicense.Id = cloudOrgId; - - model.Config = JsonDocumentFromObject(config); - var typedModel = new OrganizationConnectionRequestModel(model); - typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; - - sutProvider.GetDependency().SelfHosted.Returns(true); - sutProvider.GetDependency().CreateAsync(default) - .ReturnsForAnyArgs(typedModel.ToData(Guid.NewGuid()).ToEntity()); - sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); - sutProvider.GetDependency() - .ReadOrganizationLicenseAsync(Arg.Any()) - .Returns(organizationLicense); - - sutProvider.GetDependency() - .VerifyLicense(organizationLicense) - .Returns(true); - - await sutProvider.Sut.CreateConnection(model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(typedModel.ToData()))); - } - - [Theory] - [BitAutoData] - public async Task UpdateConnection_RequiresOwnerPermissions(SutProvider sutProvider) - { - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(new OrganizationConnection()); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(default, null)); - - Assert.Contains("You do not have permission to update this connection.", exception.Message); - } - - [Theory] - [BitAutoData(OrganizationConnectionType.CloudBillingSync)] - public async Task UpdateConnection_BillingSync_OnlyOneConnectionOfEachType(OrganizationConnectionType type, - OrganizationConnection existing1, OrganizationConnection existing2, BillingSyncConfig config, - SutProvider sutProvider) - { - existing1.Type = existing2.Type = type; - existing1.Config = JsonSerializer.Serialize(config); - var typedModel = RequestModelFromEntity(existing1); - - sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); - - var orgConnectionRepository = sutProvider.GetDependency(); - orgConnectionRepository.GetByIdAsync(existing1.Id).Returns(existing1); - orgConnectionRepository.GetByIdAsync(existing2.Id).Returns(existing2); - orgConnectionRepository.GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type).Returns(new[] { existing1, existing2 }); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); - - Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); - } - - [Theory] - [BitAutoData(OrganizationConnectionType.Scim)] - public async Task UpdateConnection_Scim_OnlyOneConnectionOfEachType(OrganizationConnectionType type, - OrganizationConnection existing1, OrganizationConnection existing2, ScimConfig config, - SutProvider sutProvider) - { - existing1.Type = existing2.Type = type; - existing1.Config = JsonSerializer.Serialize(config); - var typedModel = RequestModelFromEntity(existing1); - - sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); - - sutProvider.GetDependency() - .GetByIdAsync(existing1.Id) - .Returns(existing1); - - sutProvider.GetDependency().ManageScim(typedModel.OrganizationId).Returns(true); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type) - .Returns(new[] { existing1, existing2 }); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); - - Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task UpdateConnection_Success(OrganizationConnection existing, BillingSyncConfig config, - OrganizationConnection updated, - SutProvider sutProvider) - { - existing.SetConfig(new BillingSyncConfig + [Theory] + [BitAutoData] + public async Task CreateConnection_CloudBillingSync_RequiresOwnerPermissions(SutProvider sutProvider) { - CloudOrganizationId = config.CloudOrganizationId, - }); - updated.Config = JsonSerializer.Serialize(config); - updated.Id = existing.Id; - updated.Type = OrganizationConnectionType.CloudBillingSync; - var model = RequestModelFromEntity(updated); + var model = new OrganizationConnectionRequestModel + { + Type = OrganizationConnectionType.CloudBillingSync, + }; + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); - sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(model.OrganizationId, model.Type) - .Returns(new[] { existing }); - sutProvider.GetDependency() - .UpdateAsync(default) - .ReturnsForAnyArgs(updated); - sutProvider.GetDependency() - .GetByIdAsync(existing.Id) - .Returns(existing); + Assert.Contains($"You do not have permission to create a connection of type", exception.Message); + } - var expected = new OrganizationConnectionResponseModel(updated, typeof(BillingSyncConfig)); - var result = await sutProvider.Sut.UpdateConnection(existing.Id, model); - - AssertHelper.AssertPropertyEqual(expected, result); - await sutProvider.GetDependency().Received(1) - .UpdateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(model.ToData(updated.Id)))); - } - - [Theory] - [BitAutoData] - public async Task UpdateConnection_DoesNotExist_ThrowsNotFound(SutProvider sutProvider) - { - await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(Guid.NewGuid(), null)); - } - - [Theory] - [BitAutoData] - public async Task GetConnection_RequiresOwnerPermissions(Guid connectionId, SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.GetConnection(connectionId, OrganizationConnectionType.CloudBillingSync)); - - Assert.Contains("You do not have permission to retrieve a connection of type", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task GetConnection_Success(OrganizationConnection connection, BillingSyncConfig config, - SutProvider sutProvider) - { - connection.Config = JsonSerializer.Serialize(config); - - sutProvider.GetDependency().SelfHosted.Returns(true); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(connection.OrganizationId, connection.Type) - .Returns(new[] { connection }); - sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); - - var expected = new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); - var actual = await sutProvider.Sut.GetConnection(connection.OrganizationId, connection.Type); - - AssertHelper.AssertPropertyEqual(expected, actual); - } - - [Theory] - [BitAutoData] - public async Task DeleteConnection_NotFound(Guid connectionId, - SutProvider sutProvider) - { - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connectionId)); - } - - [Theory] - [BitAutoData] - public async Task DeleteConnection_RequiresOwnerPermissions(OrganizationConnection connection, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connection.Id)); - - Assert.Contains("You do not have permission to remove this connection of type", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task DeleteConnection_Success(OrganizationConnection connection, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); - sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); - - await sutProvider.Sut.DeleteConnection(connection.Id); - - await sutProvider.GetDependency().DeleteAsync(connection); - } - - private static OrganizationConnectionRequestModel RequestModelFromEntity(OrganizationConnection entity) - where T : new() - { - return new(new OrganizationConnectionRequestModel() + [Theory] + [BitMemberAutoData(nameof(ConnectionTypes))] + public async Task CreateConnection_OnlyOneConnectionOfEachType(OrganizationConnectionType type, + OrganizationConnectionRequestModel model, BillingSyncConfig config, Guid existingEntityId, + SutProvider sutProvider) { - Type = entity.Type, - OrganizationId = entity.OrganizationId, - Enabled = entity.Enabled, - Config = JsonDocument.Parse(entity.Config), - }); - } + model.Type = type; + model.Config = JsonDocumentFromObject(config); + var typedModel = new OrganizationConnectionRequestModel(model); + var existing = typedModel.ToData(existingEntityId).ToEntity(); - private static JsonDocument JsonDocumentFromObject(T obj) => JsonDocument.Parse(JsonSerializer.Serialize(obj)); + sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync(model.OrganizationId, type).Returns(new[] { existing }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); + + Assert.Contains($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task CreateConnection_BillingSyncType_InvalidLicense_Throws(OrganizationConnectionRequestModel model, + BillingSyncConfig config, Guid cloudOrgId, OrganizationLicense organizationLicense, + SutProvider sutProvider) + { + model.Type = OrganizationConnectionType.CloudBillingSync; + organizationLicense.Id = cloudOrgId; + + model.Config = JsonDocumentFromObject(config); + var typedModel = new OrganizationConnectionRequestModel(model); + typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; + + sutProvider.GetDependency() + .OrganizationOwner(model.OrganizationId) + .Returns(true); + + sutProvider.GetDependency() + .ReadOrganizationLicenseAsync(model.OrganizationId) + .Returns(organizationLicense); + + sutProvider.GetDependency() + .VerifyLicense(organizationLicense) + .Returns(false); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateConnection(model)); + } + + [Theory] + [BitAutoData] + public async Task CreateConnection_Success(OrganizationConnectionRequestModel model, BillingSyncConfig config, + Guid cloudOrgId, OrganizationLicense organizationLicense, SutProvider sutProvider) + { + organizationLicense.Id = cloudOrgId; + + model.Config = JsonDocumentFromObject(config); + var typedModel = new OrganizationConnectionRequestModel(model); + typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; + + sutProvider.GetDependency().SelfHosted.Returns(true); + sutProvider.GetDependency().CreateAsync(default) + .ReturnsForAnyArgs(typedModel.ToData(Guid.NewGuid()).ToEntity()); + sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + sutProvider.GetDependency() + .ReadOrganizationLicenseAsync(Arg.Any()) + .Returns(organizationLicense); + + sutProvider.GetDependency() + .VerifyLicense(organizationLicense) + .Returns(true); + + await sutProvider.Sut.CreateConnection(model); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(typedModel.ToData()))); + } + + [Theory] + [BitAutoData] + public async Task UpdateConnection_RequiresOwnerPermissions(SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(new OrganizationConnection()); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(default, null)); + + Assert.Contains("You do not have permission to update this connection.", exception.Message); + } + + [Theory] + [BitAutoData(OrganizationConnectionType.CloudBillingSync)] + public async Task UpdateConnection_BillingSync_OnlyOneConnectionOfEachType(OrganizationConnectionType type, + OrganizationConnection existing1, OrganizationConnection existing2, BillingSyncConfig config, + SutProvider sutProvider) + { + existing1.Type = existing2.Type = type; + existing1.Config = JsonSerializer.Serialize(config); + var typedModel = RequestModelFromEntity(existing1); + + sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); + + var orgConnectionRepository = sutProvider.GetDependency(); + orgConnectionRepository.GetByIdAsync(existing1.Id).Returns(existing1); + orgConnectionRepository.GetByIdAsync(existing2.Id).Returns(existing2); + orgConnectionRepository.GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type).Returns(new[] { existing1, existing2 }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); + + Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); + } + + [Theory] + [BitAutoData(OrganizationConnectionType.Scim)] + public async Task UpdateConnection_Scim_OnlyOneConnectionOfEachType(OrganizationConnectionType type, + OrganizationConnection existing1, OrganizationConnection existing2, ScimConfig config, + SutProvider sutProvider) + { + existing1.Type = existing2.Type = type; + existing1.Config = JsonSerializer.Serialize(config); + var typedModel = RequestModelFromEntity(existing1); + + sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(existing1.Id) + .Returns(existing1); + + sutProvider.GetDependency().ManageScim(typedModel.OrganizationId).Returns(true); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type) + .Returns(new[] { existing1, existing2 }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); + + Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task UpdateConnection_Success(OrganizationConnection existing, BillingSyncConfig config, + OrganizationConnection updated, + SutProvider sutProvider) + { + existing.SetConfig(new BillingSyncConfig + { + CloudOrganizationId = config.CloudOrganizationId, + }); + updated.Config = JsonSerializer.Serialize(config); + updated.Id = existing.Id; + updated.Type = OrganizationConnectionType.CloudBillingSync; + var model = RequestModelFromEntity(updated); + + sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(model.OrganizationId, model.Type) + .Returns(new[] { existing }); + sutProvider.GetDependency() + .UpdateAsync(default) + .ReturnsForAnyArgs(updated); + sutProvider.GetDependency() + .GetByIdAsync(existing.Id) + .Returns(existing); + + var expected = new OrganizationConnectionResponseModel(updated, typeof(BillingSyncConfig)); + var result = await sutProvider.Sut.UpdateConnection(existing.Id, model); + + AssertHelper.AssertPropertyEqual(expected, result); + await sutProvider.GetDependency().Received(1) + .UpdateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(model.ToData(updated.Id)))); + } + + [Theory] + [BitAutoData] + public async Task UpdateConnection_DoesNotExist_ThrowsNotFound(SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(Guid.NewGuid(), null)); + } + + [Theory] + [BitAutoData] + public async Task GetConnection_RequiresOwnerPermissions(Guid connectionId, SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.GetConnection(connectionId, OrganizationConnectionType.CloudBillingSync)); + + Assert.Contains("You do not have permission to retrieve a connection of type", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task GetConnection_Success(OrganizationConnection connection, BillingSyncConfig config, + SutProvider sutProvider) + { + connection.Config = JsonSerializer.Serialize(config); + + sutProvider.GetDependency().SelfHosted.Returns(true); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(connection.OrganizationId, connection.Type) + .Returns(new[] { connection }); + sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); + + var expected = new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); + var actual = await sutProvider.Sut.GetConnection(connection.OrganizationId, connection.Type); + + AssertHelper.AssertPropertyEqual(expected, actual); + } + + [Theory] + [BitAutoData] + public async Task DeleteConnection_NotFound(Guid connectionId, + SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connectionId)); + } + + [Theory] + [BitAutoData] + public async Task DeleteConnection_RequiresOwnerPermissions(OrganizationConnection connection, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connection.Id)); + + Assert.Contains("You do not have permission to remove this connection of type", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task DeleteConnection_Success(OrganizationConnection connection, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); + sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); + + await sutProvider.Sut.DeleteConnection(connection.Id); + + await sutProvider.GetDependency().DeleteAsync(connection); + } + + private static OrganizationConnectionRequestModel RequestModelFromEntity(OrganizationConnection entity) + where T : new() + { + return new(new OrganizationConnectionRequestModel() + { + Type = entity.Type, + OrganizationId = entity.OrganizationId, + Enabled = entity.Enabled, + Config = JsonDocument.Parse(entity.Config), + }); + } + + private static JsonDocument JsonDocumentFromObject(T obj) => JsonDocument.Parse(JsonSerializer.Serialize(obj)); + } } diff --git a/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs b/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs index e58add5ef0..0cafdf9ff1 100644 --- a/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs @@ -13,135 +13,136 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -[ControllerCustomize(typeof(OrganizationSponsorshipsController))] -[SutProviderCustomize] -public class OrganizationSponsorshipsControllerTests +namespace Bit.Api.Test.Controllers { - public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); - - public static IEnumerable NonConfirmedOrganizationUsersStatuses => - Enum.GetValues() - .Where(s => s != OrganizationUserStatusType.Confirmed) - .Select(s => new object[] { s }); - - - [Theory] - [BitAutoData] - public async Task RedeemSponsorship_BadToken_ThrowsBadRequest(string sponsorshipToken, User user, - OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) + [ControllerCustomize(typeof(OrganizationSponsorshipsController))] + [SutProviderCustomize] + public class OrganizationSponsorshipsControllerTests { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, - user.Email).Returns((false, null)); + public static IEnumerable EnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable NonEnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable NonFamiliesPlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); + public static IEnumerable NonConfirmedOrganizationUsersStatuses => + Enum.GetValues() + .Where(s => s != OrganizationUserStatusType.Confirmed) + .Select(s => new object[] { s }); - Assert.Contains("Failed to parse sponsorship token.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SetUpSponsorshipAsync(default, default); - } - [Theory] - [BitAutoData] - public async Task RedeemSponsorship_NotSponsoredOrgOwner_ThrowsBadRequest(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, OrganizationSponsorshipRedeemRequestModel model, - SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, - user.Email).Returns((true, sponsorship)); - sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(false); + [Theory] + [BitAutoData] + public async Task RedeemSponsorship_BadToken_ThrowsBadRequest(string sponsorshipToken, User user, + OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, + user.Email).Returns((false, null)); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); - Assert.Contains("Can only redeem sponsorship for an organization you own.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SetUpSponsorshipAsync(default, default); - } + Assert.Contains("Failed to parse sponsorship token.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SetUpSponsorshipAsync(default, default); + } - [Theory] - [BitAutoData] - public async Task RedeemSponsorship_NotSponsoredOrgOwner_Success(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, Organization sponsoringOrganization, - OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, - user.Email).Returns((true, sponsorship)); - sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(true); - sutProvider.GetDependency().GetByIdAsync(model.SponsoredOrganizationId).Returns(sponsoringOrganization); + [Theory] + [BitAutoData] + public async Task RedeemSponsorship_NotSponsoredOrgOwner_ThrowsBadRequest(string sponsorshipToken, User user, + OrganizationSponsorship sponsorship, OrganizationSponsorshipRedeemRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, + user.Email).Returns((true, sponsorship)); + sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(false); - await sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); - await sutProvider.GetDependency().Received(1) - .SetUpSponsorshipAsync(sponsorship, sponsoringOrganization); - } + Assert.Contains("Can only redeem sponsorship for an organization you own.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SetUpSponsorshipAsync(default, default); + } - [Theory] - [BitAutoData] - public async Task PreValidateSponsorshipToken_ValidatesToken_Success(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency() - .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email).Returns((true, sponsorship)); + [Theory] + [BitAutoData] + public async Task RedeemSponsorship_NotSponsoredOrgOwner_Success(string sponsorshipToken, User user, + OrganizationSponsorship sponsorship, Organization sponsoringOrganization, + OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, + user.Email).Returns((true, sponsorship)); + sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(true); + sutProvider.GetDependency().GetByIdAsync(model.SponsoredOrganizationId).Returns(sponsoringOrganization); - await sutProvider.Sut.PreValidateSponsorshipToken(sponsorshipToken); + await sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model); - await sutProvider.GetDependency().Received(1) - .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email); - } + await sutProvider.GetDependency().Received(1) + .SetUpSponsorshipAsync(sponsorship, sponsoringOrganization); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_WrongSponsoringUser_ThrowsBadRequest(OrganizationUser sponsoringOrgUser, - Guid currentUserId, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(currentUserId); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrgUser.Id) - .Returns(sponsoringOrgUser); + [Theory] + [BitAutoData] + public async Task PreValidateSponsorshipToken_ValidatesToken_Success(string sponsorshipToken, User user, + OrganizationSponsorship sponsorship, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency() + .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email).Returns((true, sponsorship)); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RevokeSponsorship(sponsoringOrgUser.Id)); + await sutProvider.Sut.PreValidateSponsorshipToken(sponsorshipToken); - Assert.Contains("Can only revoke a sponsorship you granted.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .RemoveSponsorshipAsync(default); - } + await sutProvider.GetDependency().Received(1) + .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email); + } - [Theory] - [BitAutoData] - public async Task RemoveSponsorship_WrongOrgUserType_ThrowsBadRequest(Organization sponsoredOrg, - SutProvider sutProvider) - { - sutProvider.GetDependency().OrganizationOwner(Arg.Any()).Returns(false); + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_WrongSponsoringUser_ThrowsBadRequest(OrganizationUser sponsoringOrgUser, + Guid currentUserId, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(currentUserId); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrgUser.Id) + .Returns(sponsoringOrgUser); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RemoveSponsorship(sponsoredOrg.Id)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RevokeSponsorship(sponsoringOrgUser.Id)); - Assert.Contains("Only the owner of an organization can remove sponsorship.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .RemoveSponsorshipAsync(default); + Assert.Contains("Can only revoke a sponsorship you granted.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .RemoveSponsorshipAsync(default); + } + + [Theory] + [BitAutoData] + public async Task RemoveSponsorship_WrongOrgUserType_ThrowsBadRequest(Organization sponsoredOrg, + SutProvider sutProvider) + { + sutProvider.GetDependency().OrganizationOwner(Arg.Any()).Returns(false); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RemoveSponsorship(sponsoredOrg.Id)); + + Assert.Contains("Only the owner of an organization can remove sponsorship.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .RemoveSponsorshipAsync(default); + } } } diff --git a/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs index c5c1019df6..585508c663 100644 --- a/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs @@ -10,56 +10,57 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -[ControllerCustomize(typeof(OrganizationUsersController))] -[SutProviderCustomize] -public class OrganizationUsersControllerTests +namespace Bit.Api.Test.Controllers { - [Theory] - [BitAutoData] - public async Task Accept_RequiresKnownUser(Guid orgId, Guid orgUserId, OrganizationUserAcceptRequestModel model, - SutProvider sutProvider) + [ControllerCustomize(typeof(OrganizationUsersController))] + [SutProviderCustomize] + public class OrganizationUsersControllerTests { - sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs((User)null); - - await Assert.ThrowsAsync(() => sutProvider.Sut.Accept(orgId, orgUserId, model)); - } - - [Theory] - [BitAutoData] - public async Task Accept_NoMasterPasswordReset(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); - - await sutProvider.Sut.Accept(orgId, orgUserId, model); - - await sutProvider.GetDependency().Received(1) - .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpdateUserResetPasswordEnrollmentAsync(default, default, default, default); - } - - [Theory] - [BitAutoData] - public async Task Accept_RequireMasterPasswordReset(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) - { - var policy = new Policy + [Theory] + [BitAutoData] + public async Task Accept_RequiresKnownUser(Guid orgId, Guid orgUserId, OrganizationUserAcceptRequestModel model, + SutProvider sutProvider) { - Enabled = true, - Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }), - }; - sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); - sutProvider.GetDependency().GetByOrganizationIdTypeAsync(orgId, - Core.Enums.PolicyType.ResetPassword).Returns(policy); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs((User)null); - await sutProvider.Sut.Accept(orgId, orgUserId, model); + await Assert.ThrowsAsync(() => sutProvider.Sut.Accept(orgId, orgUserId, model)); + } - await sutProvider.GetDependency().Received(1) - .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); - await sutProvider.GetDependency().Received(1) - .UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); + [Theory] + [BitAutoData] + public async Task Accept_NoMasterPasswordReset(Guid orgId, Guid orgUserId, + OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + + await sutProvider.Sut.Accept(orgId, orgUserId, model); + + await sutProvider.GetDependency().Received(1) + .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpdateUserResetPasswordEnrollmentAsync(default, default, default, default); + } + + [Theory] + [BitAutoData] + public async Task Accept_RequireMasterPasswordReset(Guid orgId, Guid orgUserId, + OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + { + var policy = new Policy + { + Enabled = true, + Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }), + }; + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByOrganizationIdTypeAsync(orgId, + Core.Enums.PolicyType.ResetPassword).Returns(policy); + + await sutProvider.Sut.Accept(orgId, orgUserId, model); + + await sutProvider.GetDependency().Received(1) + .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); + await sutProvider.GetDependency().Received(1) + .UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); + } } } diff --git a/test/Api.Test/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/Controllers/OrganizationsControllerTests.cs index dddd9c5f05..f31c319273 100644 --- a/test/Api.Test/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationsControllerTests.cs @@ -12,108 +12,109 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -public class OrganizationsControllerTests : IDisposable +namespace Bit.Api.Test.Controllers { - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoConfigService _ssoConfigService; - private readonly IUserService _userService; - private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; - private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - - private readonly OrganizationsController _sut; - - public OrganizationsControllerTests() + public class OrganizationsControllerTests : IDisposable { - _currentContext = Substitute.For(); - _globalSettings = Substitute.For(); - _organizationRepository = Substitute.For(); - _organizationService = Substitute.For(); - _organizationUserRepository = Substitute.For(); - _paymentService = Substitute.For(); - _policyRepository = Substitute.For(); - _ssoConfigRepository = Substitute.For(); - _ssoConfigService = Substitute.For(); - _getOrganizationApiKeyCommand = Substitute.For(); - _rotateOrganizationApiKeyCommand = Substitute.For(); - _organizationApiKeyRepository = Substitute.For(); - _userService = Substitute.For(); + private readonly GlobalSettings _globalSettings; + private readonly ICurrentContext _currentContext; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPaymentService _paymentService; + private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoConfigService _ssoConfigService; + private readonly IUserService _userService; + private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; + private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - _sut = new OrganizationsController(_organizationRepository, _organizationUserRepository, - _policyRepository, _organizationService, _userService, _paymentService, _currentContext, - _ssoConfigRepository, _ssoConfigService, _getOrganizationApiKeyCommand, _rotateOrganizationApiKeyCommand, - _organizationApiKeyRepository, _globalSettings); - } + private readonly OrganizationsController _sut; - public void Dispose() - { - _sut?.Dispose(); - } - - [Theory, AutoData] - public async Task OrganizationsController_UserCannotLeaveOrganizationThatProvidesKeyConnector( - Guid orgId, User user) - { - var ssoConfig = new SsoConfig + public OrganizationsControllerTests() { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = orgId, - }; + _currentContext = Substitute.For(); + _globalSettings = Substitute.For(); + _organizationRepository = Substitute.For(); + _organizationService = Substitute.For(); + _organizationUserRepository = Substitute.For(); + _paymentService = Substitute.For(); + _policyRepository = Substitute.For(); + _ssoConfigRepository = Substitute.For(); + _ssoConfigService = Substitute.For(); + _getOrganizationApiKeyCommand = Substitute.For(); + _rotateOrganizationApiKeyCommand = Substitute.For(); + _organizationApiKeyRepository = Substitute.For(); + _userService = Substitute.For(); - user.UsesKeyConnector = true; + _sut = new OrganizationsController(_organizationRepository, _organizationUserRepository, + _policyRepository, _organizationService, _userService, _paymentService, _currentContext, + _ssoConfigRepository, _ssoConfigService, _getOrganizationApiKeyCommand, _rotateOrganizationApiKeyCommand, + _organizationApiKeyRepository, _globalSettings); + } - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - - var exception = await Assert.ThrowsAsync( - () => _sut.Leave(orgId.ToString())); - - Assert.Contains("Your organization's Single Sign-On settings prevent you from leaving.", - exception.Message); - - await _organizationService.DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); - } - - [Theory] - [InlineAutoData(true, false)] - [InlineAutoData(false, true)] - [InlineAutoData(false, false)] - public async Task OrganizationsController_UserCanLeaveOrganizationThatDoesntProvideKeyConnector( - bool keyConnectorEnabled, bool userUsesKeyConnector, Guid orgId, User user) - { - var ssoConfig = new SsoConfig + public void Dispose() { - Id = default, - Data = new SsoConfigurationData + _sut?.Dispose(); + } + + [Theory, AutoData] + public async Task OrganizationsController_UserCannotLeaveOrganizationThatProvidesKeyConnector( + Guid orgId, User user) + { + var ssoConfig = new SsoConfig { - KeyConnectorEnabled = keyConnectorEnabled, - }.Serialize(), - Enabled = true, - OrganizationId = orgId, - }; + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = orgId, + }; - user.UsesKeyConnector = userUsesKeyConnector; + user.UsesKeyConnector = true; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _currentContext.OrganizationUser(orgId).Returns(true); + _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - await _organizationService.DeleteUserAsync(orgId, user.Id); - await _organizationService.Received(1).DeleteUserAsync(orgId, user.Id); + var exception = await Assert.ThrowsAsync( + () => _sut.Leave(orgId.ToString())); + + Assert.Contains("Your organization's Single Sign-On settings prevent you from leaving.", + exception.Message); + + await _organizationService.DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); + } + + [Theory] + [InlineAutoData(true, false)] + [InlineAutoData(false, true)] + [InlineAutoData(false, false)] + public async Task OrganizationsController_UserCanLeaveOrganizationThatDoesntProvideKeyConnector( + bool keyConnectorEnabled, bool userUsesKeyConnector, Guid orgId, User user) + { + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = keyConnectorEnabled, + }.Serialize(), + Enabled = true, + OrganizationId = orgId, + }; + + user.UsesKeyConnector = userUsesKeyConnector; + + _currentContext.OrganizationUser(orgId).Returns(true); + _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + await _organizationService.DeleteUserAsync(orgId, user.Id); + await _organizationService.Received(1).DeleteUserAsync(orgId, user.Id); + } } } diff --git a/test/Api.Test/Controllers/SendsControllerTests.cs b/test/Api.Test/Controllers/SendsControllerTests.cs index 07ca95a857..e86d95c550 100644 --- a/test/Api.Test/Controllers/SendsControllerTests.cs +++ b/test/Api.Test/Controllers/SendsControllerTests.cs @@ -15,65 +15,66 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers; - -public class SendsControllerTests : IDisposable +namespace Bit.Api.Test.Controllers { - - private readonly SendsController _sut; - private readonly GlobalSettings _globalSettings; - private readonly IUserService _userService; - private readonly ISendRepository _sendRepository; - private readonly ISendService _sendService; - private readonly ISendFileStorageService _sendFileStorageService; - private readonly ILogger _logger; - private readonly ICurrentContext _currentContext; - - public SendsControllerTests() + public class SendsControllerTests : IDisposable { - _userService = Substitute.For(); - _sendRepository = Substitute.For(); - _sendService = Substitute.For(); - _sendFileStorageService = Substitute.For(); - _globalSettings = new GlobalSettings(); - _logger = Substitute.For>(); - _currentContext = Substitute.For(); - _sut = new SendsController( - _sendRepository, - _userService, - _sendService, - _sendFileStorageService, - _logger, - _globalSettings, - _currentContext - ); - } + private readonly SendsController _sut; + private readonly GlobalSettings _globalSettings; + private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; + private readonly ISendFileStorageService _sendFileStorageService; + private readonly ILogger _logger; + private readonly ICurrentContext _currentContext; - public void Dispose() - { - _sut?.Dispose(); - } + public SendsControllerTests() + { + _userService = Substitute.For(); + _sendRepository = Substitute.For(); + _sendService = Substitute.For(); + _sendFileStorageService = Substitute.For(); + _globalSettings = new GlobalSettings(); + _logger = Substitute.For>(); + _currentContext = Substitute.For(); - [Theory, AutoData] - public async Task SendsController_WhenSendHidesEmail_CreatorIdentifierShouldBeNull( - Guid id, Send send, User user) - { - var accessId = CoreHelpers.Base64UrlEncode(id.ToByteArray()); + _sut = new SendsController( + _sendRepository, + _userService, + _sendService, + _sendFileStorageService, + _logger, + _globalSettings, + _currentContext + ); + } - send.Id = default; - send.Type = SendType.Text; - send.Data = JsonSerializer.Serialize(new Dictionary()); - send.HideEmail = true; + public void Dispose() + { + _sut?.Dispose(); + } - _sendService.AccessAsync(id, null).Returns((send, false, false)); - _userService.GetUserByIdAsync(Arg.Any()).Returns(user); + [Theory, AutoData] + public async Task SendsController_WhenSendHidesEmail_CreatorIdentifierShouldBeNull( + Guid id, Send send, User user) + { + var accessId = CoreHelpers.Base64UrlEncode(id.ToByteArray()); - var request = new SendAccessRequestModel(); - var actionResult = await _sut.Access(accessId, request); - var response = (actionResult as ObjectResult)?.Value as SendAccessResponseModel; + send.Id = default; + send.Type = SendType.Text; + send.Data = JsonSerializer.Serialize(new Dictionary()); + send.HideEmail = true; - Assert.NotNull(response); - Assert.Null(response.CreatorIdentifier); + _sendService.AccessAsync(id, null).Returns((send, false, false)); + _userService.GetUserByIdAsync(Arg.Any()).Returns(user); + + var request = new SendAccessRequestModel(); + var actionResult = await _sut.Access(accessId, request); + var response = (actionResult as ObjectResult)?.Value as SendAccessResponseModel; + + Assert.NotNull(response); + Assert.Null(response.CreatorIdentifier); + } } } diff --git a/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs b/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs index 9f08705317..6d01e0a157 100644 --- a/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs +++ b/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs @@ -3,62 +3,63 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Xunit; -namespace Bit.Api.Test.Models.Request.Accounts; - -public class PremiumRequestModelTests +namespace Bit.Api.Test.Models.Request.Accounts { - public static IEnumerable GetValidateData() + public class PremiumRequestModelTests { - // 1. selfHosted - // 2. formFile - // 3. country - // 4. expected + public static IEnumerable GetValidateData() + { + // 1. selfHosted + // 2. formFile + // 3. country + // 4. expected - yield return new object[] { true, null, null, false }; - yield return new object[] { true, null, "US", false }; - yield return new object[] { true, new NotImplementedFormFile(), null, false }; - yield return new object[] { true, new NotImplementedFormFile(), "US", false }; + yield return new object[] { true, null, null, false }; + yield return new object[] { true, null, "US", false }; + yield return new object[] { true, new NotImplementedFormFile(), null, false }; + yield return new object[] { true, new NotImplementedFormFile(), "US", false }; - yield return new object[] { false, null, null, false }; - yield return new object[] { false, null, "US", true }; // Only true, cloud with null license AND a Country - yield return new object[] { false, new NotImplementedFormFile(), null, false }; - yield return new object[] { false, new NotImplementedFormFile(), "US", false }; + yield return new object[] { false, null, null, false }; + yield return new object[] { false, null, "US", true }; // Only true, cloud with null license AND a Country + yield return new object[] { false, new NotImplementedFormFile(), null, false }; + yield return new object[] { false, new NotImplementedFormFile(), "US", false }; + } + + [Theory] + [MemberData(nameof(GetValidateData))] + public void Validate_Success(bool selfHosted, IFormFile formFile, string country, bool expected) + { + var gs = new GlobalSettings + { + SelfHosted = selfHosted + }; + + var sut = new PremiumRequestModel + { + License = formFile, + Country = country, + }; + + Assert.Equal(expected, sut.Validate(gs)); + } } - [Theory] - [MemberData(nameof(GetValidateData))] - public void Validate_Success(bool selfHosted, IFormFile formFile, string country, bool expected) + public class NotImplementedFormFile : IFormFile { - var gs = new GlobalSettings - { - SelfHosted = selfHosted - }; + public string ContentType => throw new NotImplementedException(); - var sut = new PremiumRequestModel - { - License = formFile, - Country = country, - }; + public string ContentDisposition => throw new NotImplementedException(); - Assert.Equal(expected, sut.Validate(gs)); + public IHeaderDictionary Headers => throw new NotImplementedException(); + + public long Length => throw new NotImplementedException(); + + public string Name => throw new NotImplementedException(); + + public string FileName => throw new NotImplementedException(); + + public void CopyTo(Stream target) => throw new NotImplementedException(); + public Task CopyToAsync(Stream target, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Stream OpenReadStream() => throw new NotImplementedException(); } } - -public class NotImplementedFormFile : IFormFile -{ - public string ContentType => throw new NotImplementedException(); - - public string ContentDisposition => throw new NotImplementedException(); - - public IHeaderDictionary Headers => throw new NotImplementedException(); - - public long Length => throw new NotImplementedException(); - - public string Name => throw new NotImplementedException(); - - public string FileName => throw new NotImplementedException(); - - public void CopyTo(Stream target) => throw new NotImplementedException(); - public Task CopyToAsync(Stream target, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Stream OpenReadStream() => throw new NotImplementedException(); -} diff --git a/test/Api.Test/Models/Request/SendRequestModelTests.cs b/test/Api.Test/Models/Request/SendRequestModelTests.cs index 7ad858d2e2..ffcf043bd0 100644 --- a/test/Api.Test/Models/Request/SendRequestModelTests.cs +++ b/test/Api.Test/Models/Request/SendRequestModelTests.cs @@ -7,53 +7,54 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Models.Request; - -public class SendRequestModelTests +namespace Bit.Api.Test.Models.Request { - [Fact] - public void ToSend_Text_Success() + public class SendRequestModelTests { - var deletionDate = DateTime.UtcNow.AddDays(5); - var sendRequest = new SendRequestModel + [Fact] + public void ToSend_Text_Success() { - DeletionDate = deletionDate, - Disabled = false, - ExpirationDate = null, - HideEmail = false, - Key = "encrypted_key", - MaxAccessCount = null, - Name = "encrypted_name", - Notes = null, - Password = "Password", - Text = new SendTextModel() + var deletionDate = DateTime.UtcNow.AddDays(5); + var sendRequest = new SendRequestModel { - Hidden = false, - Text = "encrypted_text" - }, - Type = SendType.Text, - }; + DeletionDate = deletionDate, + Disabled = false, + ExpirationDate = null, + HideEmail = false, + Key = "encrypted_key", + MaxAccessCount = null, + Name = "encrypted_name", + Notes = null, + Password = "Password", + Text = new SendTextModel() + { + Hidden = false, + Text = "encrypted_text" + }, + Type = SendType.Text, + }; - var sendService = Substitute.For(); - sendService.HashPassword(Arg.Any()) - .Returns((info) => $"hashed_{(string)info[0]}"); + var sendService = Substitute.For(); + sendService.HashPassword(Arg.Any()) + .Returns((info) => $"hashed_{(string)info[0]}"); - var send = sendRequest.ToSend(Guid.NewGuid(), sendService); + var send = sendRequest.ToSend(Guid.NewGuid(), sendService); - Assert.Equal(deletionDate, send.DeletionDate); - Assert.False(send.Disabled); - Assert.Null(send.ExpirationDate); - Assert.False(send.HideEmail); - Assert.Equal("encrypted_key", send.Key); - Assert.Equal("hashed_Password", send.Password); + Assert.Equal(deletionDate, send.DeletionDate); + Assert.False(send.Disabled); + Assert.Null(send.ExpirationDate); + Assert.False(send.HideEmail); + Assert.Equal("encrypted_key", send.Key); + Assert.Equal("hashed_Password", send.Password); - using var jsonDocument = JsonDocument.Parse(send.Data); - var root = jsonDocument.RootElement; - var text = AssertHelper.AssertJsonProperty(root, "Text", JsonValueKind.String).GetString(); - Assert.Equal("encrypted_text", text); - AssertHelper.AssertJsonProperty(root, "Hidden", JsonValueKind.False); - Assert.False(root.TryGetProperty("Notes", out var _)); - var name = AssertHelper.AssertJsonProperty(root, "Name", JsonValueKind.String).GetString(); - Assert.Equal("encrypted_name", name); + using var jsonDocument = JsonDocument.Parse(send.Data); + var root = jsonDocument.RootElement; + var text = AssertHelper.AssertJsonProperty(root, "Text", JsonValueKind.String).GetString(); + Assert.Equal("encrypted_text", text); + AssertHelper.AssertJsonProperty(root, "Hidden", JsonValueKind.False); + Assert.False(root.TryGetProperty("Notes", out var _)); + var name = AssertHelper.AssertJsonProperty(root, "Name", JsonValueKind.String).GetString(); + Assert.Equal("encrypted_name", name); + } } } diff --git a/test/Api.Test/Utilities/ApiHelpersTests.cs b/test/Api.Test/Utilities/ApiHelpersTests.cs index 4013a2222a..718ec2eeb4 100644 --- a/test/Api.Test/Utilities/ApiHelpersTests.cs +++ b/test/Api.Test/Utilities/ApiHelpersTests.cs @@ -5,22 +5,23 @@ using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Utilities; - -public class ApiHelpersTests +namespace Bit.Api.Test.Utilities { - [Fact] - public async Task ReadJsonFileFromBody_Success() + public class ApiHelpersTests { - var context = Substitute.For(); - context.Request.ContentLength.Returns(200); - var bytes = Encoding.UTF8.GetBytes(testFile); - var formFile = new FormFile(new MemoryStream(bytes), 0, bytes.Length, "bitwarden_organization_license", "bitwarden_organization_license.json"); + [Fact] + public async Task ReadJsonFileFromBody_Success() + { + var context = Substitute.For(); + context.Request.ContentLength.Returns(200); + var bytes = Encoding.UTF8.GetBytes(testFile); + var formFile = new FormFile(new MemoryStream(bytes), 0, bytes.Length, "bitwarden_organization_license", "bitwarden_organization_license.json"); - var license = await ApiHelpers.ReadJsonFileFromBody(context, formFile); - Assert.Equal(8, license.Version); + var license = await ApiHelpers.ReadJsonFileFromBody(context, formFile); + Assert.Equal(8, license.Version); + } + + const string testFile = "{\"licenseKey\": \"licenseKey\", \"installationId\": \"6285f891-b2ec-4047-84c5-2eb7f7747e74\", \"id\": \"1065216d-5854-4326-838d-635487f30b43\",\"name\": \"Test Org\",\"billingEmail\": \"test@email.com\",\"businessName\": null,\"enabled\": true, \"plan\": \"Enterprise (Annually)\",\"planType\": 11,\"seats\": 6,\"maxCollections\": null,\"usePolicies\": true,\"useSso\": true,\"useKeyConnector\": false,\"useGroups\": true,\"useEvents\": true,\"useDirectory\": true,\"useTotp\": true,\"use2fa\": true,\"useApi\": true,\"useResetPassword\": true,\"maxStorageGb\": 1,\"selfHost\": true,\"usersGetPremium\": true,\"version\": 8,\"issued\": \"2022-01-25T21:58:38.9454581Z\",\"refresh\": \"2022-01-28T14:26:31Z\",\"expires\": \"2022-01-28T14:26:31Z\",\"trial\": true,\"hash\": \"testvalue\",\"signature\": \"signature\"}"; } - - const string testFile = "{\"licenseKey\": \"licenseKey\", \"installationId\": \"6285f891-b2ec-4047-84c5-2eb7f7747e74\", \"id\": \"1065216d-5854-4326-838d-635487f30b43\",\"name\": \"Test Org\",\"billingEmail\": \"test@email.com\",\"businessName\": null,\"enabled\": true, \"plan\": \"Enterprise (Annually)\",\"planType\": 11,\"seats\": 6,\"maxCollections\": null,\"usePolicies\": true,\"useSso\": true,\"useKeyConnector\": false,\"useGroups\": true,\"useEvents\": true,\"useDirectory\": true,\"useTotp\": true,\"use2fa\": true,\"useApi\": true,\"useResetPassword\": true,\"maxStorageGb\": 1,\"selfHost\": true,\"usersGetPremium\": true,\"version\": 8,\"issued\": \"2022-01-25T21:58:38.9454581Z\",\"refresh\": \"2022-01-28T14:26:31Z\",\"expires\": \"2022-01-28T14:26:31Z\",\"trial\": true,\"hash\": \"testvalue\",\"signature\": \"signature\"}"; } diff --git a/test/Billing.Test/Controllers/FreshdeskControllerTests.cs b/test/Billing.Test/Controllers/FreshdeskControllerTests.cs index 94f9e28490..3688960029 100644 --- a/test/Billing.Test/Controllers/FreshdeskControllerTests.cs +++ b/test/Billing.Test/Controllers/FreshdeskControllerTests.cs @@ -10,70 +10,71 @@ using Microsoft.Extensions.Options; using NSubstitute; using Xunit; -namespace Bit.Billing.Test.Controllers; - -[ControllerCustomize(typeof(FreshdeskController))] -[SutProviderCustomize] -public class FreshdeskControllerTests +namespace Bit.Billing.Test.Controllers { - private const string ApiKey = "TESTFRESHDESKAPIKEY"; - private const string WebhookKey = "TESTKEY"; - - [Theory] - [BitAutoData((string)null, null)] - [BitAutoData((string)null)] - [BitAutoData(WebhookKey, null)] - public async Task PostWebhook_NullRequiredParameters_BadRequest(string freshdeskWebhookKey, FreshdeskWebhookModel model, - BillingSettings billingSettings, SutProvider sutProvider) + [ControllerCustomize(typeof(FreshdeskController))] + [SutProviderCustomize] + public class FreshdeskControllerTests { - sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(billingSettings.FreshdeskWebhookKey); + private const string ApiKey = "TESTFRESHDESKAPIKEY"; + private const string WebhookKey = "TESTKEY"; - var response = await sutProvider.Sut.PostWebhook(freshdeskWebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status400BadRequest, statusCodeResult.StatusCode); - } - - [Theory] - [BitAutoData] - public async Task PostWebhook_Success(User user, FreshdeskWebhookModel model, - List organizations, SutProvider sutProvider) - { - model.TicketContactEmail = user.Email; - - sutProvider.GetDependency().GetByEmailAsync(user.Email).Returns(user); - sutProvider.GetDependency().GetManyByUserIdAsync(user.Id).Returns(organizations); - - var mockHttpMessageHandler = Substitute.ForPartsOf(); - var mockResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); - mockHttpMessageHandler.Send(Arg.Any(), Arg.Any()) - .Returns(mockResponse); - var httpClient = new HttpClient(mockHttpMessageHandler); - - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(httpClient); - - sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(WebhookKey); - sutProvider.GetDependency>().Value.FreshdeskApiKey.Returns(ApiKey); - - var response = await sutProvider.Sut.PostWebhook(WebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); - - _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Put && m.RequestUri.ToString().EndsWith(model.TicketId)), Arg.Any()); - _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Post && m.RequestUri.ToString().EndsWith($"{model.TicketId}/notes")), Arg.Any()); - } - - public class MockHttpMessageHandler : HttpMessageHandler - { - protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + [Theory] + [BitAutoData((string)null, null)] + [BitAutoData((string)null)] + [BitAutoData(WebhookKey, null)] + public async Task PostWebhook_NullRequiredParameters_BadRequest(string freshdeskWebhookKey, FreshdeskWebhookModel model, + BillingSettings billingSettings, SutProvider sutProvider) { - return Send(request, cancellationToken); + sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(billingSettings.FreshdeskWebhookKey); + + var response = await sutProvider.Sut.PostWebhook(freshdeskWebhookKey, model); + + var statusCodeResult = Assert.IsAssignableFrom(response); + Assert.Equal(StatusCodes.Status400BadRequest, statusCodeResult.StatusCode); } - public virtual Task Send(HttpRequestMessage request, CancellationToken cancellationToken) + [Theory] + [BitAutoData] + public async Task PostWebhook_Success(User user, FreshdeskWebhookModel model, + List organizations, SutProvider sutProvider) { - throw new NotImplementedException(); + model.TicketContactEmail = user.Email; + + sutProvider.GetDependency().GetByEmailAsync(user.Email).Returns(user); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id).Returns(organizations); + + var mockHttpMessageHandler = Substitute.ForPartsOf(); + var mockResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); + mockHttpMessageHandler.Send(Arg.Any(), Arg.Any()) + .Returns(mockResponse); + var httpClient = new HttpClient(mockHttpMessageHandler); + + sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(httpClient); + + sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(WebhookKey); + sutProvider.GetDependency>().Value.FreshdeskApiKey.Returns(ApiKey); + + var response = await sutProvider.Sut.PostWebhook(WebhookKey, model); + + var statusCodeResult = Assert.IsAssignableFrom(response); + Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); + + _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Put && m.RequestUri.ToString().EndsWith(model.TicketId)), Arg.Any()); + _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Post && m.RequestUri.ToString().EndsWith($"{model.TicketId}/notes")), Arg.Any()); + } + + public class MockHttpMessageHandler : HttpMessageHandler + { + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return Send(request, cancellationToken); + } + + public virtual Task Send(HttpRequestMessage request, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } } } } diff --git a/test/Billing.Test/Controllers/FreshsalesControllerTests.cs b/test/Billing.Test/Controllers/FreshsalesControllerTests.cs index 3a5cf3bf17..490f4051a8 100644 --- a/test/Billing.Test/Controllers/FreshsalesControllerTests.cs +++ b/test/Billing.Test/Controllers/FreshsalesControllerTests.cs @@ -10,72 +10,73 @@ using Microsoft.Extensions.Options; using NSubstitute; using Xunit; -namespace Bit.Billing.Test.Controllers; - -public class FreshsalesControllerTests +namespace Bit.Billing.Test.Controllers { - private const string ApiKey = "TEST_FRESHSALES_APIKEY"; - private const string TestLead = "TEST_FRESHSALES_TESTLEAD"; - - private static (FreshsalesController, IUserRepository, IOrganizationRepository) CreateSut( - string freshsalesApiKey) + public class FreshsalesControllerTests { - var userRepository = Substitute.For(); - var organizationRepository = Substitute.For(); + private const string ApiKey = "TEST_FRESHSALES_APIKEY"; + private const string TestLead = "TEST_FRESHSALES_TESTLEAD"; - var billingSettings = Options.Create(new BillingSettings + private static (FreshsalesController, IUserRepository, IOrganizationRepository) CreateSut( + string freshsalesApiKey) { - FreshsalesApiKey = freshsalesApiKey, - }); - var globalSettings = new GlobalSettings(); - globalSettings.BaseServiceUri.Admin = "https://test.com"; + var userRepository = Substitute.For(); + var organizationRepository = Substitute.For(); - var sut = new FreshsalesController( - userRepository, - organizationRepository, - billingSettings, - Substitute.For>(), - globalSettings - ); - - return (sut, userRepository, organizationRepository); - } - - [RequiredEnvironmentTheory(ApiKey, TestLead), EnvironmentData(ApiKey, TestLead)] - public async Task PostWebhook_Success(string freshsalesApiKey, long leadId) - { - // This test is only for development to use: - // `export TEST_FRESHSALES_APIKEY=[apikey]` - // `export TEST_FRESHSALES_TESTLEAD=[lead id]` - // `dotnet test --filter "FullyQualifiedName~FreshsalesControllerTests.PostWebhook_Success"` - var (sut, userRepository, organizationRepository) = CreateSut(freshsalesApiKey); - - var user = new User - { - Id = Guid.NewGuid(), - Email = "test@email.com", - Premium = true, - }; - - userRepository.GetByEmailAsync(user.Email) - .Returns(user); - - organizationRepository.GetManyByUserIdAsync(user.Id) - .Returns(new List + var billingSettings = Options.Create(new BillingSettings { - new Organization - { - Id = Guid.NewGuid(), - Name = "Test Org", - } + FreshsalesApiKey = freshsalesApiKey, }); + var globalSettings = new GlobalSettings(); + globalSettings.BaseServiceUri.Admin = "https://test.com"; - var response = await sut.PostWebhook(freshsalesApiKey, new CustomWebhookRequestModel + var sut = new FreshsalesController( + userRepository, + organizationRepository, + billingSettings, + Substitute.For>(), + globalSettings + ); + + return (sut, userRepository, organizationRepository); + } + + [RequiredEnvironmentTheory(ApiKey, TestLead), EnvironmentData(ApiKey, TestLead)] + public async Task PostWebhook_Success(string freshsalesApiKey, long leadId) { - LeadId = leadId, - }, new CancellationToken(false)); + // This test is only for development to use: + // `export TEST_FRESHSALES_APIKEY=[apikey]` + // `export TEST_FRESHSALES_TESTLEAD=[lead id]` + // `dotnet test --filter "FullyQualifiedName~FreshsalesControllerTests.PostWebhook_Success"` + var (sut, userRepository, organizationRepository) = CreateSut(freshsalesApiKey); - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status204NoContent, statusCodeResult.StatusCode); + var user = new User + { + Id = Guid.NewGuid(), + Email = "test@email.com", + Premium = true, + }; + + userRepository.GetByEmailAsync(user.Email) + .Returns(user); + + organizationRepository.GetManyByUserIdAsync(user.Id) + .Returns(new List + { + new Organization + { + Id = Guid.NewGuid(), + Name = "Test Org", + } + }); + + var response = await sut.PostWebhook(freshsalesApiKey, new CustomWebhookRequestModel + { + LeadId = leadId, + }, new CancellationToken(false)); + + var statusCodeResult = Assert.IsAssignableFrom(response); + Assert.Equal(StatusCodes.Status204NoContent, statusCodeResult.StatusCode); + } } } diff --git a/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs index d859f81fc9..7185468ea8 100644 --- a/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs @@ -3,25 +3,26 @@ using AutoFixture; using Bit.Test.Common.Helpers; using Xunit.Sdk; -namespace Bit.Test.Common.AutoFixture.Attributes; - -[DataDiscoverer("AutoFixture.Xunit2.NoPreDiscoveryDataDiscoverer", "AutoFixture.Xunit2")] -public class BitAutoDataAttribute : DataAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - private readonly Func _createFixture; - private readonly object[] _fixedTestParameters; - - public BitAutoDataAttribute(params object[] fixedTestParameters) : - this(() => new Fixture(), fixedTestParameters) - { } - - public BitAutoDataAttribute(Func createFixture, params object[] fixedTestParameters) : - base() + [DataDiscoverer("AutoFixture.Xunit2.NoPreDiscoveryDataDiscoverer", "AutoFixture.Xunit2")] + public class BitAutoDataAttribute : DataAttribute { - _createFixture = createFixture; - _fixedTestParameters = fixedTestParameters; - } + private readonly Func _createFixture; + private readonly object[] _fixedTestParameters; - public override IEnumerable GetData(MethodInfo testMethod) - => BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), _fixedTestParameters); + public BitAutoDataAttribute(params object[] fixedTestParameters) : + this(() => new Fixture(), fixedTestParameters) + { } + + public BitAutoDataAttribute(Func createFixture, params object[] fixedTestParameters) : + base() + { + _createFixture = createFixture; + _fixedTestParameters = fixedTestParameters; + } + + public override IEnumerable GetData(MethodInfo testMethod) + => BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), _fixedTestParameters); + } } diff --git a/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs b/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs index 105a6632d8..9b9a5142d7 100644 --- a/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs +++ b/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs @@ -1,20 +1,21 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes; - -/// -/// -/// Base class for customizing parameters in methods decorated with the -/// Bit.Test.Common.AutoFixture.Attributes.MemberAutoDataAttribute. -/// -/// ⚠ Warning ⚠ Will not insert customizations into AutoFixture's AutoDataAttribute build chain -/// -[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Parameter, AllowMultiple = true)] -public abstract class BitCustomizeAttribute : Attribute +namespace Bit.Test.Common.AutoFixture.Attributes { /// - /// /// Gets a customization for the method's parameters. + /// + /// Base class for customizing parameters in methods decorated with the + /// Bit.Test.Common.AutoFixture.Attributes.MemberAutoDataAttribute. + /// + /// ⚠ Warning ⚠ Will not insert customizations into AutoFixture's AutoDataAttribute build chain /// - /// A customization for the method's paramters. - public abstract ICustomization GetCustomization(); + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Parameter, AllowMultiple = true)] + public abstract class BitCustomizeAttribute : Attribute + { + /// + /// /// Gets a customization for the method's parameters. + /// + /// A customization for the method's paramters. + public abstract ICustomization GetCustomization(); + } } diff --git a/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs index 7e6f81c30a..e9604e1c94 100644 --- a/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs @@ -3,22 +3,23 @@ using AutoFixture; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class BitMemberAutoDataAttribute : MemberDataAttributeBase +namespace Bit.Test.Common.AutoFixture.Attributes { - private readonly Func _createFixture; - - public BitMemberAutoDataAttribute(string memberName, params object[] parameters) : - this(() => new Fixture(), memberName, parameters) - { } - - public BitMemberAutoDataAttribute(Func createFixture, string memberName, params object[] parameters) : - base(memberName, parameters) + public class BitMemberAutoDataAttribute : MemberDataAttributeBase { - _createFixture = createFixture; - } + private readonly Func _createFixture; - protected override object[] ConvertDataItem(MethodInfo testMethod, object item) => - BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), item as object[]).First(); + public BitMemberAutoDataAttribute(string memberName, params object[] parameters) : + this(() => new Fixture(), memberName, parameters) + { } + + public BitMemberAutoDataAttribute(Func createFixture, string memberName, params object[] parameters) : + base(memberName, parameters) + { + _createFixture = createFixture; + } + + protected override object[] ConvertDataItem(MethodInfo testMethod, object item) => + BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), item as object[]).First(); + } } diff --git a/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs b/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs index 7627562b73..6cab60bae0 100644 --- a/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs +++ b/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs @@ -1,22 +1,23 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes; - -/// -/// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors from a mock environment. Still sets constructor dependencies. -/// -public class ControllerCustomizeAttribute : BitCustomizeAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - private readonly Type _controllerType; - /// - /// Initialize an instance of the ControllerCustomizeAttribute class + /// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors from a mock environment. Still sets constructor dependencies. /// - /// The Type of the controller to allow autofixture to create - public ControllerCustomizeAttribute(Type controllerType) + public class ControllerCustomizeAttribute : BitCustomizeAttribute { - _controllerType = controllerType; - } + private readonly Type _controllerType; - public override ICustomization GetCustomization() => new ControllerCustomization(_controllerType); + /// + /// Initialize an instance of the ControllerCustomizeAttribute class + /// + /// The Type of the controller to allow autofixture to create + public ControllerCustomizeAttribute(Type controllerType) + { + _controllerType = controllerType; + } + + public override ICustomization GetCustomization() => new ControllerCustomization(_controllerType); + } } diff --git a/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs index 6aac53ca34..75308e4487 100644 --- a/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs @@ -1,22 +1,23 @@ using AutoFixture; using AutoFixture.Xunit2; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class CustomAutoDataAttribute : AutoDataAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - public CustomAutoDataAttribute(params Type[] iCustomizationTypes) : this(iCustomizationTypes - .Select(t => (ICustomization)Activator.CreateInstance(t)).ToArray()) - { } - - public CustomAutoDataAttribute(params ICustomization[] customizations) : base(() => + public class CustomAutoDataAttribute : AutoDataAttribute { - var fixture = new Fixture(); - foreach (var customization in customizations) + public CustomAutoDataAttribute(params Type[] iCustomizationTypes) : this(iCustomizationTypes + .Select(t => (ICustomization)Activator.CreateInstance(t)).ToArray()) + { } + + public CustomAutoDataAttribute(params ICustomization[] customizations) : base(() => { - fixture.Customize(customization); - } - return fixture; - }) - { } + var fixture = new Fixture(); + foreach (var customization in customizations) + { + fixture.Customize(customization); + } + return fixture; + }) + { } + } } diff --git a/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs b/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs index acdf737be8..5479d766d2 100644 --- a/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs @@ -1,42 +1,43 @@ using System.Reflection; using Xunit.Sdk; -namespace Bit.Test.Common.AutoFixture.Attributes; - -/// -/// Used for collecting data from environment useful for when we want to test an integration with another service and -/// it might require an api key or other piece of sensitive data that we don't want slipping into the wrong hands. -/// -/// -/// It probably should be refactored to support fixtures and other customization so it can more easily be used in conjunction -/// with more parameters. Currently it attempt to match environment variable names to values of the parameter type in that positions. -/// It will start from the first parameter and go for each supplied name. -/// -public class EnvironmentDataAttribute : DataAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - private readonly string[] _environmentVariableNames; - - public EnvironmentDataAttribute(params string[] environmentVariableNames) + /// + /// Used for collecting data from environment useful for when we want to test an integration with another service and + /// it might require an api key or other piece of sensitive data that we don't want slipping into the wrong hands. + /// + /// + /// It probably should be refactored to support fixtures and other customization so it can more easily be used in conjunction + /// with more parameters. Currently it attempt to match environment variable names to values of the parameter type in that positions. + /// It will start from the first parameter and go for each supplied name. + /// + public class EnvironmentDataAttribute : DataAttribute { - _environmentVariableNames = environmentVariableNames; - } + private readonly string[] _environmentVariableNames; - public override IEnumerable GetData(MethodInfo testMethod) - { - var methodParameters = testMethod.GetParameters(); - - if (methodParameters.Length < _environmentVariableNames.Length) + public EnvironmentDataAttribute(params string[] environmentVariableNames) { - throw new ArgumentException($"The target test method only has {methodParameters.Length} arguments but you supplied {_environmentVariableNames.Length}"); + _environmentVariableNames = environmentVariableNames; } - var values = new object[_environmentVariableNames.Length]; - - for (var i = 0; i < _environmentVariableNames.Length; i++) + public override IEnumerable GetData(MethodInfo testMethod) { - values[i] = Convert.ChangeType(Environment.GetEnvironmentVariable(_environmentVariableNames[i]), methodParameters[i].ParameterType); - } + var methodParameters = testMethod.GetParameters(); - return new[] { values }; + if (methodParameters.Length < _environmentVariableNames.Length) + { + throw new ArgumentException($"The target test method only has {methodParameters.Length} arguments but you supplied {_environmentVariableNames.Length}"); + } + + var values = new object[_environmentVariableNames.Length]; + + for (var i = 0; i < _environmentVariableNames.Length; i++) + { + values[i] = Convert.ChangeType(Environment.GetEnvironmentVariable(_environmentVariableNames[i]), methodParameters[i].ParameterType); + } + + return new[] { values }; + } } } diff --git a/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs index fb16d2f908..b8c27f746e 100644 --- a/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs @@ -3,19 +3,20 @@ using AutoFixture.Xunit2; using Xunit; using Xunit.Sdk; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class InlineCustomAutoDataAttribute : CompositeDataAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - public InlineCustomAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base(new DataAttribute[] { - new InlineDataAttribute(values), - new CustomAutoDataAttribute(iCustomizationTypes) - }) - { } + public class InlineCustomAutoDataAttribute : CompositeDataAttribute + { + public InlineCustomAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base(new DataAttribute[] { + new InlineDataAttribute(values), + new CustomAutoDataAttribute(iCustomizationTypes) + }) + { } - public InlineCustomAutoDataAttribute(ICustomization[] customizations, params object[] values) : base(new DataAttribute[] { - new InlineDataAttribute(values), - new CustomAutoDataAttribute(customizations) - }) - { } + public InlineCustomAutoDataAttribute(ICustomization[] customizations, params object[] values) : base(new DataAttribute[] { + new InlineDataAttribute(values), + new CustomAutoDataAttribute(customizations) + }) + { } + } } diff --git a/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs index ae32b476cb..b2709a3308 100644 --- a/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs @@ -1,17 +1,18 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class InlineSutAutoDataAttribute : InlineCustomAutoDataAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - public InlineSutAutoDataAttribute(params object[] values) : base( - new Type[] { typeof(SutProviderCustomization) }, values) - { } - public InlineSutAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base( - iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray(), values) - { } + public class InlineSutAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineSutAutoDataAttribute(params object[] values) : base( + new Type[] { typeof(SutProviderCustomization) }, values) + { } + public InlineSutAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base( + iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray(), values) + { } - public InlineSutAutoDataAttribute(ICustomization[] customizations, params object[] values) : base( - customizations.Append(new SutProviderCustomization()).ToArray(), values) - { } + public InlineSutAutoDataAttribute(ICustomization[] customizations, params object[] values) : base( + customizations.Append(new SutProviderCustomization()).ToArray(), values) + { } + } } diff --git a/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs b/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs index 41b1dc63b4..d4df0599ab 100644 --- a/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs +++ b/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs @@ -1,10 +1,11 @@ using AutoFixture; using Bit.Test.Common.AutoFixture.JsonDocumentFixtures; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class JsonDocumentCustomizeAttribute : BitCustomizeAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - public string Json { get; set; } - public override ICustomization GetCustomization() => new JsonDocumentCustomization() { Json = Json }; + public class JsonDocumentCustomizeAttribute : BitCustomizeAttribute + { + public string Json { get; set; } + public override ICustomization GetCustomization() => new JsonDocumentCustomization() { Json = Json }; + } } diff --git a/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs b/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs index 5bb0c3485f..1830010633 100644 --- a/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs +++ b/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs @@ -1,37 +1,38 @@ using Xunit; -namespace Bit.Test.Common.AutoFixture.Attributes; - -/// -/// Used for requiring certain environment variables exist at the time. Mostly used for more edge unit tests that shouldn't -/// be run during CI builds or should only be ran in CI builds when pieces of information are available. -/// -public class RequiredEnvironmentTheoryAttribute : TheoryAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - private readonly string[] _environmentVariableNames; - - public RequiredEnvironmentTheoryAttribute(params string[] environmentVariableNames) + /// + /// Used for requiring certain environment variables exist at the time. Mostly used for more edge unit tests that shouldn't + /// be run during CI builds or should only be ran in CI builds when pieces of information are available. + /// + public class RequiredEnvironmentTheoryAttribute : TheoryAttribute { - _environmentVariableNames = environmentVariableNames; + private readonly string[] _environmentVariableNames; - if (!HasRequiredEnvironmentVariables()) + public RequiredEnvironmentTheoryAttribute(params string[] environmentVariableNames) { - Skip = $"Missing one or more required environment variables. ({string.Join(", ", _environmentVariableNames)})"; - } - } + _environmentVariableNames = environmentVariableNames; - private bool HasRequiredEnvironmentVariables() - { - foreach (var env in _environmentVariableNames) - { - var value = Environment.GetEnvironmentVariable(env); - - if (value == null) + if (!HasRequiredEnvironmentVariables()) { - return false; + Skip = $"Missing one or more required environment variables. ({string.Join(", ", _environmentVariableNames)})"; } } - return true; + private bool HasRequiredEnvironmentVariables() + { + foreach (var env in _environmentVariableNames) + { + var value = Environment.GetEnvironmentVariable(env); + + if (value == null) + { + return false; + } + } + + return true; + } } } diff --git a/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs index a84bc31187..3680f4a667 100644 --- a/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs @@ -1,15 +1,16 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class SutProviderCustomizeAttribute : BitCustomizeAttribute +namespace Bit.Test.Common.AutoFixture.Attributes { - public override ICustomization GetCustomization() => new SutProviderCustomization(); -} + public class SutProviderCustomizeAttribute : BitCustomizeAttribute + { + public override ICustomization GetCustomization() => new SutProviderCustomization(); + } -public class SutAutoDataAttribute : CustomAutoDataAttribute -{ - public SutAutoDataAttribute(params Type[] iCustomizationTypes) : base( - iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray()) - { } + public class SutAutoDataAttribute : CustomAutoDataAttribute + { + public SutAutoDataAttribute(params Type[] iCustomizationTypes) : base( + iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray()) + { } + } } diff --git a/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs b/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs index 039475fadc..b2bdae0d45 100644 --- a/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs +++ b/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs @@ -1,38 +1,39 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture; - -public class BuilderWithoutAutoProperties : ISpecimenBuilder +namespace Bit.Test.Common.AutoFixture { - private readonly Type _type; - public BuilderWithoutAutoProperties(Type type) + public class BuilderWithoutAutoProperties : ISpecimenBuilder { - _type = type; - } - - public object Create(object request, ISpecimenContext context) - { - if (context == null) + private readonly Type _type; + public BuilderWithoutAutoProperties(Type type) { - throw new ArgumentNullException(nameof(context)); + _type = type; } - var type = request as Type; - if (type == null || type != _type) + public object Create(object request, ISpecimenContext context) { - return new NoSpecimen(); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var fixture = new Fixture(); - // This is the equivalent of _fixture.Build<_type>().OmitAutoProperties().Create(request, context), but no overload for - // Build(Type type) exists. - dynamic reflectedComposer = typeof(Fixture).GetMethod("Build").MakeGenericMethod(_type).Invoke(fixture, null); - return reflectedComposer.OmitAutoProperties().Create(request, context); + var type = request as Type; + if (type == null || type != _type) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + // This is the equivalent of _fixture.Build<_type>().OmitAutoProperties().Create(request, context), but no overload for + // Build(Type type) exists. + dynamic reflectedComposer = typeof(Fixture).GetMethod("Build").MakeGenericMethod(_type).Invoke(fixture, null); + return reflectedComposer.OmitAutoProperties().Create(request, context); + } + } + public class BuilderWithoutAutoProperties : ISpecimenBuilder + { + public object Create(object request, ISpecimenContext context) => + new BuilderWithoutAutoProperties(typeof(T)).Create(request, context); } } -public class BuilderWithoutAutoProperties : ISpecimenBuilder -{ - public object Create(object request, ISpecimenContext context) => - new BuilderWithoutAutoProperties(typeof(T)).Create(request, context); -} diff --git a/test/Common/AutoFixture/ControllerCustomization.cs b/test/Common/AutoFixture/ControllerCustomization.cs index f695f86b55..9592466aa5 100644 --- a/test/Common/AutoFixture/ControllerCustomization.cs +++ b/test/Common/AutoFixture/ControllerCustomization.cs @@ -2,31 +2,32 @@ using Microsoft.AspNetCore.Mvc; using Org.BouncyCastle.Security; -namespace Bit.Test.Common.AutoFixture; - -/// -/// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors. Still sets constructor dependencies. -/// -/// -public class ControllerCustomization : ICustomization +namespace Bit.Test.Common.AutoFixture { - private readonly Type _controllerType; - public ControllerCustomization(Type controllerType) + /// + /// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors. Still sets constructor dependencies. + /// + /// + public class ControllerCustomization : ICustomization { - if (!controllerType.IsAssignableTo(typeof(Controller))) + private readonly Type _controllerType; + public ControllerCustomization(Type controllerType) { - throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); + if (!controllerType.IsAssignableTo(typeof(Controller))) + { + throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); + } + + _controllerType = controllerType; } - _controllerType = controllerType; + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new BuilderWithoutAutoProperties(_controllerType)); + } } - - public void Customize(IFixture fixture) + public class ControllerCustomization : ICustomization where T : Controller { - fixture.Customizations.Add(new BuilderWithoutAutoProperties(_controllerType)); + public void Customize(IFixture fixture) => new ControllerCustomization(typeof(T)).Customize(fixture); } } -public class ControllerCustomization : ICustomization where T : Controller -{ - public void Customize(IFixture fixture) => new ControllerCustomization(typeof(T)).Customize(fixture); -} diff --git a/test/Common/AutoFixture/FixtureExtensions.cs b/test/Common/AutoFixture/FixtureExtensions.cs index 300967666e..162784a356 100644 --- a/test/Common/AutoFixture/FixtureExtensions.cs +++ b/test/Common/AutoFixture/FixtureExtensions.cs @@ -1,13 +1,14 @@ using AutoFixture; using AutoFixture.AutoNSubstitute; -namespace Bit.Test.Common.AutoFixture; - -public static class FixtureExtensions +namespace Bit.Test.Common.AutoFixture { - public static IFixture WithAutoNSubstitutions(this IFixture fixture) - => fixture.Customize(new AutoNSubstituteCustomization()); + public static class FixtureExtensions + { + public static IFixture WithAutoNSubstitutions(this IFixture fixture) + => fixture.Customize(new AutoNSubstituteCustomization()); - public static IFixture WithAutoNSubstitutionsAutoPopulatedProperties(this IFixture fixture) - => fixture.Customize(new AutoNSubstituteCustomization { ConfigureMembers = true }); + public static IFixture WithAutoNSubstitutionsAutoPopulatedProperties(this IFixture fixture) + => fixture.Customize(new AutoNSubstituteCustomization { ConfigureMembers = true }); + } } diff --git a/test/Common/AutoFixture/GlobalSettingsFixtures.cs b/test/Common/AutoFixture/GlobalSettingsFixtures.cs index 3a2a319eec..86f460909c 100644 --- a/test/Common/AutoFixture/GlobalSettingsFixtures.cs +++ b/test/Common/AutoFixture/GlobalSettingsFixtures.cs @@ -1,15 +1,16 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture; - -public class GlobalSettings : ICustomization +namespace Bit.Test.Common.AutoFixture { - public void Customize(IFixture fixture) + public class GlobalSettings : ICustomization { - fixture.Customize(composer => composer - .Without(s => s.BaseServiceUri) - .Without(s => s.Attachment) - .Without(s => s.Send) - .Without(s => s.DataProtection)); + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .Without(s => s.BaseServiceUri) + .Without(s => s.Attachment) + .Without(s => s.Send) + .Without(s => s.DataProtection)); + } } } diff --git a/test/Common/AutoFixture/ISutProvider.cs b/test/Common/AutoFixture/ISutProvider.cs index 1ce9c7a005..9f6b0b23a2 100644 --- a/test/Common/AutoFixture/ISutProvider.cs +++ b/test/Common/AutoFixture/ISutProvider.cs @@ -1,7 +1,8 @@ -namespace Bit.Test.Common.AutoFixture; - -public interface ISutProvider +namespace Bit.Test.Common.AutoFixture { - Type SutType { get; } - ISutProvider Create(); + public interface ISutProvider + { + Type SutType { get; } + ISutProvider Create(); + } } diff --git a/test/Common/AutoFixture/JsonDocumentFixtures.cs b/test/Common/AutoFixture/JsonDocumentFixtures.cs index df27aa8ce7..e39b7f9901 100644 --- a/test/Common/AutoFixture/JsonDocumentFixtures.cs +++ b/test/Common/AutoFixture/JsonDocumentFixtures.cs @@ -2,30 +2,31 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture.JsonDocumentFixtures; - -public class JsonDocumentCustomization : ICustomization, ISpecimenBuilder +namespace Bit.Test.Common.AutoFixture.JsonDocumentFixtures { - - public string Json { get; set; } - - public void Customize(IFixture fixture) + public class JsonDocumentCustomization : ICustomization, ISpecimenBuilder { - fixture.Customizations.Add(this); - } - public object Create(object request, ISpecimenContext context) - { - if (context == null) + public string Json { get; set; } + + public void Customize(IFixture fixture) { - throw new ArgumentNullException(nameof(context)); - } - var type = request as Type; - if (type == null || (type != typeof(JsonDocument))) - { - return new NoSpecimen(); + fixture.Customizations.Add(this); } - return JsonDocument.Parse(Json ?? "{}"); + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + var type = request as Type; + if (type == null || (type != typeof(JsonDocument))) + { + return new NoSpecimen(); + } + + return JsonDocument.Parse(Json ?? "{}"); + } } } diff --git a/test/Common/AutoFixture/SutProvider.cs b/test/Common/AutoFixture/SutProvider.cs index 3a3d6409ba..2b00ed0cfa 100644 --- a/test/Common/AutoFixture/SutProvider.cs +++ b/test/Common/AutoFixture/SutProvider.cs @@ -2,132 +2,133 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture; - -public class SutProvider : ISutProvider +namespace Bit.Test.Common.AutoFixture { - private Dictionary> _dependencies; - private readonly IFixture _fixture; - private readonly ConstructorParameterRelay _constructorParameterRelay; - - public TSut Sut { get; private set; } - public Type SutType => typeof(TSut); - - public SutProvider() : this(new Fixture()) { } - - public SutProvider(IFixture fixture) + public class SutProvider : ISutProvider { - _dependencies = new Dictionary>(); - _fixture = (fixture ?? new Fixture()).WithAutoNSubstitutions().Customize(new GlobalSettings()); - _constructorParameterRelay = new ConstructorParameterRelay(this, _fixture); - _fixture.Customizations.Add(_constructorParameterRelay); - } + private Dictionary> _dependencies; + private readonly IFixture _fixture; + private readonly ConstructorParameterRelay _constructorParameterRelay; - public SutProvider SetDependency(T dependency, string parameterName = "") - => SetDependency(typeof(T), dependency, parameterName); - public SutProvider SetDependency(Type dependencyType, object dependency, string parameterName = "") - { - if (_dependencies.ContainsKey(dependencyType)) + public TSut Sut { get; private set; } + public Type SutType => typeof(TSut); + + public SutProvider() : this(new Fixture()) { } + + public SutProvider(IFixture fixture) { - _dependencies[dependencyType][parameterName] = dependency; - } - else - { - _dependencies[dependencyType] = new Dictionary { { parameterName, dependency } }; + _dependencies = new Dictionary>(); + _fixture = (fixture ?? new Fixture()).WithAutoNSubstitutions().Customize(new GlobalSettings()); + _constructorParameterRelay = new ConstructorParameterRelay(this, _fixture); + _fixture.Customizations.Add(_constructorParameterRelay); } - return this; - } - - public T GetDependency(string parameterName = "") => (T)GetDependency(typeof(T), parameterName); - public object GetDependency(Type dependencyType, string parameterName = "") - { - if (DependencyIsSet(dependencyType, parameterName)) + public SutProvider SetDependency(T dependency, string parameterName = "") + => SetDependency(typeof(T), dependency, parameterName); + public SutProvider SetDependency(Type dependencyType, object dependency, string parameterName = "") { - return _dependencies[dependencyType][parameterName]; - } - else if (_dependencies.ContainsKey(dependencyType)) - { - var knownDependencies = _dependencies[dependencyType]; - if (knownDependencies.Values.Count == 1) + if (_dependencies.ContainsKey(dependencyType)) { - return _dependencies[dependencyType].Values.Single(); + _dependencies[dependencyType][parameterName] = dependency; } else { - throw new ArgumentException(string.Concat($"Dependency of type {dependencyType.Name} and name ", - $"{parameterName} does not exist. Available dependency names are: ", - string.Join(", ", knownDependencies.Keys))); + _dependencies[dependencyType] = new Dictionary { { parameterName, dependency } }; } - } - else - { - throw new ArgumentException($"Dependency of type {dependencyType.Name} and name {parameterName} has not been set."); - } - } - public void Reset() - { - _dependencies = new Dictionary>(); - Sut = default; - } - - ISutProvider ISutProvider.Create() => Create(); - public SutProvider Create() - { - Sut = _fixture.Create(); - return this; - } - - private bool DependencyIsSet(Type dependencyType, string parameterName = "") - => _dependencies.ContainsKey(dependencyType) && _dependencies[dependencyType].ContainsKey(parameterName); - - private object GetDefault(Type type) => type.IsValueType ? Activator.CreateInstance(type) : null; - - private class ConstructorParameterRelay : ISpecimenBuilder - { - private readonly SutProvider _sutProvider; - private readonly IFixture _fixture; - - public ConstructorParameterRelay(SutProvider sutProvider, IFixture fixture) - { - _sutProvider = sutProvider; - _fixture = fixture; + return this; } - public object Create(object request, ISpecimenContext context) + public T GetDependency(string parameterName = "") => (T)GetDependency(typeof(T), parameterName); + public object GetDependency(Type dependencyType, string parameterName = "") { - if (context == null) + if (DependencyIsSet(dependencyType, parameterName)) { - throw new ArgumentNullException(nameof(context)); + return _dependencies[dependencyType][parameterName]; } - if (!(request is ParameterInfo parameterInfo)) + else if (_dependencies.ContainsKey(dependencyType)) { - return new NoSpecimen(); + var knownDependencies = _dependencies[dependencyType]; + if (knownDependencies.Values.Count == 1) + { + return _dependencies[dependencyType].Values.Single(); + } + else + { + throw new ArgumentException(string.Concat($"Dependency of type {dependencyType.Name} and name ", + $"{parameterName} does not exist. Available dependency names are: ", + string.Join(", ", knownDependencies.Keys))); + } } - if (parameterInfo.Member.DeclaringType != typeof(T) || - parameterInfo.Member.MemberType != MemberTypes.Constructor) + else { - return new NoSpecimen(); + throw new ArgumentException($"Dependency of type {dependencyType.Name} and name {parameterName} has not been set."); + } + } + + public void Reset() + { + _dependencies = new Dictionary>(); + Sut = default; + } + + ISutProvider ISutProvider.Create() => Create(); + public SutProvider Create() + { + Sut = _fixture.Create(); + return this; + } + + private bool DependencyIsSet(Type dependencyType, string parameterName = "") + => _dependencies.ContainsKey(dependencyType) && _dependencies[dependencyType].ContainsKey(parameterName); + + private object GetDefault(Type type) => type.IsValueType ? Activator.CreateInstance(type) : null; + + private class ConstructorParameterRelay : ISpecimenBuilder + { + private readonly SutProvider _sutProvider; + private readonly IFixture _fixture; + + public ConstructorParameterRelay(SutProvider sutProvider, IFixture fixture) + { + _sutProvider = sutProvider; + _fixture = fixture; } - if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, parameterInfo.Name)) + public object Create(object request, ISpecimenContext context) { - return _sutProvider.GetDependency(parameterInfo.ParameterType, parameterInfo.Name); - } - // Return default type if set - else if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, "")) - { - return _sutProvider.GetDependency(parameterInfo.ParameterType, ""); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + if (!(request is ParameterInfo parameterInfo)) + { + return new NoSpecimen(); + } + if (parameterInfo.Member.DeclaringType != typeof(T) || + parameterInfo.Member.MemberType != MemberTypes.Constructor) + { + return new NoSpecimen(); + } + + if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, parameterInfo.Name)) + { + return _sutProvider.GetDependency(parameterInfo.ParameterType, parameterInfo.Name); + } + // Return default type if set + else if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, "")) + { + return _sutProvider.GetDependency(parameterInfo.ParameterType, ""); + } - // This is the equivalent of _fixture.Create, but no overload for - // Create(Type type) exists. - var dependency = new SpecimenContext(_fixture).Resolve(new SeededRequest(parameterInfo.ParameterType, - _sutProvider.GetDefault(parameterInfo.ParameterType))); - _sutProvider.SetDependency(parameterInfo.ParameterType, dependency, parameterInfo.Name); - return dependency; + // This is the equivalent of _fixture.Create, but no overload for + // Create(Type type) exists. + var dependency = new SpecimenContext(_fixture).Resolve(new SeededRequest(parameterInfo.ParameterType, + _sutProvider.GetDefault(parameterInfo.ParameterType))); + _sutProvider.SetDependency(parameterInfo.ParameterType, dependency, parameterInfo.Name); + return dependency; + } } } } diff --git a/test/Common/AutoFixture/SutProviderCustomization.cs b/test/Common/AutoFixture/SutProviderCustomization.cs index 5cbff6a718..1485923945 100644 --- a/test/Common/AutoFixture/SutProviderCustomization.cs +++ b/test/Common/AutoFixture/SutProviderCustomization.cs @@ -1,33 +1,34 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture.Attributes; - -public class SutProviderCustomization : ICustomization, ISpecimenBuilder +namespace Bit.Test.Common.AutoFixture.Attributes { - private IFixture _fixture = null; - - public object Create(object request, ISpecimenContext context) + public class SutProviderCustomization : ICustomization, ISpecimenBuilder { - if (context == null) + private IFixture _fixture = null; + + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } - if (!(request is Type typeRequest)) - { - return new NoSpecimen(); - } - if (!typeof(ISutProvider).IsAssignableFrom(typeRequest)) - { - return new NoSpecimen(); + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + if (!(request is Type typeRequest)) + { + return new NoSpecimen(); + } + if (!typeof(ISutProvider).IsAssignableFrom(typeRequest)) + { + return new NoSpecimen(); + } + + return ((ISutProvider)Activator.CreateInstance(typeRequest, _fixture)).Create(); } - return ((ISutProvider)Activator.CreateInstance(typeRequest, _fixture)).Create(); - } - - public void Customize(IFixture fixture) - { - _fixture = fixture; - fixture.Customizations.Add(this); + public void Customize(IFixture fixture) + { + _fixture = fixture; + fixture.Customizations.Add(this); + } } } diff --git a/test/Common/Helpers/AssertHelper.cs b/test/Common/Helpers/AssertHelper.cs index d690837ef0..7239cb6dba 100644 --- a/test/Common/Helpers/AssertHelper.cs +++ b/test/Common/Helpers/AssertHelper.cs @@ -7,222 +7,223 @@ using Microsoft.AspNetCore.Http; using Xunit; using Xunit.Sdk; -namespace Bit.Test.Common.Helpers; - -public static class AssertHelper +namespace Bit.Test.Common.Helpers { - public static void AssertPropertyEqual(object expected, object actual, params string[] excludedPropertyStrings) + public static class AssertHelper { - var relevantExcludedProperties = excludedPropertyStrings.Where(name => !name.Contains('.')).ToList(); - if (expected == null) + public static void AssertPropertyEqual(object expected, object actual, params string[] excludedPropertyStrings) { - Assert.Null(actual); - return; + var relevantExcludedProperties = excludedPropertyStrings.Where(name => !name.Contains('.')).ToList(); + if (expected == null) + { + Assert.Null(actual); + return; + } + + if (actual == null) + { + throw new Exception("Actual object is null but expected is not"); + } + + foreach (var expectedPropInfo in expected.GetType().GetProperties().Where(pi => !relevantExcludedProperties.Contains(pi.Name))) + { + var actualPropInfo = actual.GetType().GetProperty(expectedPropInfo.Name); + + if (actualPropInfo == null) + { + throw new Exception(string.Concat($"Expected actual object to contain a property named {expectedPropInfo.Name}, but it does not\n", + $"Expected:\n{JsonSerializer.Serialize(expected, JsonHelpers.Indented)}\n", + $"Actual:\n{JsonSerializer.Serialize(actual, JsonHelpers.Indented)}")); + } + + if (expectedPropInfo.PropertyType == typeof(string) || expectedPropInfo.PropertyType.IsValueType) + { + Assert.Equal(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual)); + } + else if (expectedPropInfo.PropertyType == typeof(JsonDocument) && actualPropInfo.PropertyType == typeof(JsonDocument)) + { + static string JsonDocString(PropertyInfo info, object obj) => JsonSerializer.Serialize(info.GetValue(obj)); + Assert.Equal(JsonDocString(expectedPropInfo, expected), JsonDocString(actualPropInfo, actual)); + } + else + { + var prefix = $"{expectedPropInfo.PropertyType.Name}."; + var nextExcludedProperties = excludedPropertyStrings.Where(name => name.StartsWith(prefix)) + .Select(name => name[prefix.Length..]).ToArray(); + AssertPropertyEqual(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual), nextExcludedProperties); + } + } } - if (actual == null) + private static Predicate AssertPropertyEqualPredicate(T expected, params string[] excludedPropertyStrings) => (actual) => { - throw new Exception("Actual object is null but expected is not"); + AssertPropertyEqual(expected, actual, excludedPropertyStrings); + return true; + }; + + public static Expression> AssertPropertyEqual(T expected, params string[] excludedPropertyStrings) => + (T actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); + + private static Predicate> AssertPropertyEqualPredicate(IEnumerable expected, params string[] excludedPropertyStrings) => (actual) => + { + // IEnumerable.Zip doesn't account for different lengths, we need to check this ourselves + if (actual.Count() != expected.Count()) + { + throw new Exception(string.Concat($"Actual IEnumerable does not have the expected length.\n", + $"Expected: {expected.Count()}\n", + $"Actual: {actual.Count()}")); + } + + var elements = expected.Zip(actual); + foreach (var (expectedEl, actualEl) in elements) + { + AssertPropertyEqual(expectedEl, actualEl, excludedPropertyStrings); + } + + return true; + }; + + public static Expression>> AssertPropertyEqual(IEnumerable expected, params string[] excludedPropertyStrings) => + (actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); + + private static Predicate AssertEqualExpectedPredicate(T expected) => (actual) => + { + Assert.Equal(expected, actual); + return true; + }; + + public static Expression> AssertEqualExpected(T expected) => + (T actual) => AssertEqualExpectedPredicate(expected)(actual); + + public static JsonElement AssertJsonProperty(JsonElement element, string propertyName, JsonValueKind jsonValueKind) + { + if (!element.TryGetProperty(propertyName, out var subElement)) + { + throw new XunitException($"Could not find property by name '{propertyName}'"); + } + + Assert.Equal(jsonValueKind, subElement.ValueKind); + return subElement; } - foreach (var expectedPropInfo in expected.GetType().GetProperties().Where(pi => !relevantExcludedProperties.Contains(pi.Name))) + public static void AssertEqualJson(JsonElement a, JsonElement b) { - var actualPropInfo = actual.GetType().GetProperty(expectedPropInfo.Name); - - if (actualPropInfo == null) + switch (a.ValueKind) { - throw new Exception(string.Concat($"Expected actual object to contain a property named {expectedPropInfo.Name}, but it does not\n", - $"Expected:\n{JsonSerializer.Serialize(expected, JsonHelpers.Indented)}\n", - $"Actual:\n{JsonSerializer.Serialize(actual, JsonHelpers.Indented)}")); - } - - if (expectedPropInfo.PropertyType == typeof(string) || expectedPropInfo.PropertyType.IsValueType) - { - Assert.Equal(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual)); - } - else if (expectedPropInfo.PropertyType == typeof(JsonDocument) && actualPropInfo.PropertyType == typeof(JsonDocument)) - { - static string JsonDocString(PropertyInfo info, object obj) => JsonSerializer.Serialize(info.GetValue(obj)); - Assert.Equal(JsonDocString(expectedPropInfo, expected), JsonDocString(actualPropInfo, actual)); - } - else - { - var prefix = $"{expectedPropInfo.PropertyType.Name}."; - var nextExcludedProperties = excludedPropertyStrings.Where(name => name.StartsWith(prefix)) - .Select(name => name[prefix.Length..]).ToArray(); - AssertPropertyEqual(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual), nextExcludedProperties); + case JsonValueKind.Array: + Assert.Equal(JsonValueKind.Array, b.ValueKind); + AssertEqualJsonArray(a, b); + break; + case JsonValueKind.Object: + Assert.Equal(JsonValueKind.Object, b.ValueKind); + AssertEqualJsonObject(a, b); + break; + case JsonValueKind.False: + Assert.Equal(JsonValueKind.False, b.ValueKind); + break; + case JsonValueKind.True: + Assert.Equal(JsonValueKind.True, b.ValueKind); + break; + case JsonValueKind.Number: + Assert.Equal(JsonValueKind.Number, b.ValueKind); + Assert.Equal(a.GetDouble(), b.GetDouble()); + break; + case JsonValueKind.String: + Assert.Equal(JsonValueKind.String, b.ValueKind); + Assert.Equal(a.GetString(), b.GetString()); + break; + case JsonValueKind.Null: + Assert.Equal(JsonValueKind.Null, b.ValueKind); + break; + default: + throw new XunitException($"Bad JsonValueKind '{a.ValueKind}'"); } } - } - private static Predicate AssertPropertyEqualPredicate(T expected, params string[] excludedPropertyStrings) => (actual) => - { - AssertPropertyEqual(expected, actual, excludedPropertyStrings); - return true; - }; - - public static Expression> AssertPropertyEqual(T expected, params string[] excludedPropertyStrings) => - (T actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); - - private static Predicate> AssertPropertyEqualPredicate(IEnumerable expected, params string[] excludedPropertyStrings) => (actual) => - { - // IEnumerable.Zip doesn't account for different lengths, we need to check this ourselves - if (actual.Count() != expected.Count()) + private static void AssertEqualJsonObject(JsonElement a, JsonElement b) { - throw new Exception(string.Concat($"Actual IEnumerable does not have the expected length.\n", - $"Expected: {expected.Count()}\n", - $"Actual: {actual.Count()}")); + Debug.Assert(a.ValueKind == JsonValueKind.Object && b.ValueKind == JsonValueKind.Object); + + var aObjectEnumerator = a.EnumerateObject(); + var bObjectEnumerator = b.EnumerateObject(); + + while (true) + { + var aCanMove = aObjectEnumerator.MoveNext(); + var bCanMove = bObjectEnumerator.MoveNext(); + + if (aCanMove) + { + Assert.True(bCanMove, $"a was able to enumerate over object '{a}' but b was NOT able to '{b}'"); + } + else + { + Assert.False(bCanMove, $"a was NOT able to enumerate over object '{a}' but b was able to '{b}'"); + } + + if (aCanMove == false && bCanMove == false) + { + // They both can't continue to enumerate at the same time, that is valid + break; + } + + var aProp = aObjectEnumerator.Current; + var bProp = bObjectEnumerator.Current; + + Assert.Equal(aProp.Name, bProp.Name); + // Recursion! + AssertEqualJson(aProp.Value, bProp.Value); + } } - var elements = expected.Zip(actual); - foreach (var (expectedEl, actualEl) in elements) + private static void AssertEqualJsonArray(JsonElement a, JsonElement b) { - AssertPropertyEqual(expectedEl, actualEl, excludedPropertyStrings); + Debug.Assert(a.ValueKind == JsonValueKind.Array && b.ValueKind == JsonValueKind.Array); + + var aArrayEnumerator = a.EnumerateArray(); + var bArrayEnumerator = b.EnumerateArray(); + + while (true) + { + var aCanMove = aArrayEnumerator.MoveNext(); + var bCanMove = bArrayEnumerator.MoveNext(); + + if (aCanMove) + { + Assert.True(bCanMove, $"a was able to enumerate over array '{a}' but b was NOT able to '{b}'"); + } + else + { + Assert.False(bCanMove, $"a was NOT able to enumerate over array '{a}' but b was able to '{b}'"); + } + + if (aCanMove == false && bCanMove == false) + { + // They both can't continue to enumerate at the same time, that is valid + break; + } + + var aElement = aArrayEnumerator.Current; + var bElement = bArrayEnumerator.Current; + + // Recursion! + AssertEqualJson(aElement, bElement); + } } - return true; - }; - - public static Expression>> AssertPropertyEqual(IEnumerable expected, params string[] excludedPropertyStrings) => - (actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); - - private static Predicate AssertEqualExpectedPredicate(T expected) => (actual) => - { - Assert.Equal(expected, actual); - return true; - }; - - public static Expression> AssertEqualExpected(T expected) => - (T actual) => AssertEqualExpectedPredicate(expected)(actual); - - public static JsonElement AssertJsonProperty(JsonElement element, string propertyName, JsonValueKind jsonValueKind) - { - if (!element.TryGetProperty(propertyName, out var subElement)) + public async static Task AssertResponseTypeIs(HttpContext context) { - throw new XunitException($"Could not find property by name '{propertyName}'"); + return await JsonSerializer.DeserializeAsync(context.Response.Body); } - Assert.Equal(jsonValueKind, subElement.ValueKind); - return subElement; - } + public static TimeSpan AssertRecent(DateTime dateTime, int skewSeconds = 2) + => AssertRecent(dateTime, TimeSpan.FromSeconds(skewSeconds)); - public static void AssertEqualJson(JsonElement a, JsonElement b) - { - switch (a.ValueKind) + public static TimeSpan AssertRecent(DateTime dateTime, TimeSpan skew) { - case JsonValueKind.Array: - Assert.Equal(JsonValueKind.Array, b.ValueKind); - AssertEqualJsonArray(a, b); - break; - case JsonValueKind.Object: - Assert.Equal(JsonValueKind.Object, b.ValueKind); - AssertEqualJsonObject(a, b); - break; - case JsonValueKind.False: - Assert.Equal(JsonValueKind.False, b.ValueKind); - break; - case JsonValueKind.True: - Assert.Equal(JsonValueKind.True, b.ValueKind); - break; - case JsonValueKind.Number: - Assert.Equal(JsonValueKind.Number, b.ValueKind); - Assert.Equal(a.GetDouble(), b.GetDouble()); - break; - case JsonValueKind.String: - Assert.Equal(JsonValueKind.String, b.ValueKind); - Assert.Equal(a.GetString(), b.GetString()); - break; - case JsonValueKind.Null: - Assert.Equal(JsonValueKind.Null, b.ValueKind); - break; - default: - throw new XunitException($"Bad JsonValueKind '{a.ValueKind}'"); + var difference = DateTime.UtcNow - dateTime; + Assert.True(difference < skew); + return difference; } } - - private static void AssertEqualJsonObject(JsonElement a, JsonElement b) - { - Debug.Assert(a.ValueKind == JsonValueKind.Object && b.ValueKind == JsonValueKind.Object); - - var aObjectEnumerator = a.EnumerateObject(); - var bObjectEnumerator = b.EnumerateObject(); - - while (true) - { - var aCanMove = aObjectEnumerator.MoveNext(); - var bCanMove = bObjectEnumerator.MoveNext(); - - if (aCanMove) - { - Assert.True(bCanMove, $"a was able to enumerate over object '{a}' but b was NOT able to '{b}'"); - } - else - { - Assert.False(bCanMove, $"a was NOT able to enumerate over object '{a}' but b was able to '{b}'"); - } - - if (aCanMove == false && bCanMove == false) - { - // They both can't continue to enumerate at the same time, that is valid - break; - } - - var aProp = aObjectEnumerator.Current; - var bProp = bObjectEnumerator.Current; - - Assert.Equal(aProp.Name, bProp.Name); - // Recursion! - AssertEqualJson(aProp.Value, bProp.Value); - } - } - - private static void AssertEqualJsonArray(JsonElement a, JsonElement b) - { - Debug.Assert(a.ValueKind == JsonValueKind.Array && b.ValueKind == JsonValueKind.Array); - - var aArrayEnumerator = a.EnumerateArray(); - var bArrayEnumerator = b.EnumerateArray(); - - while (true) - { - var aCanMove = aArrayEnumerator.MoveNext(); - var bCanMove = bArrayEnumerator.MoveNext(); - - if (aCanMove) - { - Assert.True(bCanMove, $"a was able to enumerate over array '{a}' but b was NOT able to '{b}'"); - } - else - { - Assert.False(bCanMove, $"a was NOT able to enumerate over array '{a}' but b was able to '{b}'"); - } - - if (aCanMove == false && bCanMove == false) - { - // They both can't continue to enumerate at the same time, that is valid - break; - } - - var aElement = aArrayEnumerator.Current; - var bElement = bArrayEnumerator.Current; - - // Recursion! - AssertEqualJson(aElement, bElement); - } - } - - public async static Task AssertResponseTypeIs(HttpContext context) - { - return await JsonSerializer.DeserializeAsync(context.Response.Body); - } - - public static TimeSpan AssertRecent(DateTime dateTime, int skewSeconds = 2) - => AssertRecent(dateTime, TimeSpan.FromSeconds(skewSeconds)); - - public static TimeSpan AssertRecent(DateTime dateTime, TimeSpan skew) - { - var difference = DateTime.UtcNow - dateTime; - Assert.True(difference < skew); - return difference; - } } diff --git a/test/Common/Helpers/BitAutoDataAttributeHelpers.cs b/test/Common/Helpers/BitAutoDataAttributeHelpers.cs index 32cacc49dc..aae8d72dcc 100644 --- a/test/Common/Helpers/BitAutoDataAttributeHelpers.cs +++ b/test/Common/Helpers/BitAutoDataAttributeHelpers.cs @@ -4,48 +4,49 @@ using AutoFixture.Kernel; using AutoFixture.Xunit2; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Test.Common.Helpers; - -public static class BitAutoDataAttributeHelpers +namespace Bit.Test.Common.Helpers { - public static IEnumerable GetData(MethodInfo testMethod, IFixture fixture, object[] fixedTestParameters) + public static class BitAutoDataAttributeHelpers { - var methodParameters = testMethod.GetParameters(); - var classCustomizations = testMethod.DeclaringType.GetCustomAttributes().Select(attr => attr.GetCustomization()); - var methodCustomizations = testMethod.GetCustomAttributes().Select(attr => attr.GetCustomization()); - - fixedTestParameters = fixedTestParameters ?? Array.Empty(); - - fixture = ApplyCustomizations(ApplyCustomizations(fixture, classCustomizations), methodCustomizations); - var missingParameters = methodParameters.Skip(fixedTestParameters.Length).Select(p => CustomizeAndCreate(p, fixture)); - - return new object[1][] { fixedTestParameters.Concat(missingParameters).ToArray() }; - } - - public static object CustomizeAndCreate(ParameterInfo p, IFixture fixture) - { - var customizations = p.GetCustomAttributes(typeof(CustomizeAttribute), false) - .OfType() - .Select(attr => attr.GetCustomization(p)); - - var context = new SpecimenContext(ApplyCustomizations(fixture, customizations)); - return context.Resolve(p); - } - - public static IFixture ApplyCustomizations(IFixture fixture, IEnumerable customizations) - { - var newFixture = new Fixture(); - - foreach (var customization in fixture.Customizations.Reverse().Select(b => b.ToCustomization())) + public static IEnumerable GetData(MethodInfo testMethod, IFixture fixture, object[] fixedTestParameters) { - newFixture.Customize(customization); + var methodParameters = testMethod.GetParameters(); + var classCustomizations = testMethod.DeclaringType.GetCustomAttributes().Select(attr => attr.GetCustomization()); + var methodCustomizations = testMethod.GetCustomAttributes().Select(attr => attr.GetCustomization()); + + fixedTestParameters = fixedTestParameters ?? Array.Empty(); + + fixture = ApplyCustomizations(ApplyCustomizations(fixture, classCustomizations), methodCustomizations); + var missingParameters = methodParameters.Skip(fixedTestParameters.Length).Select(p => CustomizeAndCreate(p, fixture)); + + return new object[1][] { fixedTestParameters.Concat(missingParameters).ToArray() }; } - foreach (var customization in customizations) + public static object CustomizeAndCreate(ParameterInfo p, IFixture fixture) { - newFixture.Customize(customization); + var customizations = p.GetCustomAttributes(typeof(CustomizeAttribute), false) + .OfType() + .Select(attr => attr.GetCustomization(p)); + + var context = new SpecimenContext(ApplyCustomizations(fixture, customizations)); + return context.Resolve(p); } - return newFixture; + public static IFixture ApplyCustomizations(IFixture fixture, IEnumerable customizations) + { + var newFixture = new Fixture(); + + foreach (var customization in fixture.Customizations.Reverse().Select(b => b.ToCustomization())) + { + newFixture.Customize(customization); + } + + foreach (var customization in customizations) + { + newFixture.Customize(customization); + } + + return newFixture; + } } } diff --git a/test/Common/Helpers/TestCaseHelper.cs b/test/Common/Helpers/TestCaseHelper.cs index 279229fc57..c31d66e17a 100644 --- a/test/Common/Helpers/TestCaseHelper.cs +++ b/test/Common/Helpers/TestCaseHelper.cs @@ -1,44 +1,45 @@ -namespace Bit.Test.Common.Helpers; - -public static class TestCaseHelper +namespace Bit.Test.Common.Helpers { - public static IEnumerable> GetCombinations(params T[] items) + public static class TestCaseHelper { - var count = Math.Pow(2, items.Length); - for (var i = 0; i < count; i++) + public static IEnumerable> GetCombinations(params T[] items) { - var str = Convert.ToString(i, 2).PadLeft(items.Length, '0'); - List combination = new(); - for (var j = 0; j < str.Length; j++) + var count = Math.Pow(2, items.Length); + for (var i = 0; i < count; i++) { - if (str[j] == '1') + var str = Convert.ToString(i, 2).PadLeft(items.Length, '0'); + List combination = new(); + for (var j = 0; j < str.Length; j++) { - combination.Add(items[j]); + if (str[j] == '1') + { + combination.Add(items[j]); + } } + yield return combination; } - yield return combination; - } - } - - public static IEnumerable> GetCombinationsOfMultipleLists(params IEnumerable[] optionLists) - { - if (!optionLists.Any()) - { - yield break; } - foreach (var item in optionLists.First()) + public static IEnumerable> GetCombinationsOfMultipleLists(params IEnumerable[] optionLists) { - var itemArray = new[] { item }; - - if (optionLists.Length == 1) + if (!optionLists.Any()) { - yield return itemArray; + yield break; } - foreach (var nextCombination in GetCombinationsOfMultipleLists(optionLists.Skip(1).ToArray())) + foreach (var item in optionLists.First()) { - yield return itemArray.Concat(nextCombination); + var itemArray = new[] { item }; + + if (optionLists.Length == 1) + { + yield return itemArray; + } + + foreach (var nextCombination in GetCombinationsOfMultipleLists(optionLists.Skip(1).ToArray())) + { + yield return itemArray.Concat(nextCombination); + } } } } diff --git a/test/Common/Test/TestCaseHelperTests.cs b/test/Common/Test/TestCaseHelperTests.cs index 4d18aa76e9..697899813c 100644 --- a/test/Common/Test/TestCaseHelperTests.cs +++ b/test/Common/Test/TestCaseHelperTests.cs @@ -1,50 +1,51 @@ using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Test.Common.Test; - -public class TestCaseHelperTests +namespace Bit.Test.Common.Test { - [Fact] - public void GetCombinations_EmptyList() + public class TestCaseHelperTests { - Assert.Equal(new[] { Array.Empty() }, TestCaseHelper.GetCombinations(Array.Empty()).ToArray()); - } + [Fact] + public void GetCombinations_EmptyList() + { + Assert.Equal(new[] { Array.Empty() }, TestCaseHelper.GetCombinations(Array.Empty()).ToArray()); + } - [Fact] - public void GetCombinations_OneItemList() - { - Assert.Equal(new[] { Array.Empty(), new[] { 1 } }, TestCaseHelper.GetCombinations(1)); - } + [Fact] + public void GetCombinations_OneItemList() + { + Assert.Equal(new[] { Array.Empty(), new[] { 1 } }, TestCaseHelper.GetCombinations(1)); + } - [Fact] - public void GetCombinations_TwoItemList() - { - Assert.Equal(new[] { Array.Empty(), new[] { 2 }, new[] { 1 }, new[] { 1, 2 } }, TestCaseHelper.GetCombinations(1, 2)); - } + [Fact] + public void GetCombinations_TwoItemList() + { + Assert.Equal(new[] { Array.Empty(), new[] { 2 }, new[] { 1 }, new[] { 1, 2 } }, TestCaseHelper.GetCombinations(1, 2)); + } - [Fact] - public void GetCombinationsOfMultipleLists_OneOne() - { - Assert.Equal(new[] { new object[] { 1, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_OneOne() + { + Assert.Equal(new[] { new object[] { 1, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1" })); + } - [Fact] - public void GetCombinationsOfMultipleLists_OneTwo() - { - Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1", "2" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_OneTwo() + { + Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1", "2" })); + } - [Fact] - public void GetCombinationsOfMultipleLists_TwoOne() - { - Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 2, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_TwoOne() + { + Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 2, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1" })); + } - [Fact] - public void GetCombinationsOfMultipleLists_TwoTwo() - { - Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" }, new object[] { 2, "1" }, new object[] { 2, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1", "2" })); + [Fact] + public void GetCombinationsOfMultipleLists_TwoTwo() + { + Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" }, new object[] { 2, "1" }, new object[] { 2, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1", "2" })); + } } } diff --git a/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs b/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs index 269988272b..90981704f7 100644 --- a/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs +++ b/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs @@ -1,13 +1,14 @@ -namespace Bit.Core.Test.AutoFixture.Attributes; - -public sealed class CiSkippedTheory : Xunit.TheoryAttribute +namespace Bit.Core.Test.AutoFixture.Attributes { - private static bool IsGithubActions() => Environment.GetEnvironmentVariable("CI") != null; - public CiSkippedTheory() + public sealed class CiSkippedTheory : Xunit.TheoryAttribute { - if (IsGithubActions()) + private static bool IsGithubActions() => Environment.GetEnvironmentVariable("CI") != null; + public CiSkippedTheory() { - Skip = "Ignore during CI builds"; + if (IsGithubActions()) + { + Skip = "Ignore during CI builds"; + } } } } diff --git a/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs b/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs index ef18dcd5f9..7b41f76bea 100644 --- a/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs +++ b/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs @@ -2,31 +2,32 @@ using AutoFixture.Dsl; using Bit.Core.Models.Data; -namespace Bit.Core.Test.AutoFixture.CipherAttachmentMetaData; - -public class MetaData : ICustomization +namespace Bit.Core.Test.AutoFixture.CipherAttachmentMetaData { - protected virtual IPostprocessComposer ComposerAction(IFixture fixture, - ICustomizationComposer composer) + public class MetaData : ICustomization { - return composer.With(d => d.Size, fixture.Create()); + protected virtual IPostprocessComposer ComposerAction(IFixture fixture, + ICustomizationComposer composer) + { + return composer.With(d => d.Size, fixture.Create()); + } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => ComposerAction(fixture, composer)); + } } - public void Customize(IFixture fixture) + + public class MetaDataWithoutContainer : MetaData { - fixture.Customize(composer => ComposerAction(fixture, composer)); + protected override IPostprocessComposer ComposerAction(IFixture fixture, + ICustomizationComposer composer) => + base.ComposerAction(fixture, composer).With(d => d.ContainerName, (string)null); + } + + public class MetaDataWithoutKey : MetaDataWithoutContainer + { + protected override IPostprocessComposer ComposerAction(IFixture fixture, + ICustomizationComposer composer) => + base.ComposerAction(fixture, composer).Without(d => d.Key); } } - -public class MetaDataWithoutContainer : MetaData -{ - protected override IPostprocessComposer ComposerAction(IFixture fixture, - ICustomizationComposer composer) => - base.ComposerAction(fixture, composer).With(d => d.ContainerName, (string)null); -} - -public class MetaDataWithoutKey : MetaDataWithoutContainer -{ - protected override IPostprocessComposer ComposerAction(IFixture fixture, - ICustomizationComposer composer) => - base.ComposerAction(fixture, composer).Without(d => d.Key); -} diff --git a/test/Core.Test/AutoFixture/CipherFixtures.cs b/test/Core.Test/AutoFixture/CipherFixtures.cs index b4c87ef8a7..5ef2976fbc 100644 --- a/test/Core.Test/AutoFixture/CipherFixtures.cs +++ b/test/Core.Test/AutoFixture/CipherFixtures.cs @@ -3,66 +3,67 @@ using Bit.Core.Entities; using Bit.Test.Common.AutoFixture.Attributes; using Core.Models.Data; -namespace Bit.Core.Test.AutoFixture.CipherFixtures; - -internal class OrganizationCipher : ICustomization +namespace Bit.Core.Test.AutoFixture.CipherFixtures { - public Guid? OrganizationId { get; set; } - public void Customize(IFixture fixture) + internal class OrganizationCipher : ICustomization { - fixture.Customize(composer => composer - .With(c => c.OrganizationId, OrganizationId ?? Guid.NewGuid()) - .Without(c => c.UserId)); - fixture.Customize(composer => composer - .With(c => c.OrganizationId, Guid.NewGuid()) - .Without(c => c.UserId)); + public Guid? OrganizationId { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(c => c.OrganizationId, OrganizationId ?? Guid.NewGuid()) + .Without(c => c.UserId)); + fixture.Customize(composer => composer + .With(c => c.OrganizationId, Guid.NewGuid()) + .Without(c => c.UserId)); + } + } + + internal class UserCipher : ICustomization + { + public Guid? UserId { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(c => c.UserId, UserId ?? Guid.NewGuid()) + .Without(c => c.OrganizationId)); + fixture.Customize(composer => composer + .With(c => c.UserId, Guid.NewGuid()) + .Without(c => c.OrganizationId)); + } + } + + internal class UserCipherAutoDataAttribute : CustomAutoDataAttribute + { + public UserCipherAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), + new UserCipher { UserId = userId == null ? (Guid?)null : new Guid(userId) }) + { } + } + internal class InlineUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineUserCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(UserCipher) }, values) + { } + } + + internal class InlineKnownUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineKnownUserCipherAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] + { new SutProviderCustomization(), new UserCipher { UserId = new Guid(userId) } }, values) + { } + } + + internal class OrganizationCipherAutoDataAttribute : CustomAutoDataAttribute + { + public OrganizationCipherAutoDataAttribute(string organizationId = null) : base(new SutProviderCustomization(), + new OrganizationCipher { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) + { } + } + + internal class InlineOrganizationCipherAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineOrganizationCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(OrganizationCipher) }, values) + { } } } - -internal class UserCipher : ICustomization -{ - public Guid? UserId { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(c => c.UserId, UserId ?? Guid.NewGuid()) - .Without(c => c.OrganizationId)); - fixture.Customize(composer => composer - .With(c => c.UserId, Guid.NewGuid()) - .Without(c => c.OrganizationId)); - } -} - -internal class UserCipherAutoDataAttribute : CustomAutoDataAttribute -{ - public UserCipherAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), - new UserCipher { UserId = userId == null ? (Guid?)null : new Guid(userId) }) - { } -} -internal class InlineUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineUserCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(UserCipher) }, values) - { } -} - -internal class InlineKnownUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineKnownUserCipherAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] - { new SutProviderCustomization(), new UserCipher { UserId = new Guid(userId) } }, values) - { } -} - -internal class OrganizationCipherAutoDataAttribute : CustomAutoDataAttribute -{ - public OrganizationCipherAutoDataAttribute(string organizationId = null) : base(new SutProviderCustomization(), - new OrganizationCipher { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) - { } -} - -internal class InlineOrganizationCipherAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineOrganizationCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(OrganizationCipher) }, values) - { } -} diff --git a/test/Core.Test/AutoFixture/CollectionFixtures.cs b/test/Core.Test/AutoFixture/CollectionFixtures.cs index 26c169a443..38517f5c01 100644 --- a/test/Core.Test/AutoFixture/CollectionFixtures.cs +++ b/test/Core.Test/AutoFixture/CollectionFixtures.cs @@ -1,10 +1,11 @@ using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.CollectionFixtures; - -internal class CollectionAutoDataAttribute : CustomAutoDataAttribute +namespace Bit.Core.Test.AutoFixture.CollectionFixtures { - public CollectionAutoDataAttribute() : base(new SutProviderCustomization(), new OrganizationCustomization()) - { } + internal class CollectionAutoDataAttribute : CustomAutoDataAttribute + { + public CollectionAutoDataAttribute() : base(new SutProviderCustomization(), new OrganizationCustomization()) + { } + } } diff --git a/test/Core.Test/AutoFixture/CurrentContextFixtures.cs b/test/Core.Test/AutoFixture/CurrentContextFixtures.cs index 1949dedd71..90187cf6e7 100644 --- a/test/Core.Test/AutoFixture/CurrentContextFixtures.cs +++ b/test/Core.Test/AutoFixture/CurrentContextFixtures.cs @@ -3,35 +3,36 @@ using AutoFixture.Kernel; using Bit.Core.Context; using Bit.Test.Common.AutoFixture; -namespace Bit.Core.Test.AutoFixture.CurrentContextFixtures; - -internal class CurrentContext : ICustomization +namespace Bit.Core.Test.AutoFixture.CurrentContextFixtures { - public void Customize(IFixture fixture) + internal class CurrentContext : ICustomization { - fixture.Customizations.Add(new CurrentContextBuilder()); - } -} - -internal class CurrentContextBuilder : ISpecimenBuilder -{ - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - if (!(request is Type typeRequest)) - { - return new NoSpecimen(); - } - if (typeof(ICurrentContext) != typeRequest) - { - return new NoSpecimen(); - } - - var obj = new Fixture().WithAutoNSubstitutions().Create(); - obj.Organizations = context.Create>(); - return obj; + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new CurrentContextBuilder()); + } + } + + internal class CurrentContextBuilder : ISpecimenBuilder + { + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + if (!(request is Type typeRequest)) + { + return new NoSpecimen(); + } + if (typeof(ICurrentContext) != typeRequest) + { + return new NoSpecimen(); + } + + var obj = new Fixture().WithAutoNSubstitutions().Create(); + obj.Organizations = context.Create>(); + return obj; + } } } diff --git a/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs b/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs index b9c053c290..a893840a4d 100644 --- a/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs +++ b/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs @@ -4,28 +4,29 @@ using AutoFixture.Kernel; using AutoFixture.Xunit2; using Bit.Core.Test.Helpers.Factories; -namespace Bit.Test.Common.AutoFixture; - -public class GlobalSettingsBuilder : ISpecimenBuilder +namespace Bit.Test.Common.AutoFixture { - public object Create(object request, ISpecimenContext context) + public class GlobalSettingsBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var pi = request as ParameterInfo; + var fixture = new Fixture(); + + if (pi == null || pi.ParameterType != typeof(Bit.Core.Settings.GlobalSettings)) + return new NoSpecimen(); + + return GlobalSettingsFactory.GlobalSettings; } + } - var pi = request as ParameterInfo; - var fixture = new Fixture(); - - if (pi == null || pi.ParameterType != typeof(Bit.Core.Settings.GlobalSettings)) - return new NoSpecimen(); - - return GlobalSettingsFactory.GlobalSettings; + public class GlobalSettingsCustomizeAttribute : CustomizeAttribute + { + public override ICustomization GetCustomization(ParameterInfo parameter) => new GlobalSettings(); } } - -public class GlobalSettingsCustomizeAttribute : CustomizeAttribute -{ - public override ICustomization GetCustomization(ParameterInfo parameter) => new GlobalSettings(); -} diff --git a/test/Core.Test/AutoFixture/GroupFixtures.cs b/test/Core.Test/AutoFixture/GroupFixtures.cs index 2501bbfc39..07b8f6a678 100644 --- a/test/Core.Test/AutoFixture/GroupFixtures.cs +++ b/test/Core.Test/AutoFixture/GroupFixtures.cs @@ -1,18 +1,19 @@ using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.GroupFixtures; - -internal class GroupOrganizationAutoDataAttribute : CustomAutoDataAttribute +namespace Bit.Core.Test.AutoFixture.GroupFixtures { - public GroupOrganizationAutoDataAttribute() : base( - new SutProviderCustomization(), new OrganizationCustomization { UseGroups = true }) - { } -} + internal class GroupOrganizationAutoDataAttribute : CustomAutoDataAttribute + { + public GroupOrganizationAutoDataAttribute() : base( + new SutProviderCustomization(), new OrganizationCustomization { UseGroups = true }) + { } + } -internal class GroupOrganizationNotUseGroupsAutoDataAttribute : CustomAutoDataAttribute -{ - public GroupOrganizationNotUseGroupsAutoDataAttribute() : base( - new SutProviderCustomization(), new OrganizationCustomization { UseGroups = false }) - { } + internal class GroupOrganizationNotUseGroupsAutoDataAttribute : CustomAutoDataAttribute + { + public GroupOrganizationNotUseGroupsAutoDataAttribute() : base( + new SutProviderCustomization(), new OrganizationCustomization { UseGroups = false }) + { } + } } diff --git a/test/Core.Test/AutoFixture/OrganizationFixtures.cs b/test/Core.Test/AutoFixture/OrganizationFixtures.cs index 0641cb29e5..c496471c1a 100644 --- a/test/Core.Test/AutoFixture/OrganizationFixtures.cs +++ b/test/Core.Test/AutoFixture/OrganizationFixtures.cs @@ -10,176 +10,177 @@ using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.OrganizationFixtures; - -public class OrganizationCustomization : ICustomization +namespace Bit.Core.Test.AutoFixture.OrganizationFixtures { - public bool UseGroups { get; set; } - - public void Customize(IFixture fixture) + public class OrganizationCustomization : ICustomization { - var organizationId = Guid.NewGuid(); - var maxConnections = (short)new Random().Next(10, short.MaxValue); + public bool UseGroups { get; set; } - fixture.Customize(composer => composer - .With(o => o.Id, organizationId) - .With(o => o.MaxCollections, maxConnections) - .With(o => o.UseGroups, UseGroups)); - - fixture.Customize(composer => - composer - .With(c => c.OrganizationId, organizationId) - .Without(o => o.CreationDate) - .Without(o => o.RevisionDate)); - - fixture.Customize(composer => composer.With(g => g.OrganizationId, organizationId)); - } -} - -internal class OrganizationBuilder : ISpecimenBuilder -{ - public object Create(object request, ISpecimenContext context) - { - if (context == null) + public void Customize(IFixture fixture) { - throw new ArgumentNullException(nameof(context)); + var organizationId = Guid.NewGuid(); + var maxConnections = (short)new Random().Next(10, short.MaxValue); + + fixture.Customize(composer => composer + .With(o => o.Id, organizationId) + .With(o => o.MaxCollections, maxConnections) + .With(o => o.UseGroups, UseGroups)); + + fixture.Customize(composer => + composer + .With(c => c.OrganizationId, organizationId) + .Without(o => o.CreationDate) + .Without(o => o.RevisionDate)); + + fixture.Customize(composer => composer.With(g => g.OrganizationId, organizationId)); } + } - var type = request as Type; - if (type == null || type != typeof(Organization)) + internal class OrganizationBuilder : ISpecimenBuilder + { + public object Create(object request, ISpecimenContext context) { - return new NoSpecimen(); + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var type = request as Type; + if (type == null || type != typeof(Organization)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var providers = fixture.Create>(); + var organization = new Fixture().WithAutoNSubstitutions().Create(); + organization.SetTwoFactorProviders(providers); + return organization; } - - var fixture = new Fixture(); - var providers = fixture.Create>(); - var organization = new Fixture().WithAutoNSubstitutions().Create(); - organization.SetTwoFactorProviders(providers); - return organization; } -} -internal class PaidOrganization : ICustomization -{ - public PlanType CheckedPlanType { get; set; } - public void Customize(IFixture fixture) + internal class PaidOrganization : ICustomization { - var validUpgradePlans = StaticStore.Plans.Where(p => p.Type != PlanType.Free && !p.Disabled).Select(p => p.Type).ToList(); - var lowestActivePaidPlan = validUpgradePlans.First(); - CheckedPlanType = CheckedPlanType.Equals(PlanType.Free) ? lowestActivePaidPlan : CheckedPlanType; - validUpgradePlans.Remove(lowestActivePaidPlan); - fixture.Customize(composer => composer - .With(o => o.PlanType, CheckedPlanType)); - fixture.Customize(composer => composer - .With(ou => ou.Plan, validUpgradePlans.First())); - } -} - -internal class FreeOrganization : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.PlanType, PlanType.Free)); - } -} - -internal class FreeOrganizationUpgrade : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.PlanType, PlanType.Free)); - - var plansToIgnore = new List { PlanType.Free, PlanType.Custom }; - var selectedPlan = StaticStore.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); - - fixture.Customize(composer => composer - .With(ou => ou.Plan, selectedPlan.Type) - .With(ou => ou.PremiumAccessAddon, selectedPlan.HasPremiumAccessOption)); - fixture.Customize(composer => composer - .Without(o => o.GatewaySubscriptionId)); - } -} - -internal class OrganizationInvite : ICustomization -{ - public OrganizationUserType InviteeUserType { get; set; } - public OrganizationUserType InvitorUserType { get; set; } - public string PermissionsBlob { get; set; } - public void Customize(IFixture fixture) - { - var organizationId = new Guid(); - PermissionsBlob = PermissionsBlob ?? JsonSerializer.Serialize(new Permissions(), new JsonSerializerOptions + public PlanType CheckedPlanType { get; set; } + public void Customize(IFixture fixture) { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - fixture.Customize(composer => composer - .With(o => o.Id, organizationId) - .With(o => o.Seats, (short)100)); - fixture.Customize(composer => composer - .With(ou => ou.OrganizationId, organizationId) - .With(ou => ou.Type, InvitorUserType) - .With(ou => ou.Permissions, PermissionsBlob)); - fixture.Customize(composer => composer - .With(oi => oi.Type, InviteeUserType)); + var validUpgradePlans = StaticStore.Plans.Where(p => p.Type != PlanType.Free && !p.Disabled).Select(p => p.Type).ToList(); + var lowestActivePaidPlan = validUpgradePlans.First(); + CheckedPlanType = CheckedPlanType.Equals(PlanType.Free) ? lowestActivePaidPlan : CheckedPlanType; + validUpgradePlans.Remove(lowestActivePaidPlan); + fixture.Customize(composer => composer + .With(o => o.PlanType, CheckedPlanType)); + fixture.Customize(composer => composer + .With(ou => ou.Plan, validUpgradePlans.First())); + } + } + + internal class FreeOrganization : ICustomization + { + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.PlanType, PlanType.Free)); + } + } + + internal class FreeOrganizationUpgrade : ICustomization + { + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.PlanType, PlanType.Free)); + + var plansToIgnore = new List { PlanType.Free, PlanType.Custom }; + var selectedPlan = StaticStore.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); + + fixture.Customize(composer => composer + .With(ou => ou.Plan, selectedPlan.Type) + .With(ou => ou.PremiumAccessAddon, selectedPlan.HasPremiumAccessOption)); + fixture.Customize(composer => composer + .Without(o => o.GatewaySubscriptionId)); + } + } + + internal class OrganizationInvite : ICustomization + { + public OrganizationUserType InviteeUserType { get; set; } + public OrganizationUserType InvitorUserType { get; set; } + public string PermissionsBlob { get; set; } + public void Customize(IFixture fixture) + { + var organizationId = new Guid(); + PermissionsBlob = PermissionsBlob ?? JsonSerializer.Serialize(new Permissions(), new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + fixture.Customize(composer => composer + .With(o => o.Id, organizationId) + .With(o => o.Seats, (short)100)); + fixture.Customize(composer => composer + .With(ou => ou.OrganizationId, organizationId) + .With(ou => ou.Type, InvitorUserType) + .With(ou => ou.Permissions, PermissionsBlob)); + fixture.Customize(composer => composer + .With(oi => oi.Type, InviteeUserType)); + } + } + + internal class PaidOrganizationAutoDataAttribute : CustomAutoDataAttribute + { + public PaidOrganizationAutoDataAttribute(PlanType planType) : base(new SutProviderCustomization(), + new PaidOrganization { CheckedPlanType = planType }) + { } + public PaidOrganizationAutoDataAttribute(int planType = 0) : this((PlanType)planType) { } + } + + internal class InlinePaidOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlinePaidOrganizationAutoDataAttribute(PlanType planType, object[] values) : base( + new ICustomization[] { new SutProviderCustomization(), new PaidOrganization { CheckedPlanType = planType } }, values) + { } + + public InlinePaidOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(PaidOrganization) }, values) + { } + } + + internal class InlineFreeOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineFreeOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(FreeOrganization) }, values) + { } + } + + internal class FreeOrganizationUpgradeAutoDataAttribute : CustomAutoDataAttribute + { + public FreeOrganizationUpgradeAutoDataAttribute() : base(new SutProviderCustomization(), new FreeOrganizationUpgrade()) + { } + } + + internal class InlineFreeOrganizationUpgradeAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineFreeOrganizationUpgradeAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(FreeOrganizationUpgrade) }, values) + { } + } + + internal class OrganizationInviteAutoDataAttribute : CustomAutoDataAttribute + { + public OrganizationInviteAutoDataAttribute(int inviteeUserType = 0, int invitorUserType = 0, string permissionsBlob = null) : base(new SutProviderCustomization(), + new OrganizationInvite + { + InviteeUserType = (OrganizationUserType)inviteeUserType, + InvitorUserType = (OrganizationUserType)invitorUserType, + PermissionsBlob = permissionsBlob, + }) + { } + } + + internal class InlineOrganizationInviteAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineOrganizationInviteAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(OrganizationInvite) }, values) + { } } } - -internal class PaidOrganizationAutoDataAttribute : CustomAutoDataAttribute -{ - public PaidOrganizationAutoDataAttribute(PlanType planType) : base(new SutProviderCustomization(), - new PaidOrganization { CheckedPlanType = planType }) - { } - public PaidOrganizationAutoDataAttribute(int planType = 0) : this((PlanType)planType) { } -} - -internal class InlinePaidOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlinePaidOrganizationAutoDataAttribute(PlanType planType, object[] values) : base( - new ICustomization[] { new SutProviderCustomization(), new PaidOrganization { CheckedPlanType = planType } }, values) - { } - - public InlinePaidOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(PaidOrganization) }, values) - { } -} - -internal class InlineFreeOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineFreeOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(FreeOrganization) }, values) - { } -} - -internal class FreeOrganizationUpgradeAutoDataAttribute : CustomAutoDataAttribute -{ - public FreeOrganizationUpgradeAutoDataAttribute() : base(new SutProviderCustomization(), new FreeOrganizationUpgrade()) - { } -} - -internal class InlineFreeOrganizationUpgradeAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineFreeOrganizationUpgradeAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(FreeOrganizationUpgrade) }, values) - { } -} - -internal class OrganizationInviteAutoDataAttribute : CustomAutoDataAttribute -{ - public OrganizationInviteAutoDataAttribute(int inviteeUserType = 0, int invitorUserType = 0, string permissionsBlob = null) : base(new SutProviderCustomization(), - new OrganizationInvite - { - InviteeUserType = (OrganizationUserType)inviteeUserType, - InvitorUserType = (OrganizationUserType)invitorUserType, - PermissionsBlob = permissionsBlob, - }) - { } -} - -internal class InlineOrganizationInviteAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineOrganizationInviteAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(OrganizationInvite) }, values) - { } -} diff --git a/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs b/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs index 66a7f52249..11c8cd8cb0 100644 --- a/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs +++ b/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs @@ -2,17 +2,18 @@ using Bit.Core.Models.Business; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture; - -public class OrganizationLicenseCustomizeAttribute : BitCustomizeAttribute +namespace Bit.Core.Test.AutoFixture { - public override ICustomization GetCustomization() => new OrganizationLicenseCustomization(); -} -public class OrganizationLicenseCustomization : ICustomization -{ - public void Customize(IFixture fixture) + public class OrganizationLicenseCustomizeAttribute : BitCustomizeAttribute { - fixture.Customize(composer => composer - .With(o => o.Signature, Guid.NewGuid().ToString().Replace('-', '+'))); + public override ICustomization GetCustomization() => new OrganizationLicenseCustomization(); + } + public class OrganizationLicenseCustomization : ICustomization + { + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.Signature, Guid.NewGuid().ToString().Replace('-', '+'))); + } } } diff --git a/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs b/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs index b9172ae707..a40c15917e 100644 --- a/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs +++ b/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs @@ -2,31 +2,32 @@ using Bit.Core.Entities; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures; - -public class OrganizationSponsorshipCustomizeAttribute : BitCustomizeAttribute +namespace Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures { - public bool ToDelete = false; - public override ICustomization GetCustomization() => ToDelete ? - new ToDeleteOrganizationSponsorship() : - new ValidOrganizationSponsorship(); -} - -public class ValidOrganizationSponsorship : ICustomization -{ - public void Customize(IFixture fixture) + public class OrganizationSponsorshipCustomizeAttribute : BitCustomizeAttribute { - fixture.Customize(composer => composer - .With(s => s.ToDelete, false) - .With(s => s.LastSyncDate, DateTime.UtcNow.AddDays(new Random().Next(-90, 0)))); - } -} - -public class ToDeleteOrganizationSponsorship : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(s => s.ToDelete, true)); + public bool ToDelete = false; + public override ICustomization GetCustomization() => ToDelete ? + new ToDeleteOrganizationSponsorship() : + new ValidOrganizationSponsorship(); + } + + public class ValidOrganizationSponsorship : ICustomization + { + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(s => s.ToDelete, false) + .With(s => s.LastSyncDate, DateTime.UtcNow.AddDays(new Random().Next(-90, 0)))); + } + } + + public class ToDeleteOrganizationSponsorship : ICustomization + { + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(s => s.ToDelete, true)); + } } } diff --git a/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs b/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs index 74bdbfc519..975b45313a 100644 --- a/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs +++ b/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs @@ -4,42 +4,43 @@ using AutoFixture.Xunit2; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Test.AutoFixture.OrganizationUserFixtures; - -public class OrganizationUserCustomization : ICustomization +namespace Bit.Core.Test.AutoFixture.OrganizationUserFixtures { - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - - public OrganizationUserCustomization(OrganizationUserStatusType status, OrganizationUserType type) + public class OrganizationUserCustomization : ICustomization { - Status = status; - Type = type; + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + + public OrganizationUserCustomization(OrganizationUserStatusType status, OrganizationUserType type) + { + Status = status; + Type = type; + } + + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.Type, Type) + .With(o => o.Status, Status)); + } } - public void Customize(IFixture fixture) + public class OrganizationUserAttribute : CustomizeAttribute { - fixture.Customize(composer => composer - .With(o => o.Type, Type) - .With(o => o.Status, Status)); - } -} - -public class OrganizationUserAttribute : CustomizeAttribute -{ - private readonly OrganizationUserStatusType _status; - private readonly OrganizationUserType _type; - - public OrganizationUserAttribute( - OrganizationUserStatusType status = OrganizationUserStatusType.Confirmed, - OrganizationUserType type = OrganizationUserType.User) - { - _status = status; - _type = type; - } - - public override ICustomization GetCustomization(ParameterInfo parameter) - { - return new OrganizationUserCustomization(_status, _type); + private readonly OrganizationUserStatusType _status; + private readonly OrganizationUserType _type; + + public OrganizationUserAttribute( + OrganizationUserStatusType status = OrganizationUserStatusType.Confirmed, + OrganizationUserType type = OrganizationUserType.User) + { + _status = status; + _type = type; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new OrganizationUserCustomization(_status, _type); + } } } diff --git a/test/Core.Test/AutoFixture/PolicyFixtures.cs b/test/Core.Test/AutoFixture/PolicyFixtures.cs index fb8109baf9..b3da0e6982 100644 --- a/test/Core.Test/AutoFixture/PolicyFixtures.cs +++ b/test/Core.Test/AutoFixture/PolicyFixtures.cs @@ -4,37 +4,38 @@ using AutoFixture.Xunit2; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Test.AutoFixture.PolicyFixtures; - -internal class PolicyCustomization : ICustomization +namespace Bit.Core.Test.AutoFixture.PolicyFixtures { - public PolicyType Type { get; set; } - - public PolicyCustomization(PolicyType type) + internal class PolicyCustomization : ICustomization { - Type = type; + public PolicyType Type { get; set; } + + public PolicyCustomization(PolicyType type) + { + Type = type; + } + + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.OrganizationId, Guid.NewGuid()) + .With(o => o.Type, Type) + .With(o => o.Enabled, true)); + } } - public void Customize(IFixture fixture) + public class PolicyAttribute : CustomizeAttribute { - fixture.Customize(composer => composer - .With(o => o.OrganizationId, Guid.NewGuid()) - .With(o => o.Type, Type) - .With(o => o.Enabled, true)); - } -} - -public class PolicyAttribute : CustomizeAttribute -{ - private readonly PolicyType _type; - - public PolicyAttribute(PolicyType type) - { - _type = type; - } - - public override ICustomization GetCustomization(ParameterInfo parameter) - { - return new PolicyCustomization(_type); + private readonly PolicyType _type; + + public PolicyAttribute(PolicyType type) + { + _type = type; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new PolicyCustomization(_type); + } } } diff --git a/test/Core.Test/AutoFixture/SendFixtures.cs b/test/Core.Test/AutoFixture/SendFixtures.cs index b7cdeeafd4..573f32288a 100644 --- a/test/Core.Test/AutoFixture/SendFixtures.cs +++ b/test/Core.Test/AutoFixture/SendFixtures.cs @@ -2,62 +2,63 @@ using Bit.Core.Entities; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.SendFixtures; - -internal class OrganizationSend : ICustomization +namespace Bit.Core.Test.AutoFixture.SendFixtures { - public Guid? OrganizationId { get; set; } - public void Customize(IFixture fixture) + internal class OrganizationSend : ICustomization { - fixture.Customize(composer => composer - .With(s => s.OrganizationId, OrganizationId ?? Guid.NewGuid()) - .Without(s => s.UserId)); + public Guid? OrganizationId { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(s => s.OrganizationId, OrganizationId ?? Guid.NewGuid()) + .Without(s => s.UserId)); + } + } + + internal class UserSend : ICustomization + { + public Guid? UserId { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(s => s.UserId, UserId ?? Guid.NewGuid()) + .Without(s => s.OrganizationId)); + } + } + + internal class UserSendAutoDataAttribute : CustomAutoDataAttribute + { + public UserSendAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), + new UserSend { UserId = userId == null ? (Guid?)null : new Guid(userId) }) + { } + } + internal class InlineUserSendAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineUserSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), + typeof(SutProviderCustomization), typeof(UserSend) }, values) + { } + } + + internal class InlineKnownUserSendAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineKnownUserSendAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] + { new CurrentContextFixtures.CurrentContext(), new SutProviderCustomization(), + new UserSend { UserId = new Guid(userId) } }, values) + { } + } + + internal class OrganizationSendAutoDataAttribute : CustomAutoDataAttribute + { + public OrganizationSendAutoDataAttribute(string organizationId = null) : base(new CurrentContextFixtures.CurrentContext(), + new SutProviderCustomization(), + new OrganizationSend { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) + { } + } + + internal class InlineOrganizationSendAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineOrganizationSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), + typeof(SutProviderCustomization), typeof(OrganizationSend) }, values) + { } } } - -internal class UserSend : ICustomization -{ - public Guid? UserId { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(s => s.UserId, UserId ?? Guid.NewGuid()) - .Without(s => s.OrganizationId)); - } -} - -internal class UserSendAutoDataAttribute : CustomAutoDataAttribute -{ - public UserSendAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), - new UserSend { UserId = userId == null ? (Guid?)null : new Guid(userId) }) - { } -} -internal class InlineUserSendAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineUserSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), - typeof(SutProviderCustomization), typeof(UserSend) }, values) - { } -} - -internal class InlineKnownUserSendAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineKnownUserSendAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] - { new CurrentContextFixtures.CurrentContext(), new SutProviderCustomization(), - new UserSend { UserId = new Guid(userId) } }, values) - { } -} - -internal class OrganizationSendAutoDataAttribute : CustomAutoDataAttribute -{ - public OrganizationSendAutoDataAttribute(string organizationId = null) : base(new CurrentContextFixtures.CurrentContext(), - new SutProviderCustomization(), - new OrganizationSend { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) - { } -} - -internal class InlineOrganizationSendAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineOrganizationSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), - typeof(SutProviderCustomization), typeof(OrganizationSend) }, values) - { } -} diff --git a/test/Core.Test/AutoFixture/UserFixtures.cs b/test/Core.Test/AutoFixture/UserFixtures.cs index 39221aafc5..98707938a2 100644 --- a/test/Core.Test/AutoFixture/UserFixtures.cs +++ b/test/Core.Test/AutoFixture/UserFixtures.cs @@ -6,48 +6,49 @@ using Bit.Core.Models; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture; -namespace Bit.Core.Test.AutoFixture.UserFixtures; - -public class UserBuilder : ISpecimenBuilder +namespace Bit.Core.Test.AutoFixture.UserFixtures { - public object Create(object request, ISpecimenContext context) + public class UserBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == typeof(User)) - { - var fixture = new Fixture(); - var providers = fixture.Create>(); - var user = fixture.WithAutoNSubstitutions().Create(); - user.SetTwoFactorProviders(providers); - return user; - } - else if (type == typeof(List)) - { - var fixture = new Fixture(); - var users = fixture.WithAutoNSubstitutions().CreateMany(2); - foreach (var user in users) + if (context == null) { - var providers = fixture.Create>(); - user.SetTwoFactorProviders(providers); + throw new ArgumentNullException(nameof(context)); } - return users; + + var type = request as Type; + if (type == typeof(User)) + { + var fixture = new Fixture(); + var providers = fixture.Create>(); + var user = fixture.WithAutoNSubstitutions().Create(); + user.SetTwoFactorProviders(providers); + return user; + } + else if (type == typeof(List)) + { + var fixture = new Fixture(); + var users = fixture.WithAutoNSubstitutions().CreateMany(2); + foreach (var user in users) + { + var providers = fixture.Create>(); + user.SetTwoFactorProviders(providers); + } + return users; + } + + return new NoSpecimen(); } - - return new NoSpecimen(); } -} -public class UserFixture : ICustomization -{ - public virtual void Customize(IFixture fixture) + public class UserFixture : ICustomization { - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); + public virtual void Customize(IFixture fixture) + { + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + } } } diff --git a/test/Core.Test/Entities/OrganizationTests.cs b/test/Core.Test/Entities/OrganizationTests.cs index 5a86c3fd01..c24d6effc0 100644 --- a/test/Core.Test/Entities/OrganizationTests.cs +++ b/test/Core.Test/Entities/OrganizationTests.cs @@ -5,95 +5,96 @@ using Bit.Core.Models; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Entities; - -public class OrganizationTests +namespace Bit.Core.Test.Entities { - private static readonly Dictionary _testConfig = new Dictionary() + public class OrganizationTests { - [TwoFactorProviderType.OrganizationDuo] = new TwoFactorProvider + private static readonly Dictionary _testConfig = new Dictionary() { - Enabled = true, - MetaData = new Dictionary + [TwoFactorProviderType.OrganizationDuo] = new TwoFactorProvider { - ["IKey"] = "IKey_value", - ["SKey"] = "SKey_value", - ["Host"] = "Host_value", - }, - } - }; - - - [Fact] - public void SetTwoFactorProviders_Success() - { - var organization = new Organization(); - organization.SetTwoFactorProviders(_testConfig); - - using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); - var root = jsonDocument.RootElement; - - var duo = AssertHelper.AssertJsonProperty(root, "6", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(duo, "Enabled", JsonValueKind.True); - var duoMetaData = AssertHelper.AssertJsonProperty(duo, "MetaData", JsonValueKind.Object); - var iKey = AssertHelper.AssertJsonProperty(duoMetaData, "IKey", JsonValueKind.String).GetString(); - Assert.Equal("IKey_value", iKey); - var sKey = AssertHelper.AssertJsonProperty(duoMetaData, "SKey", JsonValueKind.String).GetString(); - Assert.Equal("SKey_value", sKey); - var host = AssertHelper.AssertJsonProperty(duoMetaData, "Host", JsonValueKind.String).GetString(); - Assert.Equal("Host_value", host); - } - - [Fact] - public void GetTwoFactorProviders_Success() - { - // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading - // It intent is to mimic a storing of the entity in the database and it being read later - var tempOrganization = new Organization(); - tempOrganization.SetTwoFactorProviders(_testConfig); - var organization = new Organization - { - TwoFactorProviders = tempOrganization.TwoFactorProviders, + Enabled = true, + MetaData = new Dictionary + { + ["IKey"] = "IKey_value", + ["SKey"] = "SKey_value", + ["Host"] = "Host_value", + }, + } }; - var twoFactorProviders = organization.GetTwoFactorProviders(); - var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); - Assert.True(duo.Enabled); - Assert.NotNull(duo.MetaData); - var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); - Assert.Equal("IKey_value", iKey); - var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); - Assert.Equal("SKey_value", sKey); - var host = Assert.Contains("Host", (IDictionary)duo.MetaData); - Assert.Equal("Host_value", host); - } + [Fact] + public void SetTwoFactorProviders_Success() + { + var organization = new Organization(); + organization.SetTwoFactorProviders(_testConfig); - [Fact] - public void GetTwoFactorProviders_SavedWithName_Success() - { - var organization = new Organization(); - // This should save items with the string name of the enum and we will validate that we can read - // from that just incase some organizations have it saved that way. - organization.TwoFactorProviders = JsonSerializer.Serialize(_testConfig); + using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); + var root = jsonDocument.RootElement; - // Preliminary Asserts to make sure we are testing what we want to be testing - using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); - var root = jsonDocument.RootElement; - // This means it saved the enum as its string name - AssertHelper.AssertJsonProperty(root, "OrganizationDuo", JsonValueKind.Object); + var duo = AssertHelper.AssertJsonProperty(root, "6", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(duo, "Enabled", JsonValueKind.True); + var duoMetaData = AssertHelper.AssertJsonProperty(duo, "MetaData", JsonValueKind.Object); + var iKey = AssertHelper.AssertJsonProperty(duoMetaData, "IKey", JsonValueKind.String).GetString(); + Assert.Equal("IKey_value", iKey); + var sKey = AssertHelper.AssertJsonProperty(duoMetaData, "SKey", JsonValueKind.String).GetString(); + Assert.Equal("SKey_value", sKey); + var host = AssertHelper.AssertJsonProperty(duoMetaData, "Host", JsonValueKind.String).GetString(); + Assert.Equal("Host_value", host); + } - // Actual checks - var twoFactorProviders = organization.GetTwoFactorProviders(); + [Fact] + public void GetTwoFactorProviders_Success() + { + // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading + // It intent is to mimic a storing of the entity in the database and it being read later + var tempOrganization = new Organization(); + tempOrganization.SetTwoFactorProviders(_testConfig); + var organization = new Organization + { + TwoFactorProviders = tempOrganization.TwoFactorProviders, + }; - var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); - Assert.True(duo.Enabled); - Assert.NotNull(duo.MetaData); - var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); - Assert.Equal("IKey_value", iKey); - var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); - Assert.Equal("SKey_value", sKey); - var host = Assert.Contains("Host", (IDictionary)duo.MetaData); - Assert.Equal("Host_value", host); + var twoFactorProviders = organization.GetTwoFactorProviders(); + + var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); + Assert.True(duo.Enabled); + Assert.NotNull(duo.MetaData); + var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); + Assert.Equal("IKey_value", iKey); + var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); + Assert.Equal("SKey_value", sKey); + var host = Assert.Contains("Host", (IDictionary)duo.MetaData); + Assert.Equal("Host_value", host); + } + + [Fact] + public void GetTwoFactorProviders_SavedWithName_Success() + { + var organization = new Organization(); + // This should save items with the string name of the enum and we will validate that we can read + // from that just incase some organizations have it saved that way. + organization.TwoFactorProviders = JsonSerializer.Serialize(_testConfig); + + // Preliminary Asserts to make sure we are testing what we want to be testing + using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); + var root = jsonDocument.RootElement; + // This means it saved the enum as its string name + AssertHelper.AssertJsonProperty(root, "OrganizationDuo", JsonValueKind.Object); + + // Actual checks + var twoFactorProviders = organization.GetTwoFactorProviders(); + + var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); + Assert.True(duo.Enabled); + Assert.NotNull(duo.MetaData); + var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); + Assert.Equal("IKey_value", iKey); + var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); + Assert.Equal("SKey_value", sKey); + var host = Assert.Contains("Host", (IDictionary)duo.MetaData); + Assert.Equal("Host_value", host); + } } } diff --git a/test/Core.Test/Entities/UserTests.cs b/test/Core.Test/Entities/UserTests.cs index 8a1986cd94..c60b031da3 100644 --- a/test/Core.Test/Entities/UserTests.cs +++ b/test/Core.Test/Entities/UserTests.cs @@ -5,140 +5,141 @@ using Bit.Core.Models; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Entities; - -public class UserTests +namespace Bit.Core.Test.Entities { - // KB MB GB - public const long Multiplier = 1024 * 1024 * 1024; - - [Fact] - public void StorageBytesRemaining_HasMax_DoesNotHaveStorage_ReturnsMaxAsBytes() + public class UserTests { - short maxStorageGb = 1; + // KB MB GB + public const long Multiplier = 1024 * 1024 * 1024; - var user = new User + [Fact] + public void StorageBytesRemaining_HasMax_DoesNotHaveStorage_ReturnsMaxAsBytes() { - MaxStorageGb = maxStorageGb, - Storage = null, - }; + short maxStorageGb = 1; - var bytesRemaining = user.StorageBytesRemaining(); - - Assert.Equal(bytesRemaining, maxStorageGb * Multiplier); - } - - [Theory] - [InlineData(2, 1 * Multiplier, 1 * Multiplier)] - - public void StorageBytesRemaining_HasMax_HasStorage_ReturnRemainingStorage(short maxStorageGb, long storageBytes, long expectedRemainingBytes) - { - var user = new User - { - MaxStorageGb = maxStorageGb, - Storage = storageBytes, - }; - - var bytesRemaining = user.StorageBytesRemaining(); - - Assert.Equal(expectedRemainingBytes, bytesRemaining); - } - - private static readonly Dictionary _testTwoFactorConfig = new Dictionary - { - [TwoFactorProviderType.WebAuthn] = new TwoFactorProvider - { - Enabled = true, - MetaData = new Dictionary + var user = new User { - ["Item"] = "thing", - }, - }, - [TwoFactorProviderType.Email] = new TwoFactorProvider + MaxStorageGb = maxStorageGb, + Storage = null, + }; + + var bytesRemaining = user.StorageBytesRemaining(); + + Assert.Equal(bytesRemaining, maxStorageGb * Multiplier); + } + + [Theory] + [InlineData(2, 1 * Multiplier, 1 * Multiplier)] + + public void StorageBytesRemaining_HasMax_HasStorage_ReturnRemainingStorage(short maxStorageGb, long storageBytes, long expectedRemainingBytes) { - Enabled = false, - MetaData = new Dictionary + var user = new User { - ["Email"] = "test@email.com", - }, - }, - }; + MaxStorageGb = maxStorageGb, + Storage = storageBytes, + }; - [Fact] - public void SetTwoFactorProviders_Success() - { - var user = new User(); - user.SetTwoFactorProviders(_testTwoFactorConfig); + var bytesRemaining = user.StorageBytesRemaining(); - using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); - var root = jsonDocument.RootElement; + Assert.Equal(expectedRemainingBytes, bytesRemaining); + } - var webAuthn = AssertHelper.AssertJsonProperty(root, "7", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(webAuthn, "Enabled", JsonValueKind.True); - var webMetaData = AssertHelper.AssertJsonProperty(webAuthn, "MetaData", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(webMetaData, "Item", JsonValueKind.String); - - var email = AssertHelper.AssertJsonProperty(root, "1", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(email, "Enabled", JsonValueKind.False); - var emailMetaData = AssertHelper.AssertJsonProperty(email, "MetaData", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(emailMetaData, "Email", JsonValueKind.String); - } - - [Fact] - public void GetTwoFactorProviders_Success() - { - // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading - // It intent is to mimic a storing of the entity in the database and it being read later - var tempUser = new User(); - tempUser.SetTwoFactorProviders(_testTwoFactorConfig); - var user = new User + private static readonly Dictionary _testTwoFactorConfig = new Dictionary { - TwoFactorProviders = tempUser.TwoFactorProviders, + [TwoFactorProviderType.WebAuthn] = new TwoFactorProvider + { + Enabled = true, + MetaData = new Dictionary + { + ["Item"] = "thing", + }, + }, + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + Enabled = false, + MetaData = new Dictionary + { + ["Email"] = "test@email.com", + }, + }, }; - var twoFactorProviders = user.GetTwoFactorProviders(); + [Fact] + public void SetTwoFactorProviders_Success() + { + var user = new User(); + user.SetTwoFactorProviders(_testTwoFactorConfig); - var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); - Assert.True(webAuthn.Enabled); - Assert.NotNull(webAuthn.MetaData); - var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); - Assert.Equal("thing", webAuthnMetaDataItem); + using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); + var root = jsonDocument.RootElement; - var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); - Assert.False(email.Enabled); - Assert.NotNull(email.MetaData); - var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); - Assert.Equal("test@email.com", emailMetaDataEmail); - } + var webAuthn = AssertHelper.AssertJsonProperty(root, "7", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(webAuthn, "Enabled", JsonValueKind.True); + var webMetaData = AssertHelper.AssertJsonProperty(webAuthn, "MetaData", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(webMetaData, "Item", JsonValueKind.String); - [Fact] - public void GetTwoFactorProviders_SavedWithName_Success() - { - var user = new User(); - // This should save items with the string name of the enum and we will validate that we can read - // from that just incase some users have it saved that way. - user.TwoFactorProviders = JsonSerializer.Serialize(_testTwoFactorConfig); + var email = AssertHelper.AssertJsonProperty(root, "1", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(email, "Enabled", JsonValueKind.False); + var emailMetaData = AssertHelper.AssertJsonProperty(email, "MetaData", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(emailMetaData, "Email", JsonValueKind.String); + } - // Preliminary Asserts to make sure we are testing what we want to be testing - using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); - var root = jsonDocument.RootElement; - // This means it saved the enum as its string name - AssertHelper.AssertJsonProperty(root, "WebAuthn", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(root, "Email", JsonValueKind.Object); + [Fact] + public void GetTwoFactorProviders_Success() + { + // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading + // It intent is to mimic a storing of the entity in the database and it being read later + var tempUser = new User(); + tempUser.SetTwoFactorProviders(_testTwoFactorConfig); + var user = new User + { + TwoFactorProviders = tempUser.TwoFactorProviders, + }; - // Actual checks - var twoFactorProviders = user.GetTwoFactorProviders(); + var twoFactorProviders = user.GetTwoFactorProviders(); - var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); - Assert.True(webAuthn.Enabled); - Assert.NotNull(webAuthn.MetaData); - var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); - Assert.Equal("thing", webAuthnMetaDataItem); + var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); + Assert.True(webAuthn.Enabled); + Assert.NotNull(webAuthn.MetaData); + var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); + Assert.Equal("thing", webAuthnMetaDataItem); - var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); - Assert.False(email.Enabled); - Assert.NotNull(email.MetaData); - var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); - Assert.Equal("test@email.com", emailMetaDataEmail); + var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); + Assert.False(email.Enabled); + Assert.NotNull(email.MetaData); + var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); + Assert.Equal("test@email.com", emailMetaDataEmail); + } + + [Fact] + public void GetTwoFactorProviders_SavedWithName_Success() + { + var user = new User(); + // This should save items with the string name of the enum and we will validate that we can read + // from that just incase some users have it saved that way. + user.TwoFactorProviders = JsonSerializer.Serialize(_testTwoFactorConfig); + + // Preliminary Asserts to make sure we are testing what we want to be testing + using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); + var root = jsonDocument.RootElement; + // This means it saved the enum as its string name + AssertHelper.AssertJsonProperty(root, "WebAuthn", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(root, "Email", JsonValueKind.Object); + + // Actual checks + var twoFactorProviders = user.GetTwoFactorProviders(); + + var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); + Assert.True(webAuthn.Enabled); + Assert.NotNull(webAuthn.MetaData); + var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); + Assert.Equal("thing", webAuthnMetaDataItem); + + var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); + Assert.False(email.Enabled); + Assert.NotNull(email.MetaData); + var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); + Assert.Equal("test@email.com", emailMetaDataEmail); + } } } diff --git a/test/Core.Test/Helpers/Factories.cs b/test/Core.Test/Helpers/Factories.cs index 7761d5cb15..3d6523bc92 100644 --- a/test/Core.Test/Helpers/Factories.cs +++ b/test/Core.Test/Helpers/Factories.cs @@ -1,15 +1,16 @@ using Bit.Core.Settings; using Microsoft.Extensions.Configuration; -namespace Bit.Core.Test.Helpers.Factories; - -public static class GlobalSettingsFactory +namespace Bit.Core.Test.Helpers.Factories { - public static GlobalSettings GlobalSettings { get; } = new(); - static GlobalSettingsFactory() + public static class GlobalSettingsFactory { - var configBuilder = new ConfigurationBuilder().AddUserSecrets("bitwarden-Api"); - var Configuration = configBuilder.Build(); - ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); + public static GlobalSettings GlobalSettings { get; } = new(); + static GlobalSettingsFactory() + { + var configBuilder = new ConfigurationBuilder().AddUserSecrets("bitwarden-Api"); + var Configuration = configBuilder.Build(); + ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); + } } } diff --git a/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs b/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs index 7b1ad3892a..8a5de6898a 100644 --- a/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs +++ b/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs @@ -5,34 +5,35 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Identity; - -public class AuthenticationTokenProviderTests : BaseTokenProviderTests +namespace Bit.Core.Test.Identity { - public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Authenticator; - - public static IEnumerable CanGenerateTwoFactorTokenAsyncData - => SetupCanGenerateData( - ( - new Dictionary - { - ["Key"] = "stuff", - }, - true - ), - ( - new Dictionary - { - ["Key"] = "" - }, - false - ) - ); - - [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] - public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, - User user, SutProvider sutProvider) + public class AuthenticationTokenProviderTests : BaseTokenProviderTests { - await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); + public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Authenticator; + + public static IEnumerable CanGenerateTwoFactorTokenAsyncData + => SetupCanGenerateData( + ( + new Dictionary + { + ["Key"] = "stuff", + }, + true + ), + ( + new Dictionary + { + ["Key"] = "" + }, + false + ) + ); + + [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] + public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, + User user, SutProvider sutProvider) + { + await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); + } } } diff --git a/test/Core.Test/Identity/BaseTokenProviderTests.cs b/test/Core.Test/Identity/BaseTokenProviderTests.cs index 5a9e0316ed..9de8abbe55 100644 --- a/test/Core.Test/Identity/BaseTokenProviderTests.cs +++ b/test/Core.Test/Identity/BaseTokenProviderTests.cs @@ -11,82 +11,83 @@ using Microsoft.Extensions.Options; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Identity; - -[SutProviderCustomize] -public abstract class BaseTokenProviderTests - where T : IUserTwoFactorTokenProvider +namespace Bit.Core.Test.Identity { - public abstract TwoFactorProviderType TwoFactorProviderType { get; } - - #region Helpers - protected static IEnumerable SetupCanGenerateData(params (Dictionary MetaData, bool ExpectedResponse)[] data) + [SutProviderCustomize] + public abstract class BaseTokenProviderTests + where T : IUserTwoFactorTokenProvider { - return data.Select(d => - new object[] - { - d.MetaData, - d.ExpectedResponse, - }); - } + public abstract TwoFactorProviderType TwoFactorProviderType { get; } - protected virtual IUserService AdditionalSetup(SutProvider sutProvider, User user) - { - var userService = Substitute.For(); - - sutProvider.GetDependency() - .GetService(typeof(IUserService)) - .Returns(userService); - - SetupUserService(userService, user); - - return userService; - } - - protected virtual void SetupUserService(IUserService userService, User user) - { - userService - .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType, user) - .Returns(true); - } - - protected static UserManager SubstituteUserManager() - { - return new UserManager(Substitute.For>(), - Substitute.For>(), - Substitute.For>(), - Enumerable.Empty>(), - Enumerable.Empty>(), - Substitute.For(), - Substitute.For(), - Substitute.For(), - Substitute.For>>()); - } - - protected void MockDatabase(User user, Dictionary metaData) - { - var providers = new Dictionary + #region Helpers + protected static IEnumerable SetupCanGenerateData(params (Dictionary MetaData, bool ExpectedResponse)[] data) { - [TwoFactorProviderType] = new TwoFactorProvider + return data.Select(d => + new object[] + { + d.MetaData, + d.ExpectedResponse, + }); + } + + protected virtual IUserService AdditionalSetup(SutProvider sutProvider, User user) + { + var userService = Substitute.For(); + + sutProvider.GetDependency() + .GetService(typeof(IUserService)) + .Returns(userService); + + SetupUserService(userService, user); + + return userService; + } + + protected virtual void SetupUserService(IUserService userService, User user) + { + userService + .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType, user) + .Returns(true); + } + + protected static UserManager SubstituteUserManager() + { + return new UserManager(Substitute.For>(), + Substitute.For>(), + Substitute.For>(), + Enumerable.Empty>(), + Enumerable.Empty>(), + Substitute.For(), + Substitute.For(), + Substitute.For(), + Substitute.For>>()); + } + + protected void MockDatabase(User user, Dictionary metaData) + { + var providers = new Dictionary { - Enabled = true, - MetaData = metaData, - }, - }; + [TwoFactorProviderType] = new TwoFactorProvider + { + Enabled = true, + MetaData = metaData, + }, + }; - user.TwoFactorProviders = JsonHelpers.LegacySerialize(providers); - } - #endregion + user.TwoFactorProviders = JsonHelpers.LegacySerialize(providers); + } + #endregion - public virtual async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, - User user, SutProvider sutProvider) - { - var userManager = SubstituteUserManager(); - MockDatabase(user, metaData); + public virtual async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, + User user, SutProvider sutProvider) + { + var userManager = SubstituteUserManager(); + MockDatabase(user, metaData); - AdditionalSetup(sutProvider, user); + AdditionalSetup(sutProvider, user); - var response = await sutProvider.Sut.CanGenerateTwoFactorTokenAsync(userManager, user); - Assert.Equal(expectedResponse, response); + var response = await sutProvider.Sut.CanGenerateTwoFactorTokenAsync(userManager, user); + Assert.Equal(expectedResponse, response); + } } } diff --git a/test/Core.Test/Identity/EmailTokenProviderTests.cs b/test/Core.Test/Identity/EmailTokenProviderTests.cs index 707ed798df..b1b4712015 100644 --- a/test/Core.Test/Identity/EmailTokenProviderTests.cs +++ b/test/Core.Test/Identity/EmailTokenProviderTests.cs @@ -5,41 +5,42 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Identity; - -public class EmailTokenProviderTests : BaseTokenProviderTests +namespace Bit.Core.Test.Identity { - public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Email; - - public static IEnumerable CanGenerateTwoFactorTokenAsyncData - => SetupCanGenerateData( - ( - new Dictionary - { - ["Email"] = "test@email.com", - }, - true - ), - ( - new Dictionary - { - ["NotEmail"] = "value", - }, - false - ), - ( - new Dictionary - { - ["Email"] = "", - }, - false - ) - ); - - [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] - public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, - User user, SutProvider sutProvider) + public class EmailTokenProviderTests : BaseTokenProviderTests { - await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); + public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Email; + + public static IEnumerable CanGenerateTwoFactorTokenAsyncData + => SetupCanGenerateData( + ( + new Dictionary + { + ["Email"] = "test@email.com", + }, + true + ), + ( + new Dictionary + { + ["NotEmail"] = "value", + }, + false + ), + ( + new Dictionary + { + ["Email"] = "", + }, + false + ) + ); + + [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] + public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, + User user, SutProvider sutProvider) + { + await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); + } } } diff --git a/test/Core.Test/IdentityServer/TokenRetrievalTests.cs b/test/Core.Test/IdentityServer/TokenRetrievalTests.cs index 591427da94..071f4d9140 100644 --- a/test/Core.Test/IdentityServer/TokenRetrievalTests.cs +++ b/test/Core.Test/IdentityServer/TokenRetrievalTests.cs @@ -4,90 +4,91 @@ using Microsoft.Extensions.Primitives; using NSubstitute; using Xunit; -namespace Bit.Core.Test.IdentityServer; - -public class TokenRetrievalTests +namespace Bit.Core.Test.IdentityServer { - private readonly Func _sut = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); - - [Fact] - public void RetrieveToken_FromHeader_ReturnsToken() + public class TokenRetrievalTests { - // Arrange - var headers = new HeaderDictionary + private readonly Func _sut = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); + + [Fact] + public void RetrieveToken_FromHeader_ReturnsToken() { - { "Authorization", "Bearer test_value" }, - { "X-Test-Header", "random_value" } - }; + // Arrange + var headers = new HeaderDictionary + { + { "Authorization", "Bearer test_value" }, + { "X-Test-Header", "random_value" } + }; - var request = Substitute.For(); + var request = Substitute.For(); - request.Headers.Returns(headers); + request.Headers.Returns(headers); - // Act - var token = _sut(request); + // Act + var token = _sut(request); - // Assert - Assert.Equal("test_value", token); - } + // Assert + Assert.Equal("test_value", token); + } - [Fact] - public void RetrieveToken_FromQueryString_ReturnsToken() - { - // Arrange - var queryString = new Dictionary + [Fact] + public void RetrieveToken_FromQueryString_ReturnsToken() { - { "access_token", "test_value" }, - { "test-query", "random_value" } - }; + // Arrange + var queryString = new Dictionary + { + { "access_token", "test_value" }, + { "test-query", "random_value" } + }; - var request = Substitute.For(); - request.Query.Returns(new QueryCollection(queryString)); + var request = Substitute.For(); + request.Query.Returns(new QueryCollection(queryString)); - // Act - var token = _sut(request); + // Act + var token = _sut(request); - // Assert - Assert.Equal("test_value", token); - } + // Assert + Assert.Equal("test_value", token); + } - [Fact] - public void RetrieveToken_HasBoth_ReturnsHeaderToken() - { - // Arrange - var queryString = new Dictionary + [Fact] + public void RetrieveToken_HasBoth_ReturnsHeaderToken() { - { "access_token", "query_string_token" }, - { "test-query", "random_value" } - }; + // Arrange + var queryString = new Dictionary + { + { "access_token", "query_string_token" }, + { "test-query", "random_value" } + }; - var headers = new HeaderDictionary + var headers = new HeaderDictionary + { + { "Authorization", "Bearer header_token" }, + { "X-Test-Header", "random_value" } + }; + + var request = Substitute.For(); + request.Headers.Returns(headers); + request.Query.Returns(new QueryCollection(queryString)); + + // Act + var token = _sut(request); + + // Assert + Assert.Equal("header_token", token); + } + + [Fact] + public void RetrieveToken_NoToken_ReturnsNull() { - { "Authorization", "Bearer header_token" }, - { "X-Test-Header", "random_value" } - }; + // Arrange + var request = Substitute.For(); - var request = Substitute.For(); - request.Headers.Returns(headers); - request.Query.Returns(new QueryCollection(queryString)); + // Act + var token = _sut(request); - // Act - var token = _sut(request); - - // Assert - Assert.Equal("header_token", token); - } - - [Fact] - public void RetrieveToken_NoToken_ReturnsNull() - { - // Arrange - var request = Substitute.For(); - - // Act - var token = _sut(request); - - // Assert - Assert.Null(token); + // Assert + Assert.Null(token); + } } } diff --git a/test/Core.Test/Models/Business/BillingInfo.cs b/test/Core.Test/Models/Business/BillingInfo.cs index c6c1ae56fd..0023b4669b 100644 --- a/test/Core.Test/Models/Business/BillingInfo.cs +++ b/test/Core.Test/Models/Business/BillingInfo.cs @@ -1,22 +1,23 @@ using Bit.Core.Models.Business; using Xunit; -namespace Bit.Core.Test.Models.Business; - -public class BillingInfoTests +namespace Bit.Core.Test.Models.Business { - [Fact] - public void BillingInvoice_Amount_ShouldComeFrom_InvoiceTotal() + public class BillingInfoTests { - var invoice = new Stripe.Invoice + [Fact] + public void BillingInvoice_Amount_ShouldComeFrom_InvoiceTotal() { - AmountDue = 1000, - Total = 2000, - }; + var invoice = new Stripe.Invoice + { + AmountDue = 1000, + Total = 2000, + }; - var billingInvoice = new BillingInfo.BillingInvoice(invoice); + var billingInvoice = new BillingInfo.BillingInvoice(invoice); - // Should have been set from Total - Assert.Equal(20M, billingInvoice.Amount); + // Should have been set from Total + Assert.Equal(20M, billingInvoice.Amount); + } } } diff --git a/test/Core.Test/Models/Business/TaxInfoTests.cs b/test/Core.Test/Models/Business/TaxInfoTests.cs index 197948006e..124201b62f 100644 --- a/test/Core.Test/Models/Business/TaxInfoTests.cs +++ b/test/Core.Test/Models/Business/TaxInfoTests.cs @@ -1,114 +1,115 @@ using Bit.Core.Models.Business; using Xunit; -namespace Bit.Core.Test.Models.Business; - -public class TaxInfoTests +namespace Bit.Core.Test.Models.Business { - // PH = Placeholder - [Theory] - [InlineData(null, null, null, null)] - [InlineData("", "", null, null)] - [InlineData("PH", "", null, null)] - [InlineData("", "PH", null, null)] - [InlineData("AE", "PH", null, "ae_trn")] - [InlineData("AU", "PH", null, "au_abn")] - [InlineData("BR", "PH", null, "br_cnpj")] - [InlineData("CA", "PH", "bec", "ca_qst")] - [InlineData("CA", "PH", null, "ca_bn")] - [InlineData("CL", "PH", null, "cl_tin")] - [InlineData("AT", "PH", null, "eu_vat")] - [InlineData("BE", "PH", null, "eu_vat")] - [InlineData("BG", "PH", null, "eu_vat")] - [InlineData("CY", "PH", null, "eu_vat")] - [InlineData("CZ", "PH", null, "eu_vat")] - [InlineData("DE", "PH", null, "eu_vat")] - [InlineData("DK", "PH", null, "eu_vat")] - [InlineData("EE", "PH", null, "eu_vat")] - [InlineData("ES", "PH", null, "eu_vat")] - [InlineData("FI", "PH", null, "eu_vat")] - [InlineData("FR", "PH", null, "eu_vat")] - [InlineData("GB", "PH", null, "eu_vat")] - [InlineData("GR", "PH", null, "eu_vat")] - [InlineData("HR", "PH", null, "eu_vat")] - [InlineData("HU", "PH", null, "eu_vat")] - [InlineData("IE", "PH", null, "eu_vat")] - [InlineData("IT", "PH", null, "eu_vat")] - [InlineData("LT", "PH", null, "eu_vat")] - [InlineData("LU", "PH", null, "eu_vat")] - [InlineData("LV", "PH", null, "eu_vat")] - [InlineData("MT", "PH", null, "eu_vat")] - [InlineData("NL", "PH", null, "eu_vat")] - [InlineData("PL", "PH", null, "eu_vat")] - [InlineData("PT", "PH", null, "eu_vat")] - [InlineData("RO", "PH", null, "eu_vat")] - [InlineData("SE", "PH", null, "eu_vat")] - [InlineData("SI", "PH", null, "eu_vat")] - [InlineData("SK", "PH", null, "eu_vat")] - [InlineData("HK", "PH", null, "hk_br")] - [InlineData("IN", "PH", null, "in_gst")] - [InlineData("JP", "PH", null, "jp_cn")] - [InlineData("KR", "PH", null, "kr_brn")] - [InlineData("LI", "PH", null, "li_uid")] - [InlineData("MX", "PH", null, "mx_rfc")] - [InlineData("MY", "PH", null, "my_sst")] - [InlineData("NO", "PH", null, "no_vat")] - [InlineData("NZ", "PH", null, "nz_gst")] - [InlineData("RU", "PH", null, "ru_inn")] - [InlineData("SA", "PH", null, "sa_vat")] - [InlineData("SG", "PH", null, "sg_gst")] - [InlineData("TH", "PH", null, "th_vat")] - [InlineData("TW", "PH", null, "tw_vat")] - [InlineData("US", "PH", null, "us_ein")] - [InlineData("ZA", "PH", null, "za_vat")] - [InlineData("ABCDEF", "PH", null, null)] - public void GetTaxIdType_Success(string billingAddressCountry, - string taxIdNumber, - string billingAddressState, - string expectedTaxIdType) + public class TaxInfoTests { - var taxInfo = new TaxInfo + // PH = Placeholder + [Theory] + [InlineData(null, null, null, null)] + [InlineData("", "", null, null)] + [InlineData("PH", "", null, null)] + [InlineData("", "PH", null, null)] + [InlineData("AE", "PH", null, "ae_trn")] + [InlineData("AU", "PH", null, "au_abn")] + [InlineData("BR", "PH", null, "br_cnpj")] + [InlineData("CA", "PH", "bec", "ca_qst")] + [InlineData("CA", "PH", null, "ca_bn")] + [InlineData("CL", "PH", null, "cl_tin")] + [InlineData("AT", "PH", null, "eu_vat")] + [InlineData("BE", "PH", null, "eu_vat")] + [InlineData("BG", "PH", null, "eu_vat")] + [InlineData("CY", "PH", null, "eu_vat")] + [InlineData("CZ", "PH", null, "eu_vat")] + [InlineData("DE", "PH", null, "eu_vat")] + [InlineData("DK", "PH", null, "eu_vat")] + [InlineData("EE", "PH", null, "eu_vat")] + [InlineData("ES", "PH", null, "eu_vat")] + [InlineData("FI", "PH", null, "eu_vat")] + [InlineData("FR", "PH", null, "eu_vat")] + [InlineData("GB", "PH", null, "eu_vat")] + [InlineData("GR", "PH", null, "eu_vat")] + [InlineData("HR", "PH", null, "eu_vat")] + [InlineData("HU", "PH", null, "eu_vat")] + [InlineData("IE", "PH", null, "eu_vat")] + [InlineData("IT", "PH", null, "eu_vat")] + [InlineData("LT", "PH", null, "eu_vat")] + [InlineData("LU", "PH", null, "eu_vat")] + [InlineData("LV", "PH", null, "eu_vat")] + [InlineData("MT", "PH", null, "eu_vat")] + [InlineData("NL", "PH", null, "eu_vat")] + [InlineData("PL", "PH", null, "eu_vat")] + [InlineData("PT", "PH", null, "eu_vat")] + [InlineData("RO", "PH", null, "eu_vat")] + [InlineData("SE", "PH", null, "eu_vat")] + [InlineData("SI", "PH", null, "eu_vat")] + [InlineData("SK", "PH", null, "eu_vat")] + [InlineData("HK", "PH", null, "hk_br")] + [InlineData("IN", "PH", null, "in_gst")] + [InlineData("JP", "PH", null, "jp_cn")] + [InlineData("KR", "PH", null, "kr_brn")] + [InlineData("LI", "PH", null, "li_uid")] + [InlineData("MX", "PH", null, "mx_rfc")] + [InlineData("MY", "PH", null, "my_sst")] + [InlineData("NO", "PH", null, "no_vat")] + [InlineData("NZ", "PH", null, "nz_gst")] + [InlineData("RU", "PH", null, "ru_inn")] + [InlineData("SA", "PH", null, "sa_vat")] + [InlineData("SG", "PH", null, "sg_gst")] + [InlineData("TH", "PH", null, "th_vat")] + [InlineData("TW", "PH", null, "tw_vat")] + [InlineData("US", "PH", null, "us_ein")] + [InlineData("ZA", "PH", null, "za_vat")] + [InlineData("ABCDEF", "PH", null, null)] + public void GetTaxIdType_Success(string billingAddressCountry, + string taxIdNumber, + string billingAddressState, + string expectedTaxIdType) { - BillingAddressCountry = billingAddressCountry, - TaxIdNumber = taxIdNumber, - BillingAddressState = billingAddressState, - }; + var taxInfo = new TaxInfo + { + BillingAddressCountry = billingAddressCountry, + TaxIdNumber = taxIdNumber, + BillingAddressState = billingAddressState, + }; - Assert.Equal(expectedTaxIdType, taxInfo.TaxIdType); - } + Assert.Equal(expectedTaxIdType, taxInfo.TaxIdType); + } - [Fact] - public void GetTaxIdType_CreateOnce_ReturnCacheSecondTime() - { - var taxInfo = new TaxInfo + [Fact] + public void GetTaxIdType_CreateOnce_ReturnCacheSecondTime() { - BillingAddressCountry = "US", - TaxIdNumber = "PH", - BillingAddressState = null, - }; + var taxInfo = new TaxInfo + { + BillingAddressCountry = "US", + TaxIdNumber = "PH", + BillingAddressState = null, + }; - Assert.Equal("us_ein", taxInfo.TaxIdType); + Assert.Equal("us_ein", taxInfo.TaxIdType); - // Per the current spec even if the values change to something other than null it - // will return the cached version of TaxIdType. - taxInfo.BillingAddressCountry = "ZA"; + // Per the current spec even if the values change to something other than null it + // will return the cached version of TaxIdType. + taxInfo.BillingAddressCountry = "ZA"; - Assert.Equal("us_ein", taxInfo.TaxIdType); - } + Assert.Equal("us_ein", taxInfo.TaxIdType); + } - [Theory] - [InlineData(null, null, false)] - [InlineData("123", "US", true)] - [InlineData("123", "ZQ12", false)] - [InlineData(" ", "US", false)] - public void HasTaxId_ReturnsExpected(string taxIdNumber, string billingAddressCountry, bool expected) - { - var taxInfo = new TaxInfo + [Theory] + [InlineData(null, null, false)] + [InlineData("123", "US", true)] + [InlineData("123", "ZQ12", false)] + [InlineData(" ", "US", false)] + public void HasTaxId_ReturnsExpected(string taxIdNumber, string billingAddressCountry, bool expected) { - TaxIdNumber = taxIdNumber, - BillingAddressCountry = billingAddressCountry, - }; + var taxInfo = new TaxInfo + { + TaxIdNumber = taxIdNumber, + BillingAddressCountry = billingAddressCountry, + }; - Assert.Equal(expected, taxInfo.HasTaxId); + Assert.Equal(expected, taxInfo.HasTaxId); + } } } diff --git a/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs index 40e390c7d0..d334c7dfa9 100644 --- a/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs @@ -4,29 +4,30 @@ using Bit.Core.Models.Business.Tokenables; using Bit.Core.Tokens; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables; - -public class EmergencyAccessInviteTokenableTests +namespace Bit.Core.Test.Models.Business.Tokenables { - [Theory, AutoData] - public void SerializationSetsCorrectDateTime(EmergencyAccess emergencyAccess) + public class EmergencyAccessInviteTokenableTests { - var token = new EmergencyAccessInviteTokenable(emergencyAccess, 2); - Assert.Equal(Tokenable.FromToken(token.ToToken().ToString()).ExpirationDate, - token.ExpirationDate, - TimeSpan.FromMilliseconds(10)); - } - - [Fact] - public void IsInvalidIfIdentifierIsWrong() - { - var token = new EmergencyAccessInviteTokenable(DateTime.MaxValue) + [Theory, AutoData] + public void SerializationSetsCorrectDateTime(EmergencyAccess emergencyAccess) { - Email = "email", - Id = Guid.NewGuid(), - Identifier = "not correct" - }; + var token = new EmergencyAccessInviteTokenable(emergencyAccess, 2); + Assert.Equal(Tokenable.FromToken(token.ToToken().ToString()).ExpirationDate, + token.ExpirationDate, + TimeSpan.FromMilliseconds(10)); + } - Assert.False(token.Valid); + [Fact] + public void IsInvalidIfIdentifierIsWrong() + { + var token = new EmergencyAccessInviteTokenable(DateTime.MaxValue) + { + Email = "email", + Id = Guid.NewGuid(), + Identifier = "not correct" + }; + + Assert.False(token.Valid); + } } } diff --git a/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs index ce97cb8b3b..ce72fa8dcb 100644 --- a/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs @@ -5,83 +5,84 @@ using Bit.Core.Tokens; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables; - -public class HCaptchaTokenableTests +namespace Bit.Core.Test.Models.Business.Tokenables { - [Fact] - public void CanHandleNullUser() + public class HCaptchaTokenableTests { - var token = new HCaptchaTokenable(null); - - Assert.Equal(default, token.Id); - Assert.Equal(default, token.Email); - } - - [Fact] - public void TokenWithNullUserIsInvalid() - { - var token = new HCaptchaTokenable(null) + [Fact] + public void CanHandleNullUser() { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + var token = new HCaptchaTokenable(null); - Assert.False(token.Valid); - } + Assert.Equal(default, token.Id); + Assert.Equal(default, token.Email); + } - [Theory, BitAutoData] - public void TokenValidityCheckNullUserIdIsInvalid(User user) - { - var token = new HCaptchaTokenable(user) + [Fact] + public void TokenWithNullUserIsInvalid() { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + var token = new HCaptchaTokenable(null) + { + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.False(token.TokenIsValid(null)); - } + Assert.False(token.Valid); + } - [Theory, AutoData] - public void CanUpdateExpirationToNonStandard(User user) - { - var token = new HCaptchaTokenable(user) + [Theory, BitAutoData] + public void TokenValidityCheckNullUserIdIsInvalid(User user) { - ExpirationDate = DateTime.MinValue - }; + var token = new HCaptchaTokenable(user) + { + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.Equal(DateTime.MinValue, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } + Assert.False(token.TokenIsValid(null)); + } - [Theory, AutoData] - public void SetsDataFromUser(User user) - { - var token = new HCaptchaTokenable(user); - - Assert.Equal(user.Id, token.Id); - Assert.Equal(user.Email, token.Email); - } - - [Theory, AutoData] - public void SerializationSetsCorrectDateTime(User user) - { - var expectedDateTime = DateTime.UtcNow.AddHours(-5); - var token = new HCaptchaTokenable(user) + [Theory, AutoData] + public void CanUpdateExpirationToNonStandard(User user) { - ExpirationDate = expectedDateTime - }; + var token = new HCaptchaTokenable(user) + { + ExpirationDate = DateTime.MinValue + }; - var result = Tokenable.FromToken(token.ToToken()); + Assert.Equal(DateTime.MinValue, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } - Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } - - [Theory, AutoData] - public void IsInvalidIfIdentifierIsWrong(User user) - { - var token = new HCaptchaTokenable(user) + [Theory, AutoData] + public void SetsDataFromUser(User user) { - Identifier = "not correct" - }; + var token = new HCaptchaTokenable(user); - Assert.False(token.Valid); + Assert.Equal(user.Id, token.Id); + Assert.Equal(user.Email, token.Email); + } + + [Theory, AutoData] + public void SerializationSetsCorrectDateTime(User user) + { + var expectedDateTime = DateTime.UtcNow.AddHours(-5); + var token = new HCaptchaTokenable(user) + { + ExpirationDate = expectedDateTime + }; + + var result = Tokenable.FromToken(token.ToToken()); + + Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } + + [Theory, AutoData] + public void IsInvalidIfIdentifierIsWrong(User user) + { + var token = new HCaptchaTokenable(user) + { + Identifier = "not correct" + }; + + Assert.False(token.Valid); + } } } diff --git a/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs index 172d4c911b..fd39c196b6 100644 --- a/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs @@ -4,152 +4,153 @@ using Bit.Core.Models.Business.Tokenables; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables; - -public class OrganizationSponsorshipOfferTokenableTests +namespace Bit.Core.Test.Models.Business.Tokenables { - public static IEnumerable PlanSponsorshipTypes() => Enum.GetValues().Select(x => new object[] { x }); - - [Fact] - public void IsInvalidIfIdentifierIsWrong() + public class OrganizationSponsorshipOfferTokenableTests { - var token = new OrganizationSponsorshipOfferTokenable() + public static IEnumerable PlanSponsorshipTypes() => Enum.GetValues().Select(x => new object[] { x }); + + [Fact] + public void IsInvalidIfIdentifierIsWrong() { - Email = "email", - Id = Guid.NewGuid(), - Identifier = "not correct", - SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + var token = new OrganizationSponsorshipOfferTokenable() + { + Email = "email", + Id = Guid.NewGuid(), + Identifier = "not correct", + SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); + } - [Fact] - public void IsInvalidIfIdIsDefault() - { - var token = new OrganizationSponsorshipOfferTokenable() + [Fact] + public void IsInvalidIfIdIsDefault() { - Email = "email", - Id = default, - SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + var token = new OrganizationSponsorshipOfferTokenable() + { + Email = "email", + Id = default, + SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); + } - [Fact] - public void IsInvalidIfEmailIsEmpty() - { - var token = new OrganizationSponsorshipOfferTokenable() + [Fact] + public void IsInvalidIfEmailIsEmpty() { - Email = "", - Id = Guid.NewGuid(), - SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + var token = new OrganizationSponsorshipOfferTokenable() + { + Email = "", + Id = Guid.NewGuid(), + SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); + } - [Theory, BitAutoData] - public void IsValid_Success(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void IsValid_Success(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.True(token.IsValid(sponsorship, sponsorship.OfferedToEmail)); - } + Assert.True(token.IsValid(sponsorship, sponsorship.OfferedToEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresNonNullSponsorship(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void IsValid_RequiresNonNullSponsorship(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.False(token.IsValid(null, sponsorship.OfferedToEmail)); - } + Assert.False(token.IsValid(null, sponsorship.OfferedToEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresCurrentEmailToBeSameAsOfferedToEmail(OrganizationSponsorship sponsorship, string currentEmail) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void IsValid_RequiresCurrentEmailToBeSameAsOfferedToEmail(OrganizationSponsorship sponsorship, string currentEmail) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.False(token.IsValid(sponsorship, currentEmail)); - } + Assert.False(token.IsValid(sponsorship, currentEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresSameSponsorshipId(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) - { - sponsorship1.Id = sponsorship2.Id; + [Theory, BitAutoData] + public void IsValid_RequiresSameSponsorshipId(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) + { + sponsorship1.Id = sponsorship2.Id; - var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); + var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); - Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); - } + Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresSameEmail(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) - { - sponsorship1.OfferedToEmail = sponsorship2.OfferedToEmail; + [Theory, BitAutoData] + public void IsValid_RequiresSameEmail(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) + { + sponsorship1.OfferedToEmail = sponsorship2.OfferedToEmail; - var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); + var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); - Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); - } + Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); + } - [Theory, BitAutoData] - public void Constructor_GrabsIdFromSponsorship(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void Constructor_GrabsIdFromSponsorship(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.Equal(sponsorship.Id, token.Id); - } + Assert.Equal(sponsorship.Id, token.Id); + } - [Theory, BitAutoData] - public void Constructor_GrabsEmailFromSponsorshipOfferedToEmail(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void Constructor_GrabsEmailFromSponsorshipOfferedToEmail(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.Equal(sponsorship.OfferedToEmail, token.Email); - } + Assert.Equal(sponsorship.OfferedToEmail, token.Email); + } - [Theory, BitMemberAutoData(nameof(PlanSponsorshipTypes))] - public void Constructor_GrabsSponsorshipType(PlanSponsorshipType planSponsorshipType, - OrganizationSponsorship sponsorship) - { - sponsorship.PlanSponsorshipType = planSponsorshipType; - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitMemberAutoData(nameof(PlanSponsorshipTypes))] + public void Constructor_GrabsSponsorshipType(PlanSponsorshipType planSponsorshipType, + OrganizationSponsorship sponsorship) + { + sponsorship.PlanSponsorshipType = planSponsorshipType; + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.Equal(sponsorship.PlanSponsorshipType, token.SponsorshipType); - } + Assert.Equal(sponsorship.PlanSponsorshipType, token.SponsorshipType); + } - [Theory, BitAutoData] - public void Constructor_DefaultId_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.Id = default; + [Theory, BitAutoData] + public void Constructor_DefaultId_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.Id = default; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } - [Theory, BitAutoData] - public void Constructor_NoOfferedToEmail_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.OfferedToEmail = null; + [Theory, BitAutoData] + public void Constructor_NoOfferedToEmail_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.OfferedToEmail = null; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } - [Theory, BitAutoData] - public void Constructor_EmptyOfferedToEmail_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.OfferedToEmail = ""; + [Theory, BitAutoData] + public void Constructor_EmptyOfferedToEmail_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.OfferedToEmail = ""; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } - [Theory, BitAutoData] - public void Constructor_NoPlanSponsorshipType_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.PlanSponsorshipType = null; + [Theory, BitAutoData] + public void Constructor_NoPlanSponsorshipType_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.PlanSponsorshipType = null; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } } } diff --git a/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs index 0ec4b5a353..aef71e5bae 100644 --- a/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs @@ -5,84 +5,85 @@ using Bit.Core.Tokens; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables; - -public class SsoTokenableTests +namespace Bit.Core.Test.Models.Business.Tokenables { - [Fact] - public void CanHandleNullOrganization() + public class SsoTokenableTests { - var token = new SsoTokenable(null, default); - - Assert.Equal(default, token.OrganizationId); - Assert.Equal(default, token.DomainHint); - } - - [Fact] - public void TokenWithNullOrganizationIsInvalid() - { - var token = new SsoTokenable(null, 500) + [Fact] + public void CanHandleNullOrganization() { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + var token = new SsoTokenable(null, default); - Assert.False(token.Valid); - } + Assert.Equal(default, token.OrganizationId); + Assert.Equal(default, token.DomainHint); + } - [Theory, BitAutoData] - public void TokenValidityCheckNullOrganizationIsInvalid(Organization organization) - { - var token = new SsoTokenable(organization, 500) + [Fact] + public void TokenWithNullOrganizationIsInvalid() { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + var token = new SsoTokenable(null, 500) + { + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.False(token.TokenIsValid(null)); - } + Assert.False(token.Valid); + } - [Theory, AutoData] - public void SetsDataFromOrganization(Organization organization) - { - var token = new SsoTokenable(organization, default); - - Assert.Equal(organization.Id, token.OrganizationId); - Assert.Equal(organization.Identifier, token.DomainHint); - } - - [Fact] - public void SetsExpirationFromConstructor() - { - var expectedDateTime = DateTime.UtcNow.AddSeconds(500); - var token = new SsoTokenable(null, 500); - - Assert.Equal(expectedDateTime, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } - - [Theory, AutoData] - public void SerializationSetsCorrectDateTime(Organization organization) - { - var expectedDateTime = DateTime.UtcNow.AddHours(-5); - var token = new SsoTokenable(organization, default) + [Theory, BitAutoData] + public void TokenValidityCheckNullOrganizationIsInvalid(Organization organization) { - ExpirationDate = expectedDateTime - }; + var token = new SsoTokenable(organization, 500) + { + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - var result = Tokenable.FromToken(token.ToToken()); + Assert.False(token.TokenIsValid(null)); + } - Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } - - [Theory, AutoData] - public void TokenIsValidFailsWhenExpired(Organization organization) - { - var expectedDateTime = DateTime.UtcNow.AddHours(-5); - var token = new SsoTokenable(organization, default) + [Theory, AutoData] + public void SetsDataFromOrganization(Organization organization) { - ExpirationDate = expectedDateTime - }; + var token = new SsoTokenable(organization, default); - var result = token.TokenIsValid(organization); + Assert.Equal(organization.Id, token.OrganizationId); + Assert.Equal(organization.Identifier, token.DomainHint); + } - Assert.False(result); + [Fact] + public void SetsExpirationFromConstructor() + { + var expectedDateTime = DateTime.UtcNow.AddSeconds(500); + var token = new SsoTokenable(null, 500); + + Assert.Equal(expectedDateTime, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } + + [Theory, AutoData] + public void SerializationSetsCorrectDateTime(Organization organization) + { + var expectedDateTime = DateTime.UtcNow.AddHours(-5); + var token = new SsoTokenable(organization, default) + { + ExpirationDate = expectedDateTime + }; + + var result = Tokenable.FromToken(token.ToToken()); + + Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } + + [Theory, AutoData] + public void TokenIsValidFailsWhenExpired(Organization organization) + { + var expectedDateTime = DateTime.UtcNow.AddHours(-5); + var token = new SsoTokenable(organization, default) + { + ExpirationDate = expectedDateTime + }; + + var result = token.TokenIsValid(organization); + + Assert.False(result); + } } } diff --git a/test/Core.Test/Models/CipherTests.cs b/test/Core.Test/Models/CipherTests.cs index af7a0b6e3c..3993f4caf6 100644 --- a/test/Core.Test/Models/CipherTests.cs +++ b/test/Core.Test/Models/CipherTests.cs @@ -3,15 +3,16 @@ using Bit.Core.Entities; using Bit.Core.Test.AutoFixture.CipherFixtures; using Xunit; -namespace Bit.Core.Test.Models; - -public class CipherTests +namespace Bit.Core.Test.Models { - [Theory] - [InlineUserCipherAutoData] - [InlineOrganizationCipherAutoData] - public void Clone_CreatesExactCopy(Cipher cipher) + public class CipherTests { - Assert.Equal(JsonSerializer.Serialize(cipher), JsonSerializer.Serialize(cipher.Clone())); + [Theory] + [InlineUserCipherAutoData] + [InlineOrganizationCipherAutoData] + public void Clone_CreatesExactCopy(Cipher cipher) + { + Assert.Equal(JsonSerializer.Serialize(cipher), JsonSerializer.Serialize(cipher.Clone())); + } } } diff --git a/test/Core.Test/Models/Data/SendFileDataTests.cs b/test/Core.Test/Models/Data/SendFileDataTests.cs index 6f2afe7483..7a7dc9bc5c 100644 --- a/test/Core.Test/Models/Data/SendFileDataTests.cs +++ b/test/Core.Test/Models/Data/SendFileDataTests.cs @@ -3,25 +3,26 @@ using Bit.Core.Models.Data; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Models.Data; - -public class SendFileDataTests +namespace Bit.Core.Test.Models.Data { - [Fact] - public void Serialize_Success() + public class SendFileDataTests { - var sut = new SendFileData + [Fact] + public void Serialize_Success() { - Id = "test", - Size = 100, - FileName = "thing.pdf", - Validated = true, - }; + var sut = new SendFileData + { + Id = "test", + Size = 100, + FileName = "thing.pdf", + Validated = true, + }; - var json = JsonSerializer.Serialize(sut); - var document = JsonDocument.Parse(json); - var root = document.RootElement; - AssertHelper.AssertJsonProperty(root, "Size", JsonValueKind.String); - Assert.False(root.TryGetProperty("SizeString", out _)); + var json = JsonSerializer.Serialize(sut); + var document = JsonDocument.Parse(json); + var root = document.RootElement; + AssertHelper.AssertJsonProperty(root, "Size", JsonValueKind.String); + Assert.False(root.TryGetProperty("SizeString", out _)); + } } } diff --git a/test/Core.Test/Models/PermissionsTests.cs b/test/Core.Test/Models/PermissionsTests.cs index 76b88f6ff9..c8522eaa2a 100644 --- a/test/Core.Test/Models/PermissionsTests.cs +++ b/test/Core.Test/Models/PermissionsTests.cs @@ -3,58 +3,59 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Models; - -public class PermissionsTests +namespace Bit.Core.Test.Models { - private static readonly string _exampleSerializedPermissions = string.Concat( - "{", - "\"accessEventLogs\": false,", - "\"accessImportExport\": false,", - "\"accessReports\": false,", - "\"manageAllCollections\": true,", // exists for backwards compatibility - "\"createNewCollections\": true,", - "\"editAnyCollection\": true,", - "\"deleteAnyCollection\": true,", - "\"manageAssignedCollections\": false,", // exists for backwards compatibility - "\"editAssignedCollections\": false,", - "\"deleteAssignedCollections\": false,", - "\"manageGroups\": false,", - "\"managePolicies\": false,", - "\"manageSso\": false,", - "\"manageUsers\": false,", - "\"manageResetPassword\": false,", - "\"manageScim\": false", - "}"); - - [Fact] - public void Serialization_Success() + public class PermissionsTests { - var permissions = new Permissions + private static readonly string _exampleSerializedPermissions = string.Concat( + "{", + "\"accessEventLogs\": false,", + "\"accessImportExport\": false,", + "\"accessReports\": false,", + "\"manageAllCollections\": true,", // exists for backwards compatibility + "\"createNewCollections\": true,", + "\"editAnyCollection\": true,", + "\"deleteAnyCollection\": true,", + "\"manageAssignedCollections\": false,", // exists for backwards compatibility + "\"editAssignedCollections\": false,", + "\"deleteAssignedCollections\": false,", + "\"manageGroups\": false,", + "\"managePolicies\": false,", + "\"manageSso\": false,", + "\"manageUsers\": false,", + "\"manageResetPassword\": false,", + "\"manageScim\": false", + "}"); + + [Fact] + public void Serialization_Success() { - AccessEventLogs = false, - AccessImportExport = false, - AccessReports = false, - CreateNewCollections = true, - EditAnyCollection = true, - DeleteAnyCollection = true, - EditAssignedCollections = false, - DeleteAssignedCollections = false, - ManageGroups = false, - ManagePolicies = false, - ManageSso = false, - ManageUsers = false, - ManageResetPassword = false, - ManageScim = false, - }; + var permissions = new Permissions + { + AccessEventLogs = false, + AccessImportExport = false, + AccessReports = false, + CreateNewCollections = true, + EditAnyCollection = true, + DeleteAnyCollection = true, + EditAssignedCollections = false, + DeleteAssignedCollections = false, + ManageGroups = false, + ManagePolicies = false, + ManageSso = false, + ManageUsers = false, + ManageResetPassword = false, + ManageScim = false, + }; - // minify expected json - var expected = JsonSerializer.Serialize(permissions, JsonHelpers.CamelCase); + // minify expected json + var expected = JsonSerializer.Serialize(permissions, JsonHelpers.CamelCase); - var actual = JsonSerializer.Serialize( - JsonHelpers.DeserializeOrNew(_exampleSerializedPermissions, JsonHelpers.CamelCase), - JsonHelpers.CamelCase); + var actual = JsonSerializer.Serialize( + JsonHelpers.DeserializeOrNew(_exampleSerializedPermissions, JsonHelpers.CamelCase), + JsonHelpers.CamelCase); - Assert.Equal(expected, actual); + Assert.Equal(expected, actual); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs index de568f2656..e81d2bcc8a 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs @@ -7,93 +7,94 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys; - -[SutProviderCustomize] -public class GetOrganizationApiKeyCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys { - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_HasOne_Returns(SutProvider sutProvider, - Guid id, Guid organizationId, OrganizationApiKeyType keyType) + [SutProviderCustomize] + public class GetOrganizationApiKeyCommandTests { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organizationId, keyType) - .Returns(new List - { - new OrganizationApiKey + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_HasOne_Returns(SutProvider sutProvider, + Guid id, Guid organizationId, OrganizationApiKeyType keyType) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organizationId, keyType) + .Returns(new List { - Id = id, - OrganizationId = organizationId, - ApiKey = "test", - Type = keyType, - RevisionDate = DateTime.Now.AddDays(-1), - }, - }); + new OrganizationApiKey + { + Id = id, + OrganizationId = organizationId, + ApiKey = "test", + Type = keyType, + RevisionDate = DateTime.Now.AddDays(-1), + }, + }); - var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); - Assert.NotNull(apiKey); - Assert.Equal(id, apiKey.Id); - } + var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); + Assert.NotNull(apiKey); + Assert.Equal(id, apiKey.Id); + } - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_HasTwo_Throws(SutProvider sutProvider, - Guid organizationId, OrganizationApiKeyType keyType) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organizationId, keyType) - .Returns(new List - { - new OrganizationApiKey + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_HasTwo_Throws(SutProvider sutProvider, + Guid organizationId, OrganizationApiKeyType keyType) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organizationId, keyType) + .Returns(new List { - Id = Guid.NewGuid(), - OrganizationId = organizationId, - ApiKey = "test", - Type = keyType, - RevisionDate = DateTime.Now.AddDays(-1), - }, - new OrganizationApiKey - { - Id = Guid.NewGuid(), - OrganizationId = organizationId, - ApiKey = "test_other", - Type = keyType, - RevisionDate = DateTime.Now.AddDays(-1), - }, - }); + new OrganizationApiKey + { + Id = Guid.NewGuid(), + OrganizationId = organizationId, + ApiKey = "test", + Type = keyType, + RevisionDate = DateTime.Now.AddDays(-1), + }, + new OrganizationApiKey + { + Id = Guid.NewGuid(), + OrganizationId = organizationId, + ApiKey = "test_other", + Type = keyType, + RevisionDate = DateTime.Now.AddDays(-1), + }, + }); - await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); - } + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); + } - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_HasNone_CreatesAndReturns(SutProvider sutProvider, - Guid organizationId, OrganizationApiKeyType keyType) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organizationId, keyType) - .Returns(Enumerable.Empty()); + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_HasNone_CreatesAndReturns(SutProvider sutProvider, + Guid organizationId, OrganizationApiKeyType keyType) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organizationId, keyType) + .Returns(Enumerable.Empty()); - var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); + var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); - Assert.NotNull(apiKey); - Assert.Equal(organizationId, apiKey.OrganizationId); - Assert.Equal(keyType, apiKey.Type); - await sutProvider.GetDependency() - .Received(1) - .CreateAsync(Arg.Any()); - } + Assert.NotNull(apiKey); + Assert.Equal(organizationId, apiKey.OrganizationId); + Assert.Equal(keyType, apiKey.Type); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Any()); + } - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_BadType_Throws(SutProvider sutProvider, - Guid organizationId, OrganizationApiKeyType keyType) - { - keyType = (OrganizationApiKeyType)byte.MaxValue; + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_BadType_Throws(SutProvider sutProvider, + Guid organizationId, OrganizationApiKeyType keyType) + { + keyType = (OrganizationApiKeyType)byte.MaxValue; - await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs index dc2ec10c2d..5bea4b8d21 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs @@ -5,18 +5,19 @@ using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys; - -[SutProviderCustomize] -public class RotateOrganizationApiKeyCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys { - [Theory, BitAutoData] - public async Task RotateApiKeyAsync_RotatesKey(SutProvider sutProvider, - OrganizationApiKey organizationApiKey) + [SutProviderCustomize] + public class RotateOrganizationApiKeyCommandTests { - var existingKey = organizationApiKey.ApiKey; - organizationApiKey = await sutProvider.Sut.RotateApiKeyAsync(organizationApiKey); - Assert.NotEqual(existingKey, organizationApiKey.ApiKey); - AssertHelper.AssertRecent(organizationApiKey.RevisionDate); + [Theory, BitAutoData] + public async Task RotateApiKeyAsync_RotatesKey(SutProvider sutProvider, + OrganizationApiKey organizationApiKey) + { + var existingKey = organizationApiKey.ApiKey; + organizationApiKey = await sutProvider.Sut.RotateApiKeyAsync(organizationApiKey); + Assert.NotEqual(existingKey, organizationApiKey.ApiKey); + AssertHelper.AssertRecent(organizationApiKey.RevisionDate); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs index bfcb532d8c..c46a7e7065 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs @@ -8,19 +8,20 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections; - -[SutProviderCustomize] -public class CreateOrganizationConnectionCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections { - [Theory] - [BitAutoData] - public async Task CreateAsync_CallsCreate(OrganizationConnectionData data, - SutProvider sutProvider) + [SutProviderCustomize] + public class CreateOrganizationConnectionCommandTests { - await sutProvider.Sut.CreateAsync(data); + [Theory] + [BitAutoData] + public async Task CreateAsync_CallsCreate(OrganizationConnectionData data, + SutProvider sutProvider) + { + await sutProvider.Sut.CreateAsync(data); - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs index 9432968f57..5a6690cd72 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs @@ -6,19 +6,20 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections; - -[SutProviderCustomize] -public class DeleteOrganizationConnectionCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections { - [Theory] - [BitAutoData] - public async Task DeleteAsync_CallsDelete(OrganizationConnection connection, - SutProvider sutProvider) + [SutProviderCustomize] + public class DeleteOrganizationConnectionCommandTests { - await sutProvider.Sut.DeleteAsync(connection); + [Theory] + [BitAutoData] + public async Task DeleteAsync_CallsDelete(OrganizationConnection connection, + SutProvider sutProvider) + { + await sutProvider.Sut.DeleteAsync(connection); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(connection); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(connection); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs index f46d799d16..dba643214e 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs @@ -10,49 +10,50 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections; - -[SutProviderCustomize] -public class UpdateOrganizationConnectionCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections { - [Theory] - [BitAutoData] - public async Task UpdateAsync_NoId_Fails(OrganizationConnectionData data, - SutProvider sutProvider) + [SutProviderCustomize] + public class UpdateOrganizationConnectionCommandTests { - data.Id = null; + [Theory] + [BitAutoData] + public async Task UpdateAsync_NoId_Fails(OrganizationConnectionData data, + SutProvider sutProvider) + { + data.Id = null; - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); - Assert.Contains("Cannot update connection, Connection does not exist.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + Assert.Contains("Cannot update connection, Connection does not exist.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - [Theory] - [BitAutoData] - public async Task UpdateAsync_ConnectionDoesNotExist_ThrowsNotFound( - OrganizationConnectionData data, - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); + [Theory] + [BitAutoData] + public async Task UpdateAsync_ConnectionDoesNotExist_ThrowsNotFound( + OrganizationConnectionData data, + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - [Theory] - [BitAutoData] - public async Task UpdateAsync_CallsUpsert(OrganizationConnectionData data, - OrganizationConnection existing, - SutProvider sutProvider) - { - data.Id = existing.Id; + [Theory] + [BitAutoData] + public async Task UpdateAsync_CallsUpsert(OrganizationConnectionData data, + OrganizationConnection existing, + SutProvider sutProvider) + { + data.Id = existing.Id; - sutProvider.GetDependency().GetByIdAsync(data.Id.Value).Returns(existing); - await sutProvider.Sut.UpdateAsync(data); + sutProvider.GetDependency().GetByIdAsync(data.Id.Value).Returns(existing); + await sutProvider.Sut.UpdateAsync(data); - await sutProvider.GetDependency().Received(1) - .UpsertAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); + await sutProvider.GetDependency().Received(1) + .UpsertAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs index ca684a30ce..8823957218 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs @@ -4,70 +4,71 @@ using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using NSubstitute; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, -OrganizationSponsorship sponsorship, SutProvider sutProvider) + public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseTestsBase { - await sutProvider.GetDependency().Received(1) - .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); - await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); - if (sponsorship != null) + protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, + OrganizationSponsorship sponsorship, SutProvider sutProvider) { - await sutProvider.GetDependency().Received(1) - .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(sponsoredOrg.BillingEmailAddress(), sponsorship.ValidUntil.GetValueOrDefault()); + await sutProvider.GetDependency().Received(1) + .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); + await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); + if (sponsorship != null) + { + await sutProvider.GetDependency().Received(1) + .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(sponsoredOrg.BillingEmailAddress(), sponsorship.ValidUntil.GetValueOrDefault()); + } + } + + protected async Task AssertDeletedSponsorshipAsync(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.GetDependency().Received(1) + .DeleteAsync(sponsorship); + } + + protected static async Task AssertDidNotRemoveSponsorshipAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .DeleteAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + protected async Task AssertRemovedSponsorshipAsync(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.GetDependency().Received(1) + .DeleteAsync(sponsorship); + } + + protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .RemoveOrganizationSponsorshipAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(default, default); + } + + protected static async Task AssertDidNotDeleteSponsorshipAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .DeleteAsync(default); + } + + protected static async Task AssertDidNotUpdateSponsorshipAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + protected static async Task AssertUpdatedSponsorshipAsync(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.GetDependency().Received(1).UpsertAsync(sponsorship); } } - - protected async Task AssertDeletedSponsorshipAsync(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.GetDependency().Received(1) - .DeleteAsync(sponsorship); - } - - protected static async Task AssertDidNotRemoveSponsorshipAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .DeleteAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - protected async Task AssertRemovedSponsorshipAsync(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.GetDependency().Received(1) - .DeleteAsync(sponsorship); - } - - protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .RemoveOrganizationSponsorshipAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(default, default); - } - - protected static async Task AssertDidNotDeleteSponsorshipAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .DeleteAsync(default); - } - - protected static async Task AssertDidNotUpdateSponsorshipAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - protected static async Task AssertUpdatedSponsorshipAsync(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.GetDependency().Received(1).UpsertAsync(sponsorship); - } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs index 2b9a27c1ab..f0c7e976ad 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs @@ -6,45 +6,46 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -[OrganizationSponsorshipCustomize] -public class CloudRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( - SutProvider sutProvider) + [SutProviderCustomize] + [OrganizationSponsorshipCustomize] + public class CloudRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RevokeSponsorshipAsync(null)); + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RevokeSponsorshipAsync(null)); - Assert.Contains("You are not currently sponsoring an organization.", exception.Message); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("You are not currently sponsoring an organization.", exception.Message); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipNotRedeemed_DeletesSponsorship(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.SponsoredOrganizationId = null; + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipNotRedeemed_DeletesSponsorship(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.SponsoredOrganizationId = null; - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); - } + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipRedeemed_MarksForDelete(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipRedeemed_MarksForDelete(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - Assert.True(sponsorship.ToDelete); - await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); + Assert.True(sponsorship.ToDelete); + await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs index f7534d8a73..3a55178148 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs @@ -10,216 +10,218 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -public class CloudSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task SyncOrganization_SponsoringOrgNotFound_ThrowsBadRequest( - IEnumerable sponsorshipsData, - SutProvider sutProvider) + [SutProviderCustomize] + public class CloudSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(null, sponsorshipsData)); - Assert.Contains("Failed to sync sponsorship - missing organization.", exception.Message); + [Theory] + [BitAutoData] + public async Task SyncOrganization_SponsoringOrgNotFound_ThrowsBadRequest( + IEnumerable sponsorshipsData, + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(null, sponsorshipsData)); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + Assert.Contains("Failed to sync sponsorship - missing organization.", exception.Message); - [Theory] - [BitAutoData] - public async Task SyncOrganization_NoSponsorships_EarlyReturn( - Organization organization, - SutProvider sutProvider) - { - var result = await sutProvider.Sut.SyncOrganization(organization, Enumerable.Empty()); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - Assert.Empty(result.Item1.SponsorshipsBatch); - Assert.Empty(result.Item2); + [Theory] + [BitAutoData] + public async Task SyncOrganization_NoSponsorships_EarlyReturn( + Organization organization, + SutProvider sutProvider) + { + var result = await sutProvider.Sut.SyncOrganization(organization, Enumerable.Empty()); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + Assert.Empty(result.Item1.SponsorshipsBatch); + Assert.Empty(result.Item2); - [Theory] - [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] - public async Task SyncOrganization_BadSponsoringOrgPlan_NoSync( - PlanType planType, - Organization organization, IEnumerable sponsorshipsData, - SutProvider sutProvider) - { - organization.PlanType = planType; + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - await sutProvider.Sut.SyncOrganization(organization, sponsorshipsData); + [Theory] + [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] + public async Task SyncOrganization_BadSponsoringOrgPlan_NoSync( + PlanType planType, + Organization organization, IEnumerable sponsorshipsData, + SutProvider sutProvider) + { + organization.PlanType = planType; - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + await sutProvider.Sut.SyncOrganization(organization, sponsorshipsData); - [Theory] - [BitAutoData] - public async Task SyncOrganization_Success_RecordsEvent(Organization organization, - SutProvider sutProvider) - { - await sutProvider.Sut.SyncOrganization(organization, Array.Empty()); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - await sutProvider.GetDependency().Received(1).LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced, Arg.Any()); - } + [Theory] + [BitAutoData] + public async Task SyncOrganization_Success_RecordsEvent(Organization organization, + SutProvider sutProvider) + { + await sutProvider.Sut.SyncOrganization(organization, Array.Empty()); - [Theory] - [BitAutoData] - public async Task SyncOrganization_OneExisting_OneNew_Success(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship existingSponsorship, OrganizationSponsorship newSponsorship) - { - // Arrange - sponsoringOrganization.Enabled = true; - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + await sutProvider.GetDependency().Received(1).LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced, Arg.Any()); + } - existingSponsorship.ToDelete = false; - newSponsorship.ToDelete = false; + [Theory] + [BitAutoData] + public async Task SyncOrganization_OneExisting_OneNew_Success(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship existingSponsorship, OrganizationSponsorship newSponsorship) + { + // Arrange + sponsoringOrganization.Enabled = true; + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List + existingSponsorship.ToDelete = false; + newSponsorship.ToDelete = false; + + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List + { + existingSponsorship, + }); + + // Act + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] { - existingSponsorship, + new OrganizationSponsorshipData(existingSponsorship), + new OrganizationSponsorshipData(newSponsorship), }); - // Act - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + // Assert + // Should have updated the cloud copy for each item given + await sutProvider.GetDependency() + .Received(1) + .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); + + // Neither were marked as delete, should not have deleted + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + + // Only one sponsorship was new so it should only send one + Assert.Single(toEmailSponsorships); + } + + [Theory] + [BitAutoData] + public async Task SyncOrganization_TwoToDelete_OneCanDelete_Success(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship canDeleteSponsorship, OrganizationSponsorship cannotDeleteSponsorship) { - new OrganizationSponsorshipData(existingSponsorship), - new OrganizationSponsorshipData(newSponsorship), - }); + // Arrange + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - // Assert - // Should have updated the cloud copy for each item given - await sutProvider.GetDependency() - .Received(1) - .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); + canDeleteSponsorship.ToDelete = true; + canDeleteSponsorship.SponsoredOrganizationId = null; - // Neither were marked as delete, should not have deleted - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); + cannotDeleteSponsorship.ToDelete = true; + cannotDeleteSponsorship.SponsoredOrganizationId = Guid.NewGuid(); - // Only one sponsorship was new so it should only send one - Assert.Single(toEmailSponsorships); - } + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List + { + canDeleteSponsorship, + cannotDeleteSponsorship, + }); - [Theory] - [BitAutoData] - public async Task SyncOrganization_TwoToDelete_OneCanDelete_Success(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship canDeleteSponsorship, OrganizationSponsorship cannotDeleteSponsorship) - { - // Arrange - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - - canDeleteSponsorship.ToDelete = true; - canDeleteSponsorship.SponsoredOrganizationId = null; - - cannotDeleteSponsorship.ToDelete = true; - cannotDeleteSponsorship.SponsoredOrganizationId = Guid.NewGuid(); - - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List + // Act + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] { - canDeleteSponsorship, - cannotDeleteSponsorship, + new OrganizationSponsorshipData(canDeleteSponsorship), + new OrganizationSponsorshipData(cannotDeleteSponsorship), }); - // Act - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + // Assert + + await sutProvider.GetDependency() + .Received(1) + .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); + + // Deletes the sponsorship that had delete requested and is not sponsoring an org + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Is>(toDeleteIds => + toDeleteIds.Count() == 1 && toDeleteIds.ElementAt(0) == canDeleteSponsorship.Id)); + } + + [Theory] + [BitAutoData] + public async Task SyncOrganization_BadData_DoesNotSave(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship badOrganizationSponsorship) { - new OrganizationSponsorshipData(canDeleteSponsorship), - new OrganizationSponsorshipData(cannotDeleteSponsorship), - }); + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - // Assert + badOrganizationSponsorship.ToDelete = true; + badOrganizationSponsorship.LastSyncDate = null; - await sutProvider.GetDependency() - .Received(1) - .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List()); - // Deletes the sponsorship that had delete requested and is not sponsoring an org - await sutProvider.GetDependency() - .Received(1) - .DeleteManyAsync(Arg.Is>(toDeleteIds => - toDeleteIds.Count() == 1 && toDeleteIds.ElementAt(0) == canDeleteSponsorship.Id)); - } + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + { + new OrganizationSponsorshipData(badOrganizationSponsorship), + }); - [Theory] - [BitAutoData] - public async Task SyncOrganization_BadData_DoesNotSave(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship badOrganizationSponsorship) - { - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); - badOrganizationSponsorship.ToDelete = true; - badOrganizationSponsorship.LastSyncDate = null; + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List()); - - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + [Theory] + [BitAutoData] + public async Task SyncOrganization_OrgDisabledForFourMonths_DoesNotSave(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship organizationSponsorship) { - new OrganizationSponsorshipData(badOrganizationSponsorship), - }); + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrganization.Enabled = false; + sponsoringOrganization.ExpirationDate = DateTime.UtcNow.AddDays(-120); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); + organizationSponsorship.ToDelete = false; - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List()); - [Theory] - [BitAutoData] - public async Task SyncOrganization_OrgDisabledForFourMonths_DoesNotSave(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship organizationSponsorship) - { - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - sponsoringOrganization.Enabled = false; - sponsoringOrganization.ExpirationDate = DateTime.UtcNow.AddDays(-120); + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + { + new OrganizationSponsorshipData(organizationSponsorship), + }); - organizationSponsorship.ToDelete = false; + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List()); - - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] - { - new OrganizationSponsorshipData(organizationSponsorship), - }); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs index b85a3f2342..ca89199a40 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs @@ -6,21 +6,22 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -public class OrganizationSponsorshipRenewCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task UpdateExpirationDate_UpdatesValidUntil(OrganizationSponsorship sponsorship, DateTime expireDate, - SutProvider sutProvider) + [SutProviderCustomize] + public class OrganizationSponsorshipRenewCommandTests { - sutProvider.GetDependency().GetBySponsoredOrganizationIdAsync(sponsorship.SponsoredOrganizationId.Value).Returns(sponsorship); + [Theory] + [BitAutoData] + public async Task UpdateExpirationDate_UpdatesValidUntil(OrganizationSponsorship sponsorship, DateTime expireDate, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetBySponsoredOrganizationIdAsync(sponsorship.SponsoredOrganizationId.Value).Returns(sponsorship); - await sutProvider.Sut.UpdateExpirationDateAsync(sponsorship.SponsoredOrganizationId.Value, expireDate); + await sutProvider.Sut.UpdateExpirationDateAsync(sponsorship.SponsoredOrganizationId.Value, expireDate); - await sutProvider.GetDependency().Received(1) - .UpsertAsync(sponsorship); + await sutProvider.GetDependency().Received(1) + .UpsertAsync(sponsorship); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs index 29adcb486e..a3ee0a7cd1 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs @@ -6,37 +6,38 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -[SutProviderCustomize] -[OrganizationSponsorshipCustomize] -public class RemoveSponsorshipCommandTests : CancelSponsorshipCommandTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - [Theory] - [BitAutoData] - public async Task RemoveSponsorship_SponsoredOrgNull_ThrowsBadRequest(OrganizationSponsorship sponsorship, - SutProvider sutProvider) + [SutProviderCustomize] + [OrganizationSponsorshipCustomize] + public class RemoveSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - sponsorship.SponsoredOrganizationId = null; + [Theory] + [BitAutoData] + public async Task RemoveSponsorship_SponsoredOrgNull_ThrowsBadRequest(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.SponsoredOrganizationId = null; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RemoveSponsorshipAsync(sponsorship)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RemoveSponsorshipAsync(sponsorship)); - Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); - Assert.False(sponsorship.ToDelete); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); + Assert.False(sponsorship.ToDelete); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task RemoveSponsorship_SponsorshipNotFound_ThrowsBadRequest(SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RemoveSponsorshipAsync(null)); + [Theory] + [BitAutoData] + public async Task RemoveSponsorship_SponsorshipNotFound_ThrowsBadRequest(SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RemoveSponsorshipAsync(null)); - Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); + Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs index f4f8a2cf4a..15377d7fe5 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs @@ -10,114 +10,115 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -[SutProviderCustomize] -[OrganizationSponsorshipCustomize] -public class SendSponsorshipOfferCommandTests : FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - [Theory] - [BitAutoData] - public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_ExistingAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, User user, SutProvider sutProvider) + [SutProviderCustomize] + [OrganizationSponsorshipCustomize] + public class SendSponsorshipOfferCommandTests : FamiliesForEnterpriseTestsBase { - sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns(user); + [Theory] + [BitAutoData] + public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_ExistingAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns(user); - await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); + await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); - await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, true, Arg.Any()); - } + await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, true, Arg.Any()); + } - [Theory] - [BitAutoData] - public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_NewAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns((User)null); + [Theory] + [BitAutoData] + public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_NewAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns((User)null); - await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); + await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); - await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, false, Arg.Any()); - } + await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, false, Arg.Any()); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_SponsoringOrgNotFound_ThrowsBadRequest( - OrganizationUser orgUser, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(null, orgUser, sponsorship)); + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_SponsoringOrgNotFound_ThrowsBadRequest( + OrganizationUser orgUser, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(null, orgUser, sponsorship)); - Assert.Contains("Cannot find the requested sponsoring organization.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Cannot find the requested sponsoring organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_SponsoringOrgUserNotFound_ThrowsBadRequest(Organization org, - OrganizationSponsorship sponsorship, SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, null, sponsorship)); + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_SponsoringOrgUserNotFound_ThrowsBadRequest(Organization org, + OrganizationSponsorship sponsorship, SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, null, sponsorship)); - Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] - public async Task ResendSponsorshipOffer_SponsoringOrgUserNotConfirmed_ThrowsBadRequest(OrganizationUserStatusType status, - Organization org, OrganizationUser orgUser, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - orgUser.Status = status; + [Theory] + [BitAutoData] + [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] + public async Task ResendSponsorshipOffer_SponsoringOrgUserNotConfirmed_ThrowsBadRequest(OrganizationUserStatusType status, + Organization org, OrganizationUser orgUser, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + orgUser.Status = status; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); - Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_SponsorshipNotFound_ThrowsBadRequest(Organization org, - OrganizationUser orgUser, - SutProvider sutProvider) - { - orgUser.Status = OrganizationUserStatusType.Confirmed; + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_SponsorshipNotFound_ThrowsBadRequest(Organization org, + OrganizationUser orgUser, + SutProvider sutProvider) + { + orgUser.Status = OrganizationUserStatusType.Confirmed; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, null)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, null)); - Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_NoOfferToEmail_ThrowsBadRequest(Organization org, - OrganizationUser orgUser, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - orgUser.Status = OrganizationUserStatusType.Confirmed; - sponsorship.OfferedToEmail = null; + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_NoOfferToEmail_ThrowsBadRequest(Organization org, + OrganizationUser orgUser, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + orgUser.Status = OrganizationUserStatusType.Confirmed; + sponsorship.OfferedToEmail = null; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); - Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs index 5776e3e849..358e4f0072 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs @@ -10,85 +10,86 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -[OrganizationSponsorshipCustomize] -public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task SetUpSponsorship_SponsorshipNotFound_ThrowsBadRequest(Organization org, - SutProvider sutProvider) + [SutProviderCustomize] + [OrganizationSponsorshipCustomize] + public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(null, org)); + [Theory] + [BitAutoData] + public async Task SetUpSponsorship_SponsorshipNotFound_ThrowsBadRequest(Organization org, + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(null, org)); - Assert.Contains("No unredeemed sponsorship offer exists for you.", exception.Message); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("No unredeemed sponsorship offer exists for you.", exception.Message); + await AssertDidNotSetUpAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task SetUpSponsorship_OrgAlreadySponsored_ThrowsBadRequest(Organization org, - OrganizationSponsorship sponsorship, OrganizationSponsorship existingSponsorship, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(org.Id).Returns(existingSponsorship); + [Theory] + [BitAutoData] + public async Task SetUpSponsorship_OrgAlreadySponsored_ThrowsBadRequest(Organization org, + OrganizationSponsorship sponsorship, OrganizationSponsorship existingSponsorship, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(org.Id).Returns(existingSponsorship); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); - Assert.Contains("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first.", exception.Message); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first.", exception.Message); + await AssertDidNotSetUpAsync(sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(FamiliesPlanTypes))] - public async Task SetUpSponsorship_TooLongSinceLastSync_ThrowsBadRequest(PlanType planType, Organization org, - OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - org.PlanType = planType; - sponsorship.LastSyncDate = DateTime.UtcNow.AddDays(-365); + [Theory] + [BitMemberAutoData(nameof(FamiliesPlanTypes))] + public async Task SetUpSponsorship_TooLongSinceLastSync_ThrowsBadRequest(PlanType planType, Organization org, + OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + org.PlanType = planType; + sponsorship.LastSyncDate = DateTime.UtcNow.AddDays(-365); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); - Assert.Contains("This sponsorship offer is more than 6 months old and has expired.", exception.Message); - await sutProvider.GetDependency() - .Received(1) - .DeleteAsync(sponsorship); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("This sponsorship offer is more than 6 months old and has expired.", exception.Message); + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(sponsorship); + await AssertDidNotSetUpAsync(sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(NonFamiliesPlanTypes))] - public async Task SetUpSponsorship_OrgNotFamiles_ThrowsBadRequest(PlanType planType, - OrganizationSponsorship sponsorship, Organization org, - SutProvider sutProvider) - { - org.PlanType = planType; - sponsorship.LastSyncDate = DateTime.UtcNow; + [Theory] + [BitMemberAutoData(nameof(NonFamiliesPlanTypes))] + public async Task SetUpSponsorship_OrgNotFamiles_ThrowsBadRequest(PlanType planType, + OrganizationSponsorship sponsorship, Organization org, + SutProvider sutProvider) + { + org.PlanType = planType; + sponsorship.LastSyncDate = DateTime.UtcNow; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); - Assert.Contains("Can only redeem sponsorship offer on families organizations.", exception.Message); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("Can only redeem sponsorship offer on families organizations.", exception.Message); + await AssertDidNotSetUpAsync(sutProvider); + } - private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SponsorOrganizationAsync(default, default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); + private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SponsorOrganizationAsync(default, default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs index 9b01e3035f..4b3426a53a 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs @@ -8,50 +8,51 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -public class ValidateBillingSyncKeyCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task ValidateBillingSyncKeyAsync_NullOrganization_Throws(SutProvider sutProvider) + [SutProviderCustomize] + public class ValidateBillingSyncKeyCommandTests { - await Assert.ThrowsAsync(() => sutProvider.Sut.ValidateBillingSyncKeyAsync(null, null)); - } + [Theory] + [BitAutoData] + public async Task ValidateBillingSyncKeyAsync_NullOrganization_Throws(SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.ValidateBillingSyncKeyAsync(null, null)); + } - [Theory] - [BitAutoData((string)null)] - [BitAutoData("")] - [BitAutoData(" ")] - public async Task ValidateBillingSyncKeyAsync_BadString_ReturnsFalse(string billingSyncKey, SutProvider sutProvider) - { - Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(new Organization(), billingSyncKey)); - } + [Theory] + [BitAutoData((string)null)] + [BitAutoData("")] + [BitAutoData(" ")] + public async Task ValidateBillingSyncKeyAsync_BadString_ReturnsFalse(string billingSyncKey, SutProvider sutProvider) + { + Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(new Organization(), billingSyncKey)); + } - [Theory] - [BitAutoData] - public async Task ValidateBillingSyncKeyAsync_KeyEquals_ReturnsTrue(SutProvider sutProvider, - Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) - { - orgApiKey.ApiKey = billingSyncKey; + [Theory] + [BitAutoData] + public async Task ValidateBillingSyncKeyAsync_KeyEquals_ReturnsTrue(SutProvider sutProvider, + Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) + { + orgApiKey.ApiKey = billingSyncKey; - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) - .Returns(new[] { orgApiKey }); + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) + .Returns(new[] { orgApiKey }); - Assert.True(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); - } + Assert.True(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); + } - [Theory] - [BitAutoData] - public async Task ValidateBillingSyncKeyAsync_KeyDoesNotEqual_ReturnsFalse(SutProvider sutProvider, - Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) - .Returns(new[] { orgApiKey }); + [Theory] + [BitAutoData] + public async Task ValidateBillingSyncKeyAsync_KeyDoesNotEqual_ReturnsFalse(SutProvider sutProvider, + Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) + .Returns(new[] { orgApiKey }); - Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); + Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs index 65aa4cfb2f..9bbaaed1dc 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs @@ -9,78 +9,79 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -public class ValidateRedemptionTokenCommandTests +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task ValidateRedemptionTokenAsync_CannotUnprotect_ReturnsFalse(SutProvider sutProvider, - string encryptedString) + [SutProviderCustomize] + public class ValidateRedemptionTokenCommandTests { - sutProvider - .GetDependency>() - .TryUnprotect(encryptedString, out _) - .Returns(call => - { - call[1] = null; - return false; - }); + [Theory] + [BitAutoData] + public async Task ValidateRedemptionTokenAsync_CannotUnprotect_ReturnsFalse(SutProvider sutProvider, + string encryptedString) + { + sutProvider + .GetDependency>() + .TryUnprotect(encryptedString, out _) + .Returns(call => + { + call[1] = null; + return false; + }); - var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, null); - Assert.False(valid); - Assert.Null(sponsorship); - } + var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, null); + Assert.False(valid); + Assert.Null(sponsorship); + } - [Theory] - [BitAutoData] - public async Task ValidateRedemptionTokenAsync_NoSponsorship_ReturnsFalse(SutProvider sutProvider, - string encryptedString, OrganizationSponsorshipOfferTokenable tokenable) - { - sutProvider - .GetDependency>() - .TryUnprotect(encryptedString, out _) - .Returns(call => - { - call[1] = tokenable; - return true; - }); + [Theory] + [BitAutoData] + public async Task ValidateRedemptionTokenAsync_NoSponsorship_ReturnsFalse(SutProvider sutProvider, + string encryptedString, OrganizationSponsorshipOfferTokenable tokenable) + { + sutProvider + .GetDependency>() + .TryUnprotect(encryptedString, out _) + .Returns(call => + { + call[1] = tokenable; + return true; + }); - var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, "test@email.com"); - Assert.False(valid); - Assert.Null(sponsorship); - } + var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, "test@email.com"); + Assert.False(valid); + Assert.Null(sponsorship); + } - [Theory] - [BitAutoData] - public async Task ValidateRedemptionTokenAsync_ValidSponsorship_ReturnsFalse(SutProvider sutProvider, - string encryptedString, string email, OrganizationSponsorshipOfferTokenable tokenable) - { - tokenable.Email = email; + [Theory] + [BitAutoData] + public async Task ValidateRedemptionTokenAsync_ValidSponsorship_ReturnsFalse(SutProvider sutProvider, + string encryptedString, string email, OrganizationSponsorshipOfferTokenable tokenable) + { + tokenable.Email = email; - sutProvider - .GetDependency>() - .TryUnprotect(encryptedString, out _) - .Returns(call => - { - call[1] = tokenable; - return true; - }); + sutProvider + .GetDependency>() + .TryUnprotect(encryptedString, out _) + .Returns(call => + { + call[1] = tokenable; + return true; + }); - sutProvider.GetDependency() - .GetByIdAsync(tokenable.Id) - .Returns(new OrganizationSponsorship - { - Id = tokenable.Id, - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - OfferedToEmail = email - }); + sutProvider.GetDependency() + .GetByIdAsync(tokenable.Id) + .Returns(new OrganizationSponsorship + { + Id = tokenable.Id, + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + OfferedToEmail = email + }); - var (valid, sponsorship) = await sutProvider.Sut - .ValidateRedemptionTokenAsync(encryptedString, email); + var (valid, sponsorship) = await sutProvider.Sut + .ValidateRedemptionTokenAsync(encryptedString, email); - Assert.True(valid); - Assert.NotNull(sponsorship); + Assert.True(valid); + Assert.NotNull(sponsorship); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs index a187f5b296..f1beb64f37 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs @@ -8,246 +8,247 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - -[SutProviderCustomize] -[OrganizationSponsorshipCustomize] -public class ValidateSponsorshipCommandTests : CancelSponsorshipCommandTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud { - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_NoSponsoredOrg_EarlyReturn(Guid sponsoredOrgId, - SutProvider sutProvider) + [SutProviderCustomize] + [OrganizationSponsorshipCustomize] + public class ValidateSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - sutProvider.GetDependency().GetByIdAsync(sponsoredOrgId).Returns((Organization)null); + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_NoSponsoredOrg_EarlyReturn(Guid sponsoredOrgId, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(sponsoredOrgId).Returns((Organization)null); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrgId); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrgId); - Assert.False(result); - await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - } + Assert.False(result); + await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_NoExistingSponsorship_UpdatesStripePlan(Organization sponsoredOrg, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_NoExistingSponsorship_UpdatesStripePlan(Organization sponsoredOrg, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, null, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, null, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsoringOrgDefault_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - existingSponsorship.SponsoringOrganizationId = default; + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsoringOrgDefault_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + existingSponsorship.SponsoringOrganizationId = default; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsoringOrgUserDefault_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - existingSponsorship.SponsoringOrganizationUserId = default; + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsoringOrgUserDefault_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + existingSponsorship.SponsoringOrganizationUserId = default; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsorshipTypeNull_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - existingSponsorship.PlanSponsorshipType = null; + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsorshipTypeNull_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + existingSponsorship.PlanSponsorshipType = null; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsoringOrgNotFound_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsoringOrgNotFound_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgNotEnterprise_UpdatesStripePlan(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgNotEnterprise_UpdatesStripePlan(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLongerThanGrace_UpdatesStripePlan(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = false; - sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-100); - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLongerThanGrace_UpdatesStripePlan(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = false; + sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-100); + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [OrganizationSponsorshipCustomize(ToDelete = true)] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_ToDeleteSponsorship_IsInvalid(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship sponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = true; - sponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [OrganizationSponsorshipCustomize(ToDelete = true)] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_ToDeleteSponsorship_IsInvalid(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship sponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = true; + sponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(sponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(sponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); + Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, sponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); - } + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, sponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledUnknownTime_UpdatesStripePlan(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = false; - sponsoringOrg.ExpirationDate = null; - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledUnknownTime_UpdatesStripePlan(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = false; + sponsoringOrg.ExpirationDate = null; + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertRemovedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertRemovedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLessThanGrace_Valid(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = true; - sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-1); - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLessThanGrace_Valid(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = true; + sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-1); + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.True(result); + Assert.True(result); - await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); - await AssertDidNotRemoveSponsorshipAsync(sutProvider); - } + await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); + await AssertDidNotRemoveSponsorshipAsync(sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_Valid(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = true; - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_Valid(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = true; + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.True(result); + Assert.True(result); - await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs index 4eb2779d92..b4e014d061 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs @@ -13,166 +13,167 @@ using NSubstitute.ExceptionExtensions; using NSubstitute.ReturnsExtensions; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -[SutProviderCustomize] -public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - private bool SponsorshipValidator(OrganizationSponsorship sponsorship, OrganizationSponsorship expectedSponsorship) + [SutProviderCustomize] + public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase { - try + private bool SponsorshipValidator(OrganizationSponsorship sponsorship, OrganizationSponsorship expectedSponsorship) { - AssertHelper.AssertPropertyEqual(sponsorship, expectedSponsorship, nameof(OrganizationSponsorship.Id)); - return true; + try + { + AssertHelper.AssertPropertyEqual(sponsorship, expectedSponsorship, nameof(OrganizationSponsorship.Id)); + return true; + } + catch + { + return false; + } } - catch + + [Theory, BitAutoData] + public async Task CreateSponsorship_OfferedToNotFound_ThrowsBadRequest(OrganizationUser orgUser, SutProvider sutProvider) { - return false; + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).ReturnsNull(); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); + + Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); } - } - [Theory, BitAutoData] - public async Task CreateSponsorship_OfferedToNotFound_ThrowsBadRequest(OrganizationUser orgUser, SutProvider sutProvider) - { - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).ReturnsNull(); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - - Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory, BitAutoData] - public async Task CreateSponsorship_OfferedToSelf_ThrowsBadRequest(OrganizationUser orgUser, string sponsoredEmail, User user, SutProvider sutProvider) - { - user.Email = sponsoredEmail; - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, default)); - - Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory, BitMemberAutoData(nameof(NonEnterprisePlanTypes))] - public async Task CreateSponsorship_BadSponsoringOrgPlan_ThrowsBadRequest(PlanType sponsoringOrgPlan, - Organization org, OrganizationUser orgUser, User user, SutProvider sutProvider) - { - org.PlanType = sponsoringOrgPlan; - orgUser.Status = OrganizationUserStatusType.Confirmed; - - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - - Assert.Contains("Specified Organization cannot sponsor other organizations.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory] - [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] - public async Task CreateSponsorship_BadSponsoringUserStatus_ThrowsBadRequest( - OrganizationUserStatusType statusType, Organization org, OrganizationUser orgUser, User user, - SutProvider sutProvider) - { - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.Status = statusType; - - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - - Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory] - [OrganizationSponsorshipCustomize] - [BitAutoData] - public async Task CreateSponsorship_AlreadySponsoring_Throws(Organization org, - OrganizationUser orgUser, User user, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.Status = OrganizationUserStatusType.Confirmed; - - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - sutProvider.GetDependency() - .GetBySponsoringOrganizationUserIdAsync(orgUser.Id).Returns(sponsorship); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, sponsorship.PlanSponsorshipType.Value, default, default)); - - Assert.Contains("Can only sponsor one organization per Organization User.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory] - [BitAutoData] - public async Task CreateSponsorship_CreatesSponsorship(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, - string sponsoredEmail, string friendlyName, Guid sponsorshipId, SutProvider sutProvider) - { - sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; - sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; - - sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); - sutProvider.GetDependency().WhenForAnyArgs(x => x.UpsertAsync(default)).Do(callInfo => + [Theory, BitAutoData] + public async Task CreateSponsorship_OfferedToSelf_ThrowsBadRequest(OrganizationUser orgUser, string sponsoredEmail, User user, SutProvider sutProvider) { - var sponsorship = callInfo.Arg(); - sponsorship.Id = sponsorshipId; - }); + user.Email = sponsoredEmail; + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, default)); - await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, - PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName); + Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } - var expectedSponsorship = new OrganizationSponsorship + [Theory, BitMemberAutoData(nameof(NonEnterprisePlanTypes))] + public async Task CreateSponsorship_BadSponsoringOrgPlan_ThrowsBadRequest(PlanType sponsoringOrgPlan, + Organization org, OrganizationUser orgUser, User user, SutProvider sutProvider) { - Id = sponsorshipId, - SponsoringOrganizationId = sponsoringOrg.Id, - SponsoringOrganizationUserId = sponsoringOrgUser.Id, - FriendlyName = friendlyName, - OfferedToEmail = sponsoredEmail, - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + org.PlanType = sponsoringOrgPlan; + orgUser.Status = OrganizationUserStatusType.Confirmed; - await sutProvider.GetDependency().Received(1) - .UpsertAsync(Arg.Is(s => SponsorshipValidator(s, expectedSponsorship))); - } + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - [Theory] - [BitAutoData] - public async Task CreateSponsorship_CreateSponsorshipThrows_RevertsDatabase(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, - string sponsoredEmail, string friendlyName, SutProvider sutProvider) - { - sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; - sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - var expectedException = new Exception(); - OrganizationSponsorship createdSponsorship = null; - sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); - sutProvider.GetDependency().UpsertAsync(default).ThrowsForAnyArgs(callInfo => + Assert.Contains("Specified Organization cannot sponsor other organizations.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory] + [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] + public async Task CreateSponsorship_BadSponsoringUserStatus_ThrowsBadRequest( + OrganizationUserStatusType statusType, Organization org, OrganizationUser orgUser, User user, + SutProvider sutProvider) { - createdSponsorship = callInfo.ArgAt(0); - createdSponsorship.Id = Guid.NewGuid(); - return expectedException; - }); + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.Status = statusType; - var actualException = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, - PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName)); - Assert.Same(expectedException, actualException); + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(createdSponsorship); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); + + Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory] + [OrganizationSponsorshipCustomize] + [BitAutoData] + public async Task CreateSponsorship_AlreadySponsoring_Throws(Organization org, + OrganizationUser orgUser, User user, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); + sutProvider.GetDependency() + .GetBySponsoringOrganizationUserIdAsync(orgUser.Id).Returns(sponsorship); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, sponsorship.PlanSponsorshipType.Value, default, default)); + + Assert.Contains("Can only sponsor one organization per Organization User.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory] + [BitAutoData] + public async Task CreateSponsorship_CreatesSponsorship(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, + string sponsoredEmail, string friendlyName, Guid sponsorshipId, SutProvider sutProvider) + { + sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); + sutProvider.GetDependency().WhenForAnyArgs(x => x.UpsertAsync(default)).Do(callInfo => + { + var sponsorship = callInfo.Arg(); + sponsorship.Id = sponsorshipId; + }); + + + await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, + PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName); + + var expectedSponsorship = new OrganizationSponsorship + { + Id = sponsorshipId, + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = sponsoringOrgUser.Id, + FriendlyName = friendlyName, + OfferedToEmail = sponsoredEmail, + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; + + await sutProvider.GetDependency().Received(1) + .UpsertAsync(Arg.Is(s => SponsorshipValidator(s, expectedSponsorship))); + } + + [Theory] + [BitAutoData] + public async Task CreateSponsorship_CreateSponsorshipThrows_RevertsDatabase(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, + string sponsoredEmail, string friendlyName, SutProvider sutProvider) + { + sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + + var expectedException = new Exception(); + OrganizationSponsorship createdSponsorship = null; + sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); + sutProvider.GetDependency().UpsertAsync(default).ThrowsForAnyArgs(callInfo => + { + createdSponsorship = callInfo.ArgAt(0); + createdSponsorship.Id = Guid.NewGuid(); + return expectedException; + }); + + var actualException = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, + PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName)); + Assert.Same(expectedException, actualException); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(createdSponsorship); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs index e49b095d76..862ae6e806 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs @@ -1,24 +1,25 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - -public abstract class FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise { - public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); + public abstract class FamiliesForEnterpriseTestsBase + { + public static IEnumerable EnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable NonEnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable FamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Families).Select(p => new object[] { p }); + public static IEnumerable FamiliesPlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Families).Select(p => new object[] { p }); - public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); + public static IEnumerable NonFamiliesPlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); - public static IEnumerable NonConfirmedOrganizationUsersStatuses => - Enum.GetValues() - .Where(s => s != OrganizationUserStatusType.Confirmed) - .Select(s => new object[] { s }); + public static IEnumerable NonConfirmedOrganizationUsersStatuses => + Enum.GetValues() + .Where(s => s != OrganizationUserStatusType.Confirmed) + .Select(s => new object[] { s }); + } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs index 7ac1c71281..6dd913383b 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs @@ -6,47 +6,48 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; - -[SutProviderCustomize] -[OrganizationSponsorshipCustomize] -public class SelfHostedRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted { - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( - SutProvider sutProvider) + [SutProviderCustomize] + [OrganizationSponsorshipCustomize] + public class SelfHostedRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RevokeSponsorshipAsync(null)); + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RevokeSponsorshipAsync(null)); - Assert.Contains("You are not currently sponsoring an organization.", exception.Message); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("You are not currently sponsoring an organization.", exception.Message); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipNotSynced_DeletesSponsorship(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.LastSyncDate = null; + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipNotSynced_DeletesSponsorship(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.LastSyncDate = null; - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); - } + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipSynced_MarksForDeletion(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.LastSyncDate = DateTime.UtcNow; + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipSynced_MarksForDeletion(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.LastSyncDate = DateTime.UtcNow; - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - Assert.True(sponsorship.ToDelete); - await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); + Assert.True(sponsorship.ToDelete); + await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + } } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs index 5ec93a976b..5c9741f35d 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs @@ -15,172 +15,174 @@ using NSubstitute; using RichardSzalay.MockHttp; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; - -public class SelfHostedSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted { - public static SutProvider GetSutProvider(bool enableCloudCommunication = true, string identityResponse = null, string apiResponse = null) + public class SelfHostedSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase { - var fixture = new Fixture().WithAutoNSubstitutionsAutoPopulatedProperties(); - fixture.AddMockHttp(); - var settings = fixture.Create(); - settings.SelfHosted = true; - settings.EnableCloudCommunication = enableCloudCommunication; - - var apiUri = fixture.Create(); - var identityUri = fixture.Create(); - settings.Installation.ApiUri.Returns(apiUri.ToString()); - settings.Installation.IdentityUri.Returns(identityUri.ToString()); - - var apiHandler = new MockHttpMessageHandler(); - var identityHandler = new MockHttpMessageHandler(); - var syncUri = string.Concat(apiUri, "organization/sponsorship/sync"); - var tokenUri = string.Concat(identityUri, "connect/token"); - - apiHandler.When(HttpMethod.Post, syncUri) - .Respond("application/json", apiResponse); - identityHandler.When(HttpMethod.Post, tokenUri) - .Respond("application/json", identityResponse ?? "{\"access_token\":\"string\",\"expires_in\":3600,\"token_type\":\"Bearer\",\"scope\":\"string\"}"); - - - var apiHttp = apiHandler.ToHttpClient(); - var identityHttp = identityHandler.ToHttpClient(); - - var mockHttpClientFactory = Substitute.For(); - mockHttpClientFactory.CreateClient(Arg.Is("client")).Returns(apiHttp); - mockHttpClientFactory.CreateClient(Arg.Is("identity")).Returns(identityHttp); - - return new SutProvider(fixture) - .SetDependency(settings) - .SetDependency(mockHttpClientFactory) - .Create(); - } - - [Theory] - [BitAutoData] - public async Task SyncOrganization_BillingSyncKeyDisabled_ThrowsBadRequest( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) - { - var sutProvider = GetSutProvider(); - billingSyncConnection.Enabled = false; - billingSyncConnection.SetConfig(new BillingSyncConfig + public static SutProvider GetSutProvider(bool enableCloudCommunication = true, string identityResponse = null, string apiResponse = null) { - BillingSyncKey = "okslkcslkjf" - }); + var fixture = new Fixture().WithAutoNSubstitutionsAutoPopulatedProperties(); + fixture.AddMockHttp(); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); + var settings = fixture.Create(); + settings.SelfHosted = true; + settings.EnableCloudCommunication = enableCloudCommunication; - Assert.Contains($"Billing Sync Key disabled", exception.Message); + var apiUri = fixture.Create(); + var identityUri = fixture.Create(); + settings.Installation.ApiUri.Returns(apiUri.ToString()); + settings.Installation.IdentityUri.Returns(identityUri.ToString()); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } + var apiHandler = new MockHttpMessageHandler(); + var identityHandler = new MockHttpMessageHandler(); + var syncUri = string.Concat(apiUri, "organization/sponsorship/sync"); + var tokenUri = string.Concat(identityUri, "connect/token"); - [Theory] - [BitAutoData] - public async Task SyncOrganization_BillingSyncKeyEmpty_ThrowsBadRequest( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) - { - var sutProvider = GetSutProvider(); - billingSyncConnection.Config = ""; + apiHandler.When(HttpMethod.Post, syncUri) + .Respond("application/json", apiResponse); + identityHandler.When(HttpMethod.Post, tokenUri) + .Respond("application/json", identityResponse ?? "{\"access_token\":\"string\",\"expires_in\":3600,\"token_type\":\"Bearer\",\"scope\":\"string\"}"); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - Assert.Contains($"No Billing Sync Key known", exception.Message); + var apiHttp = apiHandler.ToHttpClient(); + var identityHttp = identityHandler.ToHttpClient(); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } + var mockHttpClientFactory = Substitute.For(); + mockHttpClientFactory.CreateClient(Arg.Is("client")).Returns(apiHttp); + mockHttpClientFactory.CreateClient(Arg.Is("identity")).Returns(identityHttp); - [Theory] - [BitAutoData] - public async Task SyncOrganization_CloudCommunicationDisabled_EarlyReturn( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) - { - var sutProvider = GetSutProvider(false); + return new SutProvider(fixture) + .SetDependency(settings) + .SetDependency(mockHttpClientFactory) + .Create(); + } - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - - Assert.Contains($"Cloud communication is disabled", exception.Message); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } - - [Theory] - [OrganizationSponsorshipCustomize] - [BitAutoData] - public async Task SyncOrganization_SyncsSponsorships( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) - { - var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( - new OrganizationSponsorshipSyncData + [Theory] + [BitAutoData] + public async Task SyncOrganization_BillingSyncKeyDisabled_ThrowsBadRequest( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + { + var sutProvider = GetSutProvider(); + billingSyncConnection.Enabled = false; + billingSyncConnection.SetConfig(new BillingSyncConfig { - SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o)) - })); + BillingSyncKey = "okslkcslkjf" + }); - var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); - billingSyncConnection.SetConfig(new BillingSyncConfig + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); + + Assert.Contains($"Billing Sync Key disabled", exception.Message); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } + + [Theory] + [BitAutoData] + public async Task SyncOrganization_BillingSyncKeyEmpty_ThrowsBadRequest( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) { - BillingSyncKey = "okslkcslkjf" - }); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); + var sutProvider = GetSutProvider(); + billingSyncConnection.Config = ""; - await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .Received(1) - .UpsertManyAsync(Arg.Any>()); - } + Assert.Contains($"No Billing Sync Key known", exception.Message); - [Theory] - [OrganizationSponsorshipCustomize(ToDelete = true)] - [BitAutoData] - public async Task SyncOrganization_DeletesSponsorships( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) - { - var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( - new OrganizationSponsorshipSyncData + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } + + [Theory] + [BitAutoData] + public async Task SyncOrganization_CloudCommunicationDisabled_EarlyReturn( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + { + var sutProvider = GetSutProvider(false); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); + + Assert.Contains($"Cloud communication is disabled", exception.Message); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } + + [Theory] + [OrganizationSponsorshipCustomize] + [BitAutoData] + public async Task SyncOrganization_SyncsSponsorships( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) + { + var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( + new OrganizationSponsorshipSyncData + { + SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o)) + })); + + var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); + billingSyncConnection.SetConfig(new BillingSyncConfig { - SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o) { CloudSponsorshipRemoved = true }) - })); + BillingSyncKey = "okslkcslkjf" + }); + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); - var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); - billingSyncConnection.SetConfig(new BillingSyncConfig + await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .Received(1) + .UpsertManyAsync(Arg.Any>()); + } + + [Theory] + [OrganizationSponsorshipCustomize(ToDelete = true)] + [BitAutoData] + public async Task SyncOrganization_DeletesSponsorships( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) { - BillingSyncKey = "okslkcslkjf" - }); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); + var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( + new OrganizationSponsorshipSyncData + { + SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o) { CloudSponsorshipRemoved = true }) + })); - await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); + var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); + billingSyncConnection.SetConfig(new BillingSyncConfig + { + BillingSyncKey = "okslkcslkjf" + }); + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); - await sutProvider.GetDependency() - .Received(1) - .DeleteManyAsync(Arg.Any>()); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); + await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); + + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } } } diff --git a/test/Core.Test/Resources/VerifyResources.cs b/test/Core.Test/Resources/VerifyResources.cs index 028ac3e9e2..821eb87e28 100644 --- a/test/Core.Test/Resources/VerifyResources.cs +++ b/test/Core.Test/Resources/VerifyResources.cs @@ -1,27 +1,28 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Resources; - -public class VerifyResources +namespace Bit.Core.Test.Resources { - [Theory] - [MemberData(nameof(GetResources))] - public void Resource_FoundAndReadable(string resourceName) + public class VerifyResources { - var assembly = typeof(CoreHelpers).Assembly; - - using (var resource = assembly.GetManifestResourceStream(resourceName)) + [Theory] + [MemberData(nameof(GetResources))] + public void Resource_FoundAndReadable(string resourceName) { - Assert.NotNull(resource); - Assert.True(resource.CanRead); + var assembly = typeof(CoreHelpers).Assembly; + + using (var resource = assembly.GetManifestResourceStream(resourceName)) + { + Assert.NotNull(resource); + Assert.True(resource.CanRead); + } + } + + public static IEnumerable GetResources() + { + yield return new[] { "Bit.Core.licensing.cer" }; + yield return new[] { "Bit.Core.MailTemplates.Handlebars.AddedCredit.html.hbs" }; + yield return new[] { "Bit.Core.MailTemplates.Handlebars.Layouts.Basic.html.hbs" }; } } - - public static IEnumerable GetResources() - { - yield return new[] { "Bit.Core.licensing.cer" }; - yield return new[] { "Bit.Core.MailTemplates.Handlebars.AddedCredit.html.hbs" }; - yield return new[] { "Bit.Core.MailTemplates.Handlebars.Layouts.Basic.html.hbs" }; - } } diff --git a/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs b/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs index 71bbc9f13e..6c07c897de 100644 --- a/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs @@ -8,79 +8,80 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class AmazonSesMailDeliveryServiceTests : IDisposable +namespace Bit.Core.Test.Services { - private readonly AmazonSesMailDeliveryService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly IAmazonSimpleEmailService _amazonSimpleEmailService; - - public AmazonSesMailDeliveryServiceTests() + public class AmazonSesMailDeliveryServiceTests : IDisposable { - _globalSettings = new GlobalSettings + private readonly AmazonSesMailDeliveryService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly IAmazonSimpleEmailService _amazonSimpleEmailService; + + public AmazonSesMailDeliveryServiceTests() { - Amazon = - { - AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", - AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", - Region = "Region-AmazonSesMailDeliveryServiceTests" - } - }; - - _hostingEnvironment = Substitute.For(); - _logger = Substitute.For>(); - _amazonSimpleEmailService = Substitute.For(); - - _sut = new AmazonSesMailDeliveryService( - _globalSettings, - _hostingEnvironment, - _logger, - _amazonSimpleEmailService - ); - } - - public void Dispose() - { - _sut?.Dispose(); - } - - [Fact] - public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() - { - var mailMessage = new MailMessage - { - ToEmails = new List { "ToEmails" }, - BccEmails = new List { "BccEmails" }, - Subject = "Subject", - HtmlContent = "HtmlContent", - TextContent = "TextContent", - Category = "Category" - }; - - await _sut.SendEmailAsync(mailMessage); - - await _amazonSimpleEmailService.Received(1).SendEmailAsync( - Arg.Do(request => + _globalSettings = new GlobalSettings { - Assert.False(string.IsNullOrEmpty(request.Source)); + Amazon = + { + AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", + AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", + Region = "Region-AmazonSesMailDeliveryServiceTests" + } + }; - Assert.Single(request.Destination.ToAddresses); - Assert.Equal(mailMessage.ToEmails.First(), request.Destination.ToAddresses.First()); + _hostingEnvironment = Substitute.For(); + _logger = Substitute.For>(); + _amazonSimpleEmailService = Substitute.For(); - Assert.Equal(mailMessage.Subject, request.Message.Subject.Data); - Assert.Equal(mailMessage.HtmlContent, request.Message.Body.Html.Data); - Assert.Equal(mailMessage.TextContent, request.Message.Body.Text.Data); + _sut = new AmazonSesMailDeliveryService( + _globalSettings, + _hostingEnvironment, + _logger, + _amazonSimpleEmailService + ); + } - Assert.Single(request.Destination.BccAddresses); - Assert.Equal(mailMessage.BccEmails.First(), request.Destination.BccAddresses.First()); + public void Dispose() + { + _sut?.Dispose(); + } - Assert.Contains(request.Tags, x => x.Name == "Environment"); - Assert.Contains(request.Tags, x => x.Name == "Sender"); - Assert.Contains(request.Tags, x => x.Name == "Category"); - })); + [Fact] + public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() + { + var mailMessage = new MailMessage + { + ToEmails = new List { "ToEmails" }, + BccEmails = new List { "BccEmails" }, + Subject = "Subject", + HtmlContent = "HtmlContent", + TextContent = "TextContent", + Category = "Category" + }; + + await _sut.SendEmailAsync(mailMessage); + + await _amazonSimpleEmailService.Received(1).SendEmailAsync( + Arg.Do(request => + { + Assert.False(string.IsNullOrEmpty(request.Source)); + + Assert.Single(request.Destination.ToAddresses); + Assert.Equal(mailMessage.ToEmails.First(), request.Destination.ToAddresses.First()); + + Assert.Equal(mailMessage.Subject, request.Message.Subject.Data); + Assert.Equal(mailMessage.HtmlContent, request.Message.Body.Html.Data); + Assert.Equal(mailMessage.TextContent, request.Message.Body.Text.Data); + + Assert.Single(request.Destination.BccAddresses); + Assert.Equal(mailMessage.BccEmails.First(), request.Destination.BccAddresses.First()); + + Assert.Contains(request.Tags, x => x.Name == "Environment"); + Assert.Contains(request.Tags, x => x.Name == "Sender"); + Assert.Contains(request.Tags, x => x.Name == "Category"); + })); + } } } diff --git a/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs b/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs index 5c74386edb..cf24d6293c 100644 --- a/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs +++ b/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs @@ -4,73 +4,74 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class AmazonSqsBlockIpServiceTests : IDisposable +namespace Bit.Core.Test.Services { - private readonly AmazonSqsBlockIpService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IAmazonSQS _amazonSqs; - - public AmazonSqsBlockIpServiceTests() + public class AmazonSqsBlockIpServiceTests : IDisposable { - _globalSettings = new GlobalSettings + private readonly AmazonSqsBlockIpService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IAmazonSQS _amazonSqs; + + public AmazonSqsBlockIpServiceTests() { - Amazon = + _globalSettings = new GlobalSettings { - AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", - AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", - Region = "Region-AmazonSesMailDeliveryServiceTests" - } - }; + Amazon = + { + AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", + AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", + Region = "Region-AmazonSesMailDeliveryServiceTests" + } + }; - _amazonSqs = Substitute.For(); + _amazonSqs = Substitute.For(); - _sut = new AmazonSqsBlockIpService(_globalSettings, _amazonSqs); - } + _sut = new AmazonSqsBlockIpService(_globalSettings, _amazonSqs); + } - public void Dispose() - { - _sut?.Dispose(); - } + public void Dispose() + { + _sut?.Dispose(); + } - [Fact] - public async Task BlockIpAsync_UnblockCalled_WhenNotPermanent() - { - const string expectedIp = "ip"; + [Fact] + public async Task BlockIpAsync_UnblockCalled_WhenNotPermanent() + { + const string expectedIp = "ip"; - await _sut.BlockIpAsync(expectedIp, false); + await _sut.BlockIpAsync(expectedIp, false); - await _amazonSqs.Received(2).SendMessageAsync( - Arg.Any(), - Arg.Is(expectedIp)); - } + await _amazonSqs.Received(2).SendMessageAsync( + Arg.Any(), + Arg.Is(expectedIp)); + } - [Fact] - public async Task BlockIpAsync_UnblockNotCalled_WhenPermanent() - { - const string expectedIp = "ip"; + [Fact] + public async Task BlockIpAsync_UnblockNotCalled_WhenPermanent() + { + const string expectedIp = "ip"; - await _sut.BlockIpAsync(expectedIp, true); + await _sut.BlockIpAsync(expectedIp, true); - await _amazonSqs.Received(1).SendMessageAsync( - Arg.Any(), - Arg.Is(expectedIp)); - } + await _amazonSqs.Received(1).SendMessageAsync( + Arg.Any(), + Arg.Is(expectedIp)); + } - [Fact] - public async Task BlockIpAsync_NotBlocked_WhenAlreadyBlockedRecently() - { - const string expectedIp = "ip"; + [Fact] + public async Task BlockIpAsync_NotBlocked_WhenAlreadyBlockedRecently() + { + const string expectedIp = "ip"; - await _sut.BlockIpAsync(expectedIp, true); + await _sut.BlockIpAsync(expectedIp, true); - // The second call should hit the already blocked guard clause - await _sut.BlockIpAsync(expectedIp, true); + // The second call should hit the already blocked guard clause + await _sut.BlockIpAsync(expectedIp, true); - await _amazonSqs.Received(1).SendMessageAsync( - Arg.Any(), - Arg.Is(expectedIp)); + await _amazonSqs.Received(1).SendMessageAsync( + Arg.Any(), + Arg.Is(expectedIp)); + } } } diff --git a/test/Core.Test/Services/AppleIapServiceTests.cs b/test/Core.Test/Services/AppleIapServiceTests.cs index c376af2886..ff14e52e82 100644 --- a/test/Core.Test/Services/AppleIapServiceTests.cs +++ b/test/Core.Test/Services/AppleIapServiceTests.cs @@ -6,35 +6,36 @@ using NSubstitute; using NSubstitute.Core; using Xunit; -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class AppleIapServiceTests +namespace Bit.Core.Test.Services { - [Theory, BitAutoData] - public async Task GetReceiptStatusAsync_MoreThanFourAttempts_Throws(SutProvider sutProvider) + [SutProviderCustomize] + public class AppleIapServiceTests { - var result = await sutProvider.Sut.GetReceiptStatusAsync("test", false, 5, null); - Assert.Null(result); - - var errorLog = sutProvider.GetDependency>() - .ReceivedCalls() - .SingleOrDefault(LogOneWarning); - - Assert.True(errorLog != null, "Must contain one error log of warning level containing 'null'"); - - static bool LogOneWarning(ICall call) + [Theory, BitAutoData] + public async Task GetReceiptStatusAsync_MoreThanFourAttempts_Throws(SutProvider sutProvider) { - if (call.GetMethodInfo().Name != "Log") + var result = await sutProvider.Sut.GetReceiptStatusAsync("test", false, 5, null); + Assert.Null(result); + + var errorLog = sutProvider.GetDependency>() + .ReceivedCalls() + .SingleOrDefault(LogOneWarning); + + Assert.True(errorLog != null, "Must contain one error log of warning level containing 'null'"); + + static bool LogOneWarning(ICall call) { - return false; + if (call.GetMethodInfo().Name != "Log") + { + return false; + } + + var args = call.GetArguments(); + var logLevel = (LogLevel)args[0]; + var exception = (Exception)args[3]; + + return logLevel == LogLevel.Warning && exception.Message.Contains("null"); } - - var args = call.GetArguments(); - var logLevel = (LogLevel)args[0]; - var exception = (Exception)args[3]; - - return logLevel == LogLevel.Warning && exception.Message.Contains("null"); } } } diff --git a/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs b/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs index 21a5fa3f8e..75cda61aad 100644 --- a/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs +++ b/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs @@ -4,28 +4,29 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class AzureAttachmentStorageServiceTests +namespace Bit.Core.Test.Services { - private readonly AzureAttachmentStorageService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - - public AzureAttachmentStorageServiceTests() + public class AzureAttachmentStorageServiceTests { - _globalSettings = new GlobalSettings(); - _logger = Substitute.For>(); + private readonly AzureAttachmentStorageService _sut; - _sut = new AzureAttachmentStorageService(_globalSettings, _logger); - } + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public AzureAttachmentStorageServiceTests() + { + _globalSettings = new GlobalSettings(); + _logger = Substitute.For>(); + + _sut = new AzureAttachmentStorageService(_globalSettings, _logger); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs b/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs index 9efbe7180f..e4ad8bab9a 100644 --- a/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs +++ b/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs @@ -2,26 +2,27 @@ using Bit.Core.Settings; using Xunit; -namespace Bit.Core.Test.Services; - -public class AzureQueueBlockIpServiceTests +namespace Bit.Core.Test.Services { - private readonly AzureQueueBlockIpService _sut; - - private readonly GlobalSettings _globalSettings; - - public AzureQueueBlockIpServiceTests() + public class AzureQueueBlockIpServiceTests { - _globalSettings = new GlobalSettings(); + private readonly AzureQueueBlockIpService _sut; - _sut = new AzureQueueBlockIpService(_globalSettings); - } + private readonly GlobalSettings _globalSettings; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public AzureQueueBlockIpServiceTests() + { + _globalSettings = new GlobalSettings(); + + _sut = new AzureQueueBlockIpService(_globalSettings); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs b/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs index 2c4916dc6c..ce44b5f300 100644 --- a/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs +++ b/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs @@ -4,30 +4,31 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class AzureQueueEventWriteServiceTests +namespace Bit.Core.Test.Services { - private readonly AzureQueueEventWriteService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IEventRepository _eventRepository; - - public AzureQueueEventWriteServiceTests() + public class AzureQueueEventWriteServiceTests { - _globalSettings = new GlobalSettings(); - _eventRepository = Substitute.For(); + private readonly AzureQueueEventWriteService _sut; - _sut = new AzureQueueEventWriteService( - _globalSettings - ); - } + private readonly GlobalSettings _globalSettings; + private readonly IEventRepository _eventRepository; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public AzureQueueEventWriteServiceTests() + { + _globalSettings = new GlobalSettings(); + _eventRepository = Substitute.For(); + + _sut = new AzureQueueEventWriteService( + _globalSettings + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs b/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs index 7f9cb750aa..abb6ad31aa 100644 --- a/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs +++ b/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs @@ -4,31 +4,32 @@ using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class AzureQueuePushNotificationServiceTests +namespace Bit.Core.Test.Services { - private readonly AzureQueuePushNotificationService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - - public AzureQueuePushNotificationServiceTests() + public class AzureQueuePushNotificationServiceTests { - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); + private readonly AzureQueuePushNotificationService _sut; - _sut = new AzureQueuePushNotificationService( - _globalSettings, - _httpContextAccessor - ); - } + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public AzureQueuePushNotificationServiceTests() + { + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + + _sut = new AzureQueuePushNotificationService( + _globalSettings, + _httpContextAccessor + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/CipherServiceTests.cs b/test/Core.Test/Services/CipherServiceTests.cs index f036b973f9..1e34444810 100644 --- a/test/Core.Test/Services/CipherServiceTests.cs +++ b/test/Core.Test/Services/CipherServiceTests.cs @@ -9,208 +9,209 @@ using Core.Models.Data; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class CipherServiceTests +namespace Bit.Core.Test.Services { - [Theory, UserCipherAutoData] - public async Task SaveAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher) + public class CipherServiceTests { - var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate)); - Assert.Contains("out of date", exception.Message); - } - - [Theory, UserCipherAutoData] - public async Task SaveDetailsAsync_WrongRevisionDate_Throws(SutProvider sutProvider, - CipherDetails cipherDetails) - { - var lastKnownRevisionDate = cipherDetails.RevisionDate.AddDays(-1); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate)); - Assert.Contains("out of date", exception.Message); - } - - [Theory, UserCipherAutoData] - public async Task ShareAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher, - Organization organization, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, - lastKnownRevisionDate)); - Assert.Contains("out of date", exception.Message); - } - - [Theory, UserCipherAutoData("99ab4f6c-44f8-4ff5-be7a-75c37c33c69e")] - public async Task ShareManyAsync_WrongRevisionDate_Throws(SutProvider sutProvider, - IEnumerable ciphers, Guid organizationId, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organizationId) - .Returns(new Organization - { - PlanType = Enums.PlanType.EnterpriseAnnually, - MaxStorageGb = 100 - }); - - var cipherInfos = ciphers.Select(c => (c, (DateTime?)c.RevisionDate.AddDays(-1))); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, ciphers.First().UserId.Value)); - Assert.Contains("out of date", exception.Message); - } - - [Theory] - [InlineUserCipherAutoData("")] - [InlineUserCipherAutoData("Correct Time")] - public async Task SaveAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, Cipher cipher) - { - var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; - - await sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate); - await sutProvider.GetDependency().Received(1).ReplaceAsync(cipher); - } - - [Theory] - [InlineUserCipherAutoData("")] - [InlineUserCipherAutoData("Correct Time")] - public async Task SaveDetailsAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, CipherDetails cipherDetails) - { - var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipherDetails.RevisionDate; - - await sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate); - await sutProvider.GetDependency().Received(1).ReplaceAsync(cipherDetails); - } - - [Theory] - [InlineUserCipherAutoData("")] - [InlineUserCipherAutoData("Correct Time")] - public async Task ShareAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, Cipher cipher, Organization organization, List collectionIds) - { - var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; - var cipherRepository = sutProvider.GetDependency(); - cipherRepository.ReplaceAsync(cipher, collectionIds).Returns(true); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - - await sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, - lastKnownRevisionDate); - await cipherRepository.Received(1).ReplaceAsync(cipher, collectionIds); - } - - [Theory] - [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "")] - [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "CorrectTime")] - public async Task ShareManyAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, IEnumerable ciphers, Organization organization, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id) - .Returns(new Organization - { - PlanType = Enums.PlanType.EnterpriseAnnually, - MaxStorageGb = 100 - }); - - var cipherInfos = ciphers.Select(c => (c, - string.IsNullOrEmpty(revisionDateString) ? null : (DateTime?)c.RevisionDate)); - var sharingUserId = ciphers.First().UserId.Value; - - await sutProvider.Sut.ShareManyAsync(cipherInfos, organization.Id, collectionIds, sharingUserId); - await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, - Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); - } - - [Theory] - [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] - [InlineOrganizationCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] - public async Task RestoreAsync_UpdatesCipher(Guid restoringUserId, Cipher cipher, SutProvider sutProvider) - { - sutProvider.GetDependency().GetCanEditByIdAsync(restoringUserId, cipher.Id).Returns(true); - - var initialRevisionDate = new DateTime(1970, 1, 1, 0, 0, 0); - cipher.DeletedDate = initialRevisionDate; - cipher.RevisionDate = initialRevisionDate; - - await sutProvider.Sut.RestoreAsync(cipher, restoringUserId, cipher.OrganizationId.HasValue); - - Assert.Null(cipher.DeletedDate); - Assert.NotEqual(initialRevisionDate, cipher.RevisionDate); - } - - [Theory] - [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] - public async Task RestoreManyAsync_UpdatesCiphers(Guid restoringUserId, IEnumerable ciphers, - SutProvider sutProvider) - { - var previousRevisionDate = DateTime.UtcNow; - foreach (var cipher in ciphers) + [Theory, UserCipherAutoData] + public async Task SaveAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher) { - cipher.RevisionDate = previousRevisionDate; + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); } - var revisionDate = previousRevisionDate + TimeSpan.FromMinutes(1); - sutProvider.GetDependency().RestoreAsync(Arg.Any>(), restoringUserId) - .Returns(revisionDate); - - await sutProvider.Sut.RestoreManyAsync(ciphers, restoringUserId); - - foreach (var cipher in ciphers) + [Theory, UserCipherAutoData] + public async Task SaveDetailsAsync_WrongRevisionDate_Throws(SutProvider sutProvider, + CipherDetails cipherDetails) { + var lastKnownRevisionDate = cipherDetails.RevisionDate.AddDays(-1); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } + + [Theory, UserCipherAutoData] + public async Task ShareAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher, + Organization organization, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, + lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } + + [Theory, UserCipherAutoData("99ab4f6c-44f8-4ff5-be7a-75c37c33c69e")] + public async Task ShareManyAsync_WrongRevisionDate_Throws(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organizationId) + .Returns(new Organization + { + PlanType = Enums.PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + + var cipherInfos = ciphers.Select(c => (c, (DateTime?)c.RevisionDate.AddDays(-1))); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, ciphers.First().UserId.Value)); + Assert.Contains("out of date", exception.Message); + } + + [Theory] + [InlineUserCipherAutoData("")] + [InlineUserCipherAutoData("Correct Time")] + public async Task SaveAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, Cipher cipher) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + + await sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate); + await sutProvider.GetDependency().Received(1).ReplaceAsync(cipher); + } + + [Theory] + [InlineUserCipherAutoData("")] + [InlineUserCipherAutoData("Correct Time")] + public async Task SaveDetailsAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, CipherDetails cipherDetails) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipherDetails.RevisionDate; + + await sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate); + await sutProvider.GetDependency().Received(1).ReplaceAsync(cipherDetails); + } + + [Theory] + [InlineUserCipherAutoData("")] + [InlineUserCipherAutoData("Correct Time")] + public async Task ShareAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, Cipher cipher, Organization organization, List collectionIds) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + var cipherRepository = sutProvider.GetDependency(); + cipherRepository.ReplaceAsync(cipher, collectionIds).Returns(true); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + await sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, + lastKnownRevisionDate); + await cipherRepository.Received(1).ReplaceAsync(cipher, collectionIds); + } + + [Theory] + [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "")] + [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "CorrectTime")] + public async Task ShareManyAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, IEnumerable ciphers, Organization organization, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id) + .Returns(new Organization + { + PlanType = Enums.PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + + var cipherInfos = ciphers.Select(c => (c, + string.IsNullOrEmpty(revisionDateString) ? null : (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organization.Id, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, + Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); + } + + [Theory] + [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] + [InlineOrganizationCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] + public async Task RestoreAsync_UpdatesCipher(Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + { + sutProvider.GetDependency().GetCanEditByIdAsync(restoringUserId, cipher.Id).Returns(true); + + var initialRevisionDate = new DateTime(1970, 1, 1, 0, 0, 0); + cipher.DeletedDate = initialRevisionDate; + cipher.RevisionDate = initialRevisionDate; + + await sutProvider.Sut.RestoreAsync(cipher, restoringUserId, cipher.OrganizationId.HasValue); + Assert.Null(cipher.DeletedDate); - Assert.Equal(revisionDate, cipher.RevisionDate); + Assert.NotEqual(initialRevisionDate, cipher.RevisionDate); } - } - [Theory] - [InlineUserCipherAutoData] - public async Task ShareManyAsync_FreeOrgWithAttachment_Throws(SutProvider sutProvider, - IEnumerable ciphers, Guid organizationId, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(new Organization + [Theory] + [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] + public async Task RestoreManyAsync_UpdatesCiphers(Guid restoringUserId, IEnumerable ciphers, + SutProvider sutProvider) { - PlanType = Enums.PlanType.Free - }); - ciphers.FirstOrDefault().Attachments = - "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," - + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; - - var cipherInfos = ciphers.Select(c => (c, - (DateTime?)c.RevisionDate)); - var sharingUserId = ciphers.First().UserId.Value; - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId)); - Assert.Contains("This organization cannot use attachments", exception.Message); - } - - [Theory] - [InlineUserCipherAutoData] - public async Task ShareManyAsync_PaidOrgWithAttachment_Passes(SutProvider sutProvider, - IEnumerable ciphers, Guid organizationId, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organizationId) - .Returns(new Organization + var previousRevisionDate = DateTime.UtcNow; + foreach (var cipher in ciphers) { - PlanType = Enums.PlanType.EnterpriseAnnually, - MaxStorageGb = 100 + cipher.RevisionDate = previousRevisionDate; + } + + var revisionDate = previousRevisionDate + TimeSpan.FromMinutes(1); + sutProvider.GetDependency().RestoreAsync(Arg.Any>(), restoringUserId) + .Returns(revisionDate); + + await sutProvider.Sut.RestoreManyAsync(ciphers, restoringUserId); + + foreach (var cipher in ciphers) + { + Assert.Null(cipher.DeletedDate); + Assert.Equal(revisionDate, cipher.RevisionDate); + } + } + + [Theory] + [InlineUserCipherAutoData] + public async Task ShareManyAsync_FreeOrgWithAttachment_Throws(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(new Organization + { + PlanType = Enums.PlanType.Free }); - ciphers.FirstOrDefault().Attachments = - "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," - + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; + ciphers.FirstOrDefault().Attachments = + "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," + + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; - var cipherInfos = ciphers.Select(c => (c, - (DateTime?)c.RevisionDate)); - var sharingUserId = ciphers.First().UserId.Value; + var cipherInfos = ciphers.Select(c => (c, + (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; - await sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId); - await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, - Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId)); + Assert.Contains("This organization cannot use attachments", exception.Message); + } + + [Theory] + [InlineUserCipherAutoData] + public async Task ShareManyAsync_PaidOrgWithAttachment_Passes(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organizationId) + .Returns(new Organization + { + PlanType = Enums.PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + ciphers.FirstOrDefault().Attachments = + "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," + + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; + + var cipherInfos = ciphers.Select(c => (c, + (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, + Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); + } } } diff --git a/test/Core.Test/Services/CollectionServiceTests.cs b/test/Core.Test/Services/CollectionServiceTests.cs index b6a68b58e1..cf4228b3f2 100644 --- a/test/Core.Test/Services/CollectionServiceTests.cs +++ b/test/Core.Test/Services/CollectionServiceTests.cs @@ -10,160 +10,161 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class CollectionServiceTest +namespace Bit.Core.Test.Services { - [Theory, CollectionAutoData] - public async Task SaveAsync_DefaultId_CreatesCollectionInTheRepository(Collection collection, Organization organization, SutProvider sutProvider) + public class CollectionServiceTest { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_DefaultId_CreatesCollectionInTheRepository(Collection collection, Organization organization, SutProvider sutProvider) + { + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection); + await sutProvider.Sut.SaveAsync(collection); - await sutProvider.GetDependency().Received().CreateAsync(collection); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_DefaultIdWithGroups_CreateCollectionWithGroupsInRepository(Collection collection, - IEnumerable groups, Organization organization, SutProvider sutProvider) - { - collection.Id = default; - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_DefaultIdWithGroups_CreateCollectionWithGroupsInRepository(Collection collection, + IEnumerable groups, Organization organization, SutProvider sutProvider) + { + collection.Id = default; + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, groups); + await sutProvider.Sut.SaveAsync(collection, groups); - await sutProvider.GetDependency().Received().CreateAsync(collection, groups); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection, groups); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_NonDefaultId_ReplacesCollectionInRepository(Collection collection, Organization organization, SutProvider sutProvider) - { - var creationDate = collection.CreationDate; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_NonDefaultId_ReplacesCollectionInRepository(Collection collection, Organization organization, SutProvider sutProvider) + { + var creationDate = collection.CreationDate; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection); + await sutProvider.Sut.SaveAsync(collection); - await sutProvider.GetDependency().Received().ReplaceAsync(collection); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Updated); - Assert.Equal(collection.CreationDate, creationDate); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().ReplaceAsync(collection); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Updated); + Assert.Equal(collection.CreationDate, creationDate); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_OrganizationNotUseGroup_CreateCollectionWithoutGroupsInRepository(Collection collection, IEnumerable groups, - Organization organization, SutProvider sutProvider) - { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_OrganizationNotUseGroup_CreateCollectionWithoutGroupsInRepository(Collection collection, IEnumerable groups, + Organization organization, SutProvider sutProvider) + { + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, groups); + await sutProvider.Sut.SaveAsync(collection, groups); - await sutProvider.GetDependency().Received().CreateAsync(collection); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_DefaultIdWithUserId_UpdateUserInCollectionRepository(Collection collection, - Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.Id = default; - organizationUser.Status = OrganizationUserStatusType.Confirmed; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, organizationUser.Id) - .Returns(organizationUser); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_DefaultIdWithUserId_UpdateUserInCollectionRepository(Collection collection, + Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + collection.Id = default; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, organizationUser.Id) + .Returns(organizationUser); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, null, organizationUser.Id); + await sutProvider.Sut.SaveAsync(collection, null, organizationUser.Id); - await sutProvider.GetDependency().Received().CreateAsync(collection); - await sutProvider.GetDependency().Received() - .GetByOrganizationAsync(organization.Id, organizationUser.Id); - await sutProvider.GetDependency().Received().UpdateUsersAsync(collection.Id, Arg.Any>()); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection); + await sutProvider.GetDependency().Received() + .GetByOrganizationAsync(organization.Id, organizationUser.Id); + await sutProvider.GetDependency().Received().UpdateUsersAsync(collection.Id, Arg.Any>()); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Collection collection, SutProvider sutProvider) - { - var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); - Assert.Contains("Organization not found", ex.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Collection collection, SutProvider sutProvider) + { + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); + Assert.Contains("Organization not found", ex.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_ExceedsOrganizationMaxCollections_ThrowsBadRequest(Collection collection, Organization organization, SutProvider sutProvider) - { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetCountByOrganizationIdAsync(organization.Id) - .Returns(organization.MaxCollections.Value); + [Theory, CollectionAutoData] + public async Task SaveAsync_ExceedsOrganizationMaxCollections_ThrowsBadRequest(Collection collection, Organization organization, SutProvider sutProvider) + { + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetCountByOrganizationIdAsync(organization.Id) + .Returns(organization.MaxCollections.Value); - var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); - Assert.Equal($@"You have reached the maximum number of collections ({organization.MaxCollections.Value}) for this organization.", ex.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); - } + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); + Assert.Equal($@"You have reached the maximum number of collections ({organization.MaxCollections.Value}) for this organization.", ex.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); + } - [Theory, CollectionAutoData] - public async Task DeleteUserAsync_DeletesValidUserWhoBelongsToCollection(Collection collection, - Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.OrganizationId = organization.Id; - organizationUser.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CollectionAutoData] + public async Task DeleteUserAsync_DeletesValidUserWhoBelongsToCollection(Collection collection, + Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + collection.OrganizationId = organization.Id; + organizationUser.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - await sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id); + await sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id); - await sutProvider.GetDependency().Received() - .DeleteUserAsync(collection.Id, organizationUser.Id); - await sutProvider.GetDependency().Received().LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Updated); - } + await sutProvider.GetDependency().Received() + .DeleteUserAsync(collection.Id, organizationUser.Id); + await sutProvider.GetDependency().Received().LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Updated); + } - [Theory, CollectionAutoData] - public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Collection collection, Organization organization, - OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CollectionAutoData] + public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Collection collection, Organization organization, + OrganizationUser organizationUser, SutProvider sutProvider) + { + collection.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - // user not in organization - await Assert.ThrowsAsync(() => - sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id)); - // invalid user - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(collection, Guid.NewGuid())); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .LogOrganizationUserEventAsync(default, default); + // user not in organization + await Assert.ThrowsAsync(() => + sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id)); + // invalid user + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(collection, Guid.NewGuid())); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .LogOrganizationUserEventAsync(default, default); + } } } diff --git a/test/Core.Test/Services/DeviceServiceTests.cs b/test/Core.Test/Services/DeviceServiceTests.cs index 8bc9212839..f3a50d4d0c 100644 --- a/test/Core.Test/Services/DeviceServiceTests.cs +++ b/test/Core.Test/Services/DeviceServiceTests.cs @@ -5,32 +5,33 @@ using Bit.Core.Services; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class DeviceServiceTests +namespace Bit.Core.Test.Services { - [Fact] - public async Task DeviceSaveShouldUpdateRevisionDateAndPushRegistration() + public class DeviceServiceTests { - var deviceRepo = Substitute.For(); - var pushRepo = Substitute.For(); - var deviceService = new DeviceService(deviceRepo, pushRepo); - - var id = Guid.NewGuid(); - var userId = Guid.NewGuid(); - var device = new Device + [Fact] + public async Task DeviceSaveShouldUpdateRevisionDateAndPushRegistration() { - Id = id, - Name = "test device", - Type = DeviceType.Android, - UserId = userId, - PushToken = "testtoken", - Identifier = "testid" - }; - await deviceService.SaveAsync(device); + var deviceRepo = Substitute.For(); + var pushRepo = Substitute.For(); + var deviceService = new DeviceService(deviceRepo, pushRepo); - Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); - await pushRepo.Received().CreateOrUpdateRegistrationAsync("testtoken", id.ToString(), - userId.ToString(), "testid", DeviceType.Android); + var id = Guid.NewGuid(); + var userId = Guid.NewGuid(); + var device = new Device + { + Id = id, + Name = "test device", + Type = DeviceType.Android, + UserId = userId, + PushToken = "testtoken", + Identifier = "testid" + }; + await deviceService.SaveAsync(device); + + Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); + await pushRepo.Received().CreateOrUpdateRegistrationAsync("testtoken", id.ToString(), + userId.ToString(), "testid", DeviceType.Android); + } } } diff --git a/test/Core.Test/Services/EmergencyAccessServiceTests.cs b/test/Core.Test/Services/EmergencyAccessServiceTests.cs index 6f8576d8e4..bdbb6953b2 100644 --- a/test/Core.Test/Services/EmergencyAccessServiceTests.cs +++ b/test/Core.Test/Services/EmergencyAccessServiceTests.cs @@ -9,162 +9,163 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class EmergencyAccessServiceTests +namespace Bit.Core.Test.Services { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_PremiumCannotUpdate( - SutProvider sutProvider, User savingUser) + public class EmergencyAccessServiceTests { - savingUser.Premium = false; - var emergencyAccess = new EmergencyAccess + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_PremiumCannotUpdate( + SutProvider sutProvider, User savingUser) { - Type = Enums.EmergencyAccessType.Takeover, - GrantorId = savingUser.Id, - }; - - sutProvider.GetDependency().GetUserByIdAsync(savingUser.Id).Returns(savingUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); - - Assert.Contains("Not a premium user.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User invitingUser, string email, int waitTime) - { - invitingUser.UsesKeyConnector = true; - sutProvider.GetDependency().CanAccessPremium(invitingUser).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteAsync(invitingUser, email, Enums.EmergencyAccessType.Takeover, waitTime)); - - Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUserAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User confirmingUser, string key) - { - confirmingUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - Status = Enums.EmergencyAccessStatusType.Accepted, - GrantorId = confirmingUser.Id, - Type = Enums.EmergencyAccessType.Takeover, - }; - - sutProvider.GetDependency().GetByIdAsync(confirmingUser.Id).Returns(confirmingUser); - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(new Guid(), key, confirmingUser.Id)); - - Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User savingUser) - { - savingUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - Type = Enums.EmergencyAccessType.Takeover, - GrantorId = savingUser.Id, - }; - - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(savingUser.Id).Returns(savingUser); - userService.CanAccessPremium(savingUser).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); - - Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InitiateAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User initiatingUser, User grantor) - { - grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - Status = Enums.EmergencyAccessStatusType.Confirmed, - GranteeId = initiatingUser.Id, - GrantorId = grantor.Id, - Type = Enums.EmergencyAccessType.Takeover, - }; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); - - Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task TakeoverAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User requestingUser, User grantor) - { - grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - GrantorId = grantor.Id, - GranteeId = requestingUser.Id, - Status = Enums.EmergencyAccessStatusType.RecoveryApproved, - Type = Enums.EmergencyAccessType.Takeover, - }; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.TakeoverAsync(new Guid(), requestingUser)); - - Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task PasswordAsync_Disables_2FA_Providers_And_Unknown_Device_Verification_On_The_Grantor( - SutProvider sutProvider, User requestingUser, User grantor) - { - grantor.UsesKeyConnector = true; - grantor.UnknownDeviceVerificationEnabled = true; - grantor.SetTwoFactorProviders(new Dictionary - { - [TwoFactorProviderType.Email] = new TwoFactorProvider + savingUser.Premium = false; + var emergencyAccess = new EmergencyAccess { - MetaData = new Dictionary { ["Email"] = "asdfasf" }, - Enabled = true - } - }); - var emergencyAccess = new EmergencyAccess + Type = Enums.EmergencyAccessType.Takeover, + GrantorId = savingUser.Id, + }; + + sutProvider.GetDependency().GetUserByIdAsync(savingUser.Id).Returns(savingUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + + Assert.Contains("Not a premium user.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User invitingUser, string email, int waitTime) { - GrantorId = grantor.Id, - GranteeId = requestingUser.Id, - Status = Enums.EmergencyAccessStatusType.RecoveryApproved, - Type = Enums.EmergencyAccessType.Takeover, - }; + invitingUser.UsesKeyConnector = true; + sutProvider.GetDependency().CanAccessPremium(invitingUser).Returns(true); - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteAsync(invitingUser, email, Enums.EmergencyAccessType.Takeover, waitTime)); - await sutProvider.Sut.PasswordAsync(Guid.NewGuid(), requestingUser, "blablahash", "blablakey"); + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + } - Assert.False(grantor.UnknownDeviceVerificationEnabled); - Assert.Empty(grantor.GetTwoFactorProviders()); - await sutProvider.GetDependency().Received().ReplaceAsync(grantor); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUserAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User confirmingUser, string key) + { + confirmingUser.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + Status = Enums.EmergencyAccessStatusType.Accepted, + GrantorId = confirmingUser.Id, + Type = Enums.EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency().GetByIdAsync(confirmingUser.Id).Returns(confirmingUser); + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(new Guid(), key, confirmingUser.Id)); + + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User savingUser) + { + savingUser.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + Type = Enums.EmergencyAccessType.Takeover, + GrantorId = savingUser.Id, + }; + + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(savingUser.Id).Returns(savingUser); + userService.CanAccessPremium(savingUser).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InitiateAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User initiatingUser, User grantor) + { + grantor.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + Status = Enums.EmergencyAccessStatusType.Confirmed, + GranteeId = initiatingUser.Id, + GrantorId = grantor.Id, + Type = Enums.EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); + + Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task TakeoverAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User requestingUser, User grantor) + { + grantor.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + GrantorId = grantor.Id, + GranteeId = requestingUser.Id, + Status = Enums.EmergencyAccessStatusType.RecoveryApproved, + Type = Enums.EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.TakeoverAsync(new Guid(), requestingUser)); + + Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task PasswordAsync_Disables_2FA_Providers_And_Unknown_Device_Verification_On_The_Grantor( + SutProvider sutProvider, User requestingUser, User grantor) + { + grantor.UsesKeyConnector = true; + grantor.UnknownDeviceVerificationEnabled = true; + grantor.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = "asdfasf" }, + Enabled = true + } + }); + var emergencyAccess = new EmergencyAccess + { + GrantorId = grantor.Id, + GranteeId = requestingUser.Id, + Status = Enums.EmergencyAccessStatusType.RecoveryApproved, + Type = Enums.EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + + await sutProvider.Sut.PasswordAsync(Guid.NewGuid(), requestingUser, "blablahash", "blablakey"); + + Assert.False(grantor.UnknownDeviceVerificationEnabled); + Assert.Empty(grantor.GetTwoFactorProviders()); + await sutProvider.GetDependency().Received().ReplaceAsync(grantor); + } } } diff --git a/test/Core.Test/Services/EventServiceTests.cs b/test/Core.Test/Services/EventServiceTests.cs index 214f120b88..988d84e131 100644 --- a/test/Core.Test/Services/EventServiceTests.cs +++ b/test/Core.Test/Services/EventServiceTests.cs @@ -11,98 +11,99 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class EventServiceTests +namespace Bit.Core.Test.Services { - public static IEnumerable InstallationIdTestCases => TestCaseHelper.GetCombinationsOfMultipleLists( - new object[] { Guid.NewGuid(), null }, - Enum.GetValues().Select(e => (object)e) - ).Select(p => p.ToArray()); - - [Theory] - [BitMemberAutoData(nameof(InstallationIdTestCases))] - public async Task LogOrganizationEvent_ProvidesInstallationId(Guid? installationId, EventType eventType, - Organization organization, SutProvider sutProvider) + [SutProviderCustomize] + public class EventServiceTests { - organization.Enabled = true; - organization.UseEvents = true; + public static IEnumerable InstallationIdTestCases => TestCaseHelper.GetCombinationsOfMultipleLists( + new object[] { Guid.NewGuid(), null }, + Enum.GetValues().Select(e => (object)e) + ).Select(p => p.ToArray()); - sutProvider.GetDependency().InstallationId.Returns(installationId); - - await sutProvider.Sut.LogOrganizationEventAsync(organization, eventType); - - await sutProvider.GetDependency().Received(1).CreateAsync(Arg.Is(e => - e.OrganizationId == organization.Id && - e.Type == eventType && - e.InstallationId == installationId)); - } - - [Theory, BitAutoData] - public async Task LogOrganizationUserEvent_LogsRequiredInfo(OrganizationUser orgUser, EventType eventType, DateTime date, - Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) - { - var orgAbilities = new Dictionary() + [Theory] + [BitMemberAutoData(nameof(InstallationIdTestCases))] + public async Task LogOrganizationEvent_ProvidesInstallationId(Guid? installationId, EventType eventType, + Organization organization, SutProvider sutProvider) { - {orgUser.OrganizationId, new OrganizationAbility() { UseEvents = true, Enabled = true } } - }; - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); - sutProvider.GetDependency().UserId.Returns(actingUserId); - sutProvider.GetDependency().IpAddress.Returns(ipAddress); - sutProvider.GetDependency().DeviceType.Returns(deviceType); - sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); + organization.Enabled = true; + organization.UseEvents = true; - await sutProvider.Sut.LogOrganizationUserEventAsync(orgUser, eventType, date); + sutProvider.GetDependency().InstallationId.Returns(installationId); - var expected = new List() { - new EventMessage() - { - IpAddress = ipAddress, - DeviceType = deviceType, - OrganizationId = orgUser.OrganizationId, - UserId = orgUser.UserId, - OrganizationUserId = orgUser.Id, - ProviderId = providerId, - Type = eventType, - ActingUserId = actingUserId, - Date = date - } - }; + await sutProvider.Sut.LogOrganizationEventAsync(organization, eventType); - await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); - } + await sutProvider.GetDependency().Received(1).CreateAsync(Arg.Is(e => + e.OrganizationId == organization.Id && + e.Type == eventType && + e.InstallationId == installationId)); + } - [Theory, BitAutoData] - public async Task LogProviderUserEvent_LogsRequiredInfo(ProviderUser providerUser, EventType eventType, DateTime date, - Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) - { - var providerAbilities = new Dictionary() + [Theory, BitAutoData] + public async Task LogOrganizationUserEvent_LogsRequiredInfo(OrganizationUser orgUser, EventType eventType, DateTime date, + Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) { - {providerUser.ProviderId, new ProviderAbility() { UseEvents = true, Enabled = true } } - }; - sutProvider.GetDependency().GetProviderAbilitiesAsync().Returns(providerAbilities); - sutProvider.GetDependency().UserId.Returns(actingUserId); - sutProvider.GetDependency().IpAddress.Returns(ipAddress); - sutProvider.GetDependency().DeviceType.Returns(deviceType); - sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); - - await sutProvider.Sut.LogProviderUserEventAsync(providerUser, eventType, date); - - var expected = new List() { - new EventMessage() + var orgAbilities = new Dictionary() { - IpAddress = ipAddress, - DeviceType = deviceType, - ProviderId = providerUser.ProviderId, - UserId = providerUser.UserId, - ProviderUserId = providerUser.Id, - Type = eventType, - ActingUserId = actingUserId, - Date = date - } - }; + {orgUser.OrganizationId, new OrganizationAbility() { UseEvents = true, Enabled = true } } + }; + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + sutProvider.GetDependency().UserId.Returns(actingUserId); + sutProvider.GetDependency().IpAddress.Returns(ipAddress); + sutProvider.GetDependency().DeviceType.Returns(deviceType); + sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); - await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); + await sutProvider.Sut.LogOrganizationUserEventAsync(orgUser, eventType, date); + + var expected = new List() { + new EventMessage() + { + IpAddress = ipAddress, + DeviceType = deviceType, + OrganizationId = orgUser.OrganizationId, + UserId = orgUser.UserId, + OrganizationUserId = orgUser.Id, + ProviderId = providerId, + Type = eventType, + ActingUserId = actingUserId, + Date = date + } + }; + + await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); + } + + [Theory, BitAutoData] + public async Task LogProviderUserEvent_LogsRequiredInfo(ProviderUser providerUser, EventType eventType, DateTime date, + Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) + { + var providerAbilities = new Dictionary() + { + {providerUser.ProviderId, new ProviderAbility() { UseEvents = true, Enabled = true } } + }; + sutProvider.GetDependency().GetProviderAbilitiesAsync().Returns(providerAbilities); + sutProvider.GetDependency().UserId.Returns(actingUserId); + sutProvider.GetDependency().IpAddress.Returns(ipAddress); + sutProvider.GetDependency().DeviceType.Returns(deviceType); + sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); + + await sutProvider.Sut.LogProviderUserEventAsync(providerUser, eventType, date); + + var expected = new List() { + new EventMessage() + { + IpAddress = ipAddress, + DeviceType = deviceType, + ProviderId = providerUser.ProviderId, + UserId = providerUser.UserId, + ProviderUserId = providerUser.Id, + Type = eventType, + ActingUserId = actingUserId, + Date = date + } + }; + + await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); + } } } diff --git a/test/Core.Test/Services/GroupServiceTests.cs b/test/Core.Test/Services/GroupServiceTests.cs index 04aad97264..84f11cbb1a 100644 --- a/test/Core.Test/Services/GroupServiceTests.cs +++ b/test/Core.Test/Services/GroupServiceTests.cs @@ -10,127 +10,128 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class GroupServiceTests +namespace Bit.Core.Test.Services { - [Theory, GroupOrganizationAutoData] - public async Task SaveAsync_DefaultGroupId_CreatesGroupInRepository(Group group, Organization organization, SutProvider sutProvider) + public class GroupServiceTests { - group.Id = default(Guid); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - organization.UseGroups = true; - var utcNow = DateTime.UtcNow; + [Theory, GroupOrganizationAutoData] + public async Task SaveAsync_DefaultGroupId_CreatesGroupInRepository(Group group, Organization organization, SutProvider sutProvider) + { + group.Id = default(Guid); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + organization.UseGroups = true; + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(group); + await sutProvider.Sut.SaveAsync(group); - await sutProvider.GetDependency().Received().CreateAsync(group); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Created); - Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(group); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Created); + Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, GroupOrganizationAutoData] - public async Task SaveAsync_DefaultGroupIdAndCollections_CreatesGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) - { - group.Id = default(Guid); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - organization.UseGroups = true; - var utcNow = DateTime.UtcNow; + [Theory, GroupOrganizationAutoData] + public async Task SaveAsync_DefaultGroupIdAndCollections_CreatesGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) + { + group.Id = default(Guid); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + organization.UseGroups = true; + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(group, collections); + await sutProvider.Sut.SaveAsync(group, collections); - await sutProvider.GetDependency().Received().CreateAsync(group, collections); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Created); - Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(group, collections); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Created); + Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, GroupOrganizationAutoData] - public async Task SaveAsync_NonDefaultGroupId_ReplaceGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) - { - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + [Theory, GroupOrganizationAutoData] + public async Task SaveAsync_NonDefaultGroupId_ReplaceGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) + { + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - await sutProvider.Sut.SaveAsync(group, collections); + await sutProvider.Sut.SaveAsync(group, collections); - await sutProvider.GetDependency().Received().ReplaceAsync(group, collections); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Updated); - Assert.True(group.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().ReplaceAsync(group, collections); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Updated); + Assert.True(group.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Group group, SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(group)); - Assert.Contains("Organization not found", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Group group, SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(group)); + Assert.Contains("Organization not found", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); + } - [Theory, GroupOrganizationNotUseGroupsAutoData] - public async Task SaveAsync_OrganizationDoesNotUseGroups_ThrowsBadRequest(Group group, Organization organization, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + [Theory, GroupOrganizationNotUseGroupsAutoData] + public async Task SaveAsync_OrganizationDoesNotUseGroups_ThrowsBadRequest(Group group, Organization organization, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(group)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(group)); - Assert.Contains("This organization cannot use groups", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); - } + Assert.Contains("This organization cannot use groups", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteAsync_ValidData_DeletesGroup(Group group, SutProvider sutProvider) - { - await sutProvider.Sut.DeleteAsync(group); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteAsync_ValidData_DeletesGroup(Group group, SutProvider sutProvider) + { + await sutProvider.Sut.DeleteAsync(group); - await sutProvider.GetDependency().Received().DeleteAsync(group); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Deleted); - } + await sutProvider.GetDependency().Received().DeleteAsync(group); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Deleted); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUserAsync_ValidData_DeletesUserInGroupRepository(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - group.OrganizationId = organization.Id; - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - organizationUser.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUserAsync_ValidData_DeletesUserInGroupRepository(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + group.OrganizationId = organization.Id; + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + organizationUser.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - await sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id); + await sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id); - await sutProvider.GetDependency().Received().DeleteUserAsync(group.Id, organizationUser.Id); - await sutProvider.GetDependency().Received() - .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UpdatedGroups); - } + await sutProvider.GetDependency().Received().DeleteUserAsync(group.Id, organizationUser.Id); + await sutProvider.GetDependency().Received() + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UpdatedGroups); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - group.OrganizationId = organization.Id; - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - // organizationUser.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + group.OrganizationId = organization.Id; + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + // organizationUser.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - // user not in organization - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id)); - // invalid user - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, Guid.NewGuid())); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .DeleteUserAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .LogOrganizationUserEventAsync(default, default); + // user not in organization + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id)); + // invalid user + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, Guid.NewGuid())); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .DeleteUserAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .LogOrganizationUserEventAsync(default, default); + } } } diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index 5127eb2b43..39348ad8e6 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -8,168 +8,169 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class HandlebarsMailServiceTests +namespace Bit.Core.Test.Services { - private readonly HandlebarsMailService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IMailDeliveryService _mailDeliveryService; - private readonly IMailEnqueuingService _mailEnqueuingService; - - public HandlebarsMailServiceTests() + public class HandlebarsMailServiceTests { - _globalSettings = new GlobalSettings(); - _mailDeliveryService = Substitute.For(); - _mailEnqueuingService = Substitute.For(); + private readonly HandlebarsMailService _sut; - _sut = new HandlebarsMailService( - _globalSettings, - _mailDeliveryService, - _mailEnqueuingService - ); - } + private readonly GlobalSettings _globalSettings; + private readonly IMailDeliveryService _mailDeliveryService; + private readonly IMailEnqueuingService _mailEnqueuingService; - [Fact(Skip = "For local development")] - public async Task SendAllEmails() - { - // This test is only opt in and is more for development purposes. - // This will send all emails to the test email address so that they can be viewed. - var namedParameters = new Dictionary<(string, Type), object> + public HandlebarsMailServiceTests() { - // TODO: Swith to use env variable - { ("email", typeof(string)), "test@bitwarden.com" }, - { ("user", typeof(User)), new User - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - }}, - { ("userId", typeof(Guid)), Guid.NewGuid() }, - { ("token", typeof(string)), "test_token" }, - { ("fromEmail", typeof(string)), "test@bitwarden.com" }, - { ("toEmail", typeof(string)), "test@bitwarden.com" }, - { ("newEmailAddress", typeof(string)), "test@bitwarden.com" }, - { ("hint", typeof(string)), "Test Hint" }, - { ("organizationName", typeof(string)), "Test Organization Name" }, - { ("orgUser", typeof(OrganizationUser)), new OrganizationUser - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - OrganizationId = Guid.NewGuid(), + _globalSettings = new GlobalSettings(); + _mailDeliveryService = Substitute.For(); + _mailEnqueuingService = Substitute.For(); - }}, - { ("token", typeof(ExpiringToken)), new ExpiringToken("test_token", DateTime.UtcNow.AddDays(1))}, - { ("organization", typeof(Organization)), new Organization - { - Id = Guid.NewGuid(), - Name = "Test Organization Name", - Seats = 5 - }}, - { ("initialSeatCount", typeof(int)), 5}, - { ("ownerEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, - { ("maxSeatCount", typeof(int)), 5 }, - { ("userIdentifier", typeof(string)), "test_user" }, - { ("adminEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, - { ("returnUrl", typeof(string)), "https://bitwarden.com/" }, - { ("amount", typeof(decimal)), 1.00M }, - { ("dueDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1) }, - { ("items", typeof(List)), new List { "test@bitwarden.com" }}, - { ("mentionInvoices", typeof(bool)), true }, - { ("emails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, - { ("deviceType", typeof(string)), "Mobile" }, - { ("timestamp", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, - { ("ip", typeof(string)), "127.0.0.1" }, - { ("emergencyAccess", typeof(EmergencyAccess)), new EmergencyAccess - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - }}, - { ("granteeEmail", typeof(string)), "test@bitwarden.com" }, - { ("grantorName", typeof(string)), "Test User" }, - { ("initiatingName", typeof(string)), "Test" }, - { ("approvingName", typeof(string)), "Test Name" }, - { ("rejectingName", typeof(string)), "Test Name" }, - { ("provider", typeof(Provider)), new Provider - { - Id = Guid.NewGuid(), - }}, - { ("name", typeof(string)), "Test Name" }, - { ("ea", typeof(EmergencyAccess)), new EmergencyAccess - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - }}, - { ("userName", typeof(string)), "testUser" }, - { ("orgName", typeof(string)), "Test Org Name" }, - { ("providerName", typeof(string)), "testProvider" }, - { ("providerUser", typeof(ProviderUser)), new ProviderUser - { - ProviderId = Guid.NewGuid(), - Id = Guid.NewGuid(), - }}, - { ("familyUserEmail", typeof(string)), "test@bitwarden.com" }, - { ("sponsorEmail", typeof(string)), "test@bitwarden.com" }, - { ("familyOrgName", typeof(string)), "Test Org Name" }, - // Swap existingAccount to true or false to generate different versions of the SendFamiliesForEnterpriseOfferEmailAsync emails. - { ("existingAccount", typeof(bool)), false }, - { ("sponsorshipEndDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, - { ("sponsorOrgName", typeof(string)), "Sponsor Test Org Name" }, - { ("expirationDate", typeof(DateTime)), DateTime.Now.AddDays(3) }, - { ("utcNow", typeof(DateTime)), DateTime.UtcNow }, - }; - - var globalSettings = new GlobalSettings - { - Mail = new GlobalSettings.MailSettings - { - Smtp = new GlobalSettings.MailSettings.SmtpSettings - { - Host = "localhost", - TrustServer = true, - Port = 10250, - }, - ReplyToEmail = "noreply@bitwarden.com", - }, - SiteName = "Bitwarden", - }; - - var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, Substitute.For>()); - - var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService()); - - var sendMethods = typeof(IMailService).GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Where(m => m.Name.StartsWith("Send") && m.Name != "SendEnqueuedMailMessageAsync"); - - foreach (var sendMethod in sendMethods) - { - await InvokeMethod(sendMethod); + _sut = new HandlebarsMailService( + _globalSettings, + _mailDeliveryService, + _mailEnqueuingService + ); } - async Task InvokeMethod(MethodInfo method) + [Fact(Skip = "For local development")] + public async Task SendAllEmails() { - var parameters = method.GetParameters(); - var args = new object[parameters.Length]; - - for (var i = 0; i < parameters.Length; i++) + // This test is only opt in and is more for development purposes. + // This will send all emails to the test email address so that they can be viewed. + var namedParameters = new Dictionary<(string, Type), object> { - if (!namedParameters.TryGetValue((parameters[i].Name, parameters[i].ParameterType), out var value)) + // TODO: Swith to use env variable + { ("email", typeof(string)), "test@bitwarden.com" }, + { ("user", typeof(User)), new User { - throw new InvalidOperationException($"Couldn't find a parameter for name '{parameters[i].Name}' and type '{parameters[i].ParameterType.FullName}'"); - } + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + }}, + { ("userId", typeof(Guid)), Guid.NewGuid() }, + { ("token", typeof(string)), "test_token" }, + { ("fromEmail", typeof(string)), "test@bitwarden.com" }, + { ("toEmail", typeof(string)), "test@bitwarden.com" }, + { ("newEmailAddress", typeof(string)), "test@bitwarden.com" }, + { ("hint", typeof(string)), "Test Hint" }, + { ("organizationName", typeof(string)), "Test Organization Name" }, + { ("orgUser", typeof(OrganizationUser)), new OrganizationUser + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + OrganizationId = Guid.NewGuid(), - args[i] = value; + }}, + { ("token", typeof(ExpiringToken)), new ExpiringToken("test_token", DateTime.UtcNow.AddDays(1))}, + { ("organization", typeof(Organization)), new Organization + { + Id = Guid.NewGuid(), + Name = "Test Organization Name", + Seats = 5 + }}, + { ("initialSeatCount", typeof(int)), 5}, + { ("ownerEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, + { ("maxSeatCount", typeof(int)), 5 }, + { ("userIdentifier", typeof(string)), "test_user" }, + { ("adminEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, + { ("returnUrl", typeof(string)), "https://bitwarden.com/" }, + { ("amount", typeof(decimal)), 1.00M }, + { ("dueDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1) }, + { ("items", typeof(List)), new List { "test@bitwarden.com" }}, + { ("mentionInvoices", typeof(bool)), true }, + { ("emails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, + { ("deviceType", typeof(string)), "Mobile" }, + { ("timestamp", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, + { ("ip", typeof(string)), "127.0.0.1" }, + { ("emergencyAccess", typeof(EmergencyAccess)), new EmergencyAccess + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + }}, + { ("granteeEmail", typeof(string)), "test@bitwarden.com" }, + { ("grantorName", typeof(string)), "Test User" }, + { ("initiatingName", typeof(string)), "Test" }, + { ("approvingName", typeof(string)), "Test Name" }, + { ("rejectingName", typeof(string)), "Test Name" }, + { ("provider", typeof(Provider)), new Provider + { + Id = Guid.NewGuid(), + }}, + { ("name", typeof(string)), "Test Name" }, + { ("ea", typeof(EmergencyAccess)), new EmergencyAccess + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + }}, + { ("userName", typeof(string)), "testUser" }, + { ("orgName", typeof(string)), "Test Org Name" }, + { ("providerName", typeof(string)), "testProvider" }, + { ("providerUser", typeof(ProviderUser)), new ProviderUser + { + ProviderId = Guid.NewGuid(), + Id = Guid.NewGuid(), + }}, + { ("familyUserEmail", typeof(string)), "test@bitwarden.com" }, + { ("sponsorEmail", typeof(string)), "test@bitwarden.com" }, + { ("familyOrgName", typeof(string)), "Test Org Name" }, + // Swap existingAccount to true or false to generate different versions of the SendFamiliesForEnterpriseOfferEmailAsync emails. + { ("existingAccount", typeof(bool)), false }, + { ("sponsorshipEndDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, + { ("sponsorOrgName", typeof(string)), "Sponsor Test Org Name" }, + { ("expirationDate", typeof(DateTime)), DateTime.Now.AddDays(3) }, + { ("utcNow", typeof(DateTime)), DateTime.UtcNow }, + }; + + var globalSettings = new GlobalSettings + { + Mail = new GlobalSettings.MailSettings + { + Smtp = new GlobalSettings.MailSettings.SmtpSettings + { + Host = "localhost", + TrustServer = true, + Port = 10250, + }, + ReplyToEmail = "noreply@bitwarden.com", + }, + SiteName = "Bitwarden", + }; + + var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, Substitute.For>()); + + var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService()); + + var sendMethods = typeof(IMailService).GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name.StartsWith("Send") && m.Name != "SendEnqueuedMailMessageAsync"); + + foreach (var sendMethod in sendMethods) + { + await InvokeMethod(sendMethod); } - await (Task)method.Invoke(handlebarsService, args); + async Task InvokeMethod(MethodInfo method) + { + var parameters = method.GetParameters(); + var args = new object[parameters.Length]; + + for (var i = 0; i < parameters.Length; i++) + { + if (!namedParameters.TryGetValue((parameters[i].Name, parameters[i].ParameterType), out var value)) + { + throw new InvalidOperationException($"Couldn't find a parameter for name '{parameters[i].Name}' and type '{parameters[i].ParameterType.FullName}'"); + } + + args[i] = value; + } + + await (Task)method.Invoke(handlebarsService, args); + } + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); } } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); - } } diff --git a/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs b/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs index ff8e734b32..8deae63641 100644 --- a/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs +++ b/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs @@ -3,28 +3,29 @@ using Bit.Core.Services; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class InMemoryApplicationCacheServiceTests +namespace Bit.Core.Test.Services { - private readonly InMemoryApplicationCacheService _sut; - - private readonly IOrganizationRepository _organizationRepository; - private readonly IProviderRepository _providerRepository; - - public InMemoryApplicationCacheServiceTests() + public class InMemoryApplicationCacheServiceTests { - _organizationRepository = Substitute.For(); - _providerRepository = Substitute.For(); + private readonly InMemoryApplicationCacheService _sut; - _sut = new InMemoryApplicationCacheService(_organizationRepository, _providerRepository); - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IProviderRepository _providerRepository; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); + public InMemoryApplicationCacheServiceTests() + { + _organizationRepository = Substitute.For(); + _providerRepository = Substitute.For(); + + _sut = new InMemoryApplicationCacheService(_organizationRepository, _providerRepository); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs b/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs index f74aa6f505..33f23ea189 100644 --- a/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs +++ b/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs @@ -4,34 +4,35 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class InMemoryServiceBusApplicationCacheServiceTests +namespace Bit.Core.Test.Services { - private readonly InMemoryServiceBusApplicationCacheService _sut; - - private readonly IOrganizationRepository _organizationRepository; - private readonly IProviderRepository _providerRepository; - private readonly GlobalSettings _globalSettings; - - public InMemoryServiceBusApplicationCacheServiceTests() + public class InMemoryServiceBusApplicationCacheServiceTests { - _organizationRepository = Substitute.For(); - _providerRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); + private readonly InMemoryServiceBusApplicationCacheService _sut; - _sut = new InMemoryServiceBusApplicationCacheService( - _organizationRepository, - _providerRepository, - _globalSettings - ); - } + private readonly IOrganizationRepository _organizationRepository; + private readonly IProviderRepository _providerRepository; + private readonly GlobalSettings _globalSettings; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public InMemoryServiceBusApplicationCacheServiceTests() + { + _organizationRepository = Substitute.For(); + _providerRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + + _sut = new InMemoryServiceBusApplicationCacheService( + _organizationRepository, + _providerRepository, + _globalSettings + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/LicensingServiceTests.cs b/test/Core.Test/Services/LicensingServiceTests.cs index 4a8ba0255f..2e94ef2b51 100644 --- a/test/Core.Test/Services/LicensingServiceTests.cs +++ b/test/Core.Test/Services/LicensingServiceTests.cs @@ -9,52 +9,53 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class LicensingServiceTests +namespace Bit.Core.Test.Services { - private static string licenseFilePath(Guid orgId) => - Path.Combine(OrganizationLicenseDirectory.Value, $"{orgId}.json"); - private static string LicenseDirectory => Path.GetDirectoryName(OrganizationLicenseDirectory.Value); - private static Lazy OrganizationLicenseDirectory => new(() => + [SutProviderCustomize] + public class LicensingServiceTests { - var directory = Path.Combine(Path.GetTempPath(), "organization"); - if (!Directory.Exists(directory)) + private static string licenseFilePath(Guid orgId) => + Path.Combine(OrganizationLicenseDirectory.Value, $"{orgId}.json"); + private static string LicenseDirectory => Path.GetDirectoryName(OrganizationLicenseDirectory.Value); + private static Lazy OrganizationLicenseDirectory => new(() => { - Directory.CreateDirectory(directory); + var directory = Path.Combine(Path.GetTempPath(), "organization"); + if (!Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + } + return directory; + }); + + public static SutProvider GetSutProvider() + { + var fixture = new Fixture().WithAutoNSubstitutions(); + + var settings = fixture.Create(); + settings.LicenseDirectory = LicenseDirectory; + settings.SelfHosted = true; + + return new SutProvider(fixture) + .SetDependency(settings) + .Create(); } - return directory; - }); - public static SutProvider GetSutProvider() - { - var fixture = new Fixture().WithAutoNSubstitutions(); - - var settings = fixture.Create(); - settings.LicenseDirectory = LicenseDirectory; - settings.SelfHosted = true; - - return new SutProvider(fixture) - .SetDependency(settings) - .Create(); - } - - [Theory, BitAutoData, OrganizationLicenseCustomize] - public async Task ReadOrganizationLicense(Organization organization, OrganizationLicense license) - { - var sutProvider = GetSutProvider(); - - File.WriteAllText(licenseFilePath(organization.Id), JsonSerializer.Serialize(license)); - - var actual = await sutProvider.Sut.ReadOrganizationLicenseAsync(organization); - try + [Theory, BitAutoData, OrganizationLicenseCustomize] + public async Task ReadOrganizationLicense(Organization organization, OrganizationLicense license) { - Assert.Equal(JsonSerializer.Serialize(license), JsonSerializer.Serialize(actual)); - } - finally - { - Directory.Delete(OrganizationLicenseDirectory.Value, true); + var sutProvider = GetSutProvider(); + + File.WriteAllText(licenseFilePath(organization.Id), JsonSerializer.Serialize(license)); + + var actual = await sutProvider.Sut.ReadOrganizationLicenseAsync(organization); + try + { + Assert.Equal(JsonSerializer.Serialize(license), JsonSerializer.Serialize(actual)); + } + finally + { + Directory.Delete(OrganizationLicenseDirectory.Value, true); + } } } } diff --git a/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs b/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs index cf05933f42..63a3e8bc8f 100644 --- a/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs +++ b/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs @@ -11,75 +11,194 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class LocalAttachmentStorageServiceTests +namespace Bit.Core.Test.Services { - - private void AssertFileCreation(string expectedPath, string expectedFileContents) + public class LocalAttachmentStorageServiceTests { - Assert.True(File.Exists(expectedPath)); - Assert.Equal(expectedFileContents, File.ReadAllText(expectedPath)); - } - [Theory] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] - public async Task UploadNewAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) + private void AssertFileCreation(string expectedPath, string expectedFileContents) { - var sutProvider = GetSutProvider(tempDirectory); - - await sutProvider.Sut.UploadNewAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), - cipher, attachmentData); - - AssertFileCreation($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}", stream); + Assert.True(File.Exists(expectedPath)); + Assert.Equal(expectedFileContents, File.ReadAllText(expectedPath)); } - } - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task UploadShareAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) + [Theory] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] + public async Task UploadNewAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) { - var sutProvider = GetSutProvider(tempDirectory); + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); - await sutProvider.Sut.UploadShareAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), - cipher.Id, cipher.OrganizationId.Value, attachmentData); + await sutProvider.Sut.UploadNewAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), + cipher, attachmentData); - AssertFileCreation($"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}", stream); + AssertFileCreation($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}", stream); + } } - } - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task StartShareAttachmentAsync_NoSource_NoWork(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task UploadShareAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) { - var sutProvider = GetSutProvider(tempDirectory); + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); - await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); + await sutProvider.Sut.UploadShareAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), + cipher.Id, cipher.OrganizationId.Value, attachmentData); - Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); - Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); + AssertFileCreation($"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}", stream); + } } - } - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task StartShareAttachmentAsync_NoDest_NoWork(string source, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task StartShareAttachmentAsync_NoSource_NoWork(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); + + Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); + Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); + } + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task StartShareAttachmentAsync_NoDest_NoWork(string source, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; + var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; + var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; + Directory.CreateDirectory(Path.GetDirectoryName(sourcePath)); + File.WriteAllText(sourcePath, source); + + await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); + + Assert.True(File.Exists(sourcePath)); + Assert.Equal(source, File.ReadAllText(sourcePath)); + Assert.False(File.Exists(destPath)); + Assert.False(File.Exists(rollBackPath)); + } + } + + + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task StartShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); + } + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task RollbackShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; + var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; + var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; + + await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); + await sutProvider.Sut.RollbackShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData, "Not Used Here"); + + Assert.True(File.Exists(destPath)); + Assert.Equal(destOriginal, File.ReadAllText(destPath)); + Assert.False(File.Exists(sourcePath)); + Assert.False(File.Exists(rollBackPath)); + } + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] + public async Task DeleteAttachmentAsync_Success(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var expectedPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; + Directory.CreateDirectory(Path.GetDirectoryName(expectedPath)); + File.Create(expectedPath).Close(); + + await sutProvider.Sut.DeleteAttachmentAsync(cipher.Id, attachmentData); + + Assert.False(File.Exists(expectedPath)); + } + } + + [Theory] + [InlineUserCipherAutoData] + [InlineOrganizationCipherAutoData] + public async Task CleanupAsync_Succes(Cipher cipher) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; + var permPath = $"{tempDirectory}/{cipher.Id}"; + Directory.CreateDirectory(tempPath); + Directory.CreateDirectory(permPath); + + await sutProvider.Sut.CleanupAsync(cipher.Id); + + Assert.False(Directory.Exists(tempPath)); + Assert.True(Directory.Exists(permPath)); + } + } + + [Theory] + [InlineUserCipherAutoData] + [InlineOrganizationCipherAutoData] + public async Task DeleteAttachmentsForCipherAsync_Succes(Cipher cipher) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; + var permPath = $"{tempDirectory}/{cipher.Id}"; + Directory.CreateDirectory(tempPath); + Directory.CreateDirectory(permPath); + + await sutProvider.Sut.DeleteAttachmentsForCipherAsync(cipher.Id); + + Assert.True(Directory.Exists(tempPath)); + Assert.False(Directory.Exists(permPath)); + } + } + + private async Task StartShareAttachmentAsync(string source, string destOriginal, Cipher cipher, + CipherAttachment.MetaData attachmentData, TempDirectory tempDirectory) { var sutProvider = GetSutProvider(tempDirectory); @@ -87,144 +206,26 @@ public class LocalAttachmentStorageServiceTests var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; Directory.CreateDirectory(Path.GetDirectoryName(sourcePath)); + Directory.CreateDirectory(Path.GetDirectoryName(destPath)); File.WriteAllText(sourcePath, source); + File.WriteAllText(destPath, destOriginal); await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); - Assert.True(File.Exists(sourcePath)); - Assert.Equal(source, File.ReadAllText(sourcePath)); - Assert.False(File.Exists(destPath)); - Assert.False(File.Exists(rollBackPath)); - } - } - - - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task StartShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); - } - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task RollbackShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; - var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; - var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; - - await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); - await sutProvider.Sut.RollbackShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData, "Not Used Here"); - - Assert.True(File.Exists(destPath)); - Assert.Equal(destOriginal, File.ReadAllText(destPath)); Assert.False(File.Exists(sourcePath)); - Assert.False(File.Exists(rollBackPath)); + Assert.True(File.Exists(destPath)); + Assert.Equal(source, File.ReadAllText(destPath)); + Assert.True(File.Exists(rollBackPath)); + Assert.Equal(destOriginal, File.ReadAllText(rollBackPath)); } - } - [Theory] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] - public async Task DeleteAttachmentAsync_Success(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) + private SutProvider GetSutProvider(TempDirectory tempDirectory) { - var sutProvider = GetSutProvider(tempDirectory); + var fixture = new Fixture().WithAutoNSubstitutions(); + fixture.Freeze().Attachment.BaseDirectory.Returns(tempDirectory.Directory); + fixture.Freeze().Attachment.BaseUrl.Returns(Guid.NewGuid().ToString()); - var expectedPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; - Directory.CreateDirectory(Path.GetDirectoryName(expectedPath)); - File.Create(expectedPath).Close(); - - await sutProvider.Sut.DeleteAttachmentAsync(cipher.Id, attachmentData); - - Assert.False(File.Exists(expectedPath)); + return new SutProvider(fixture).Create(); } } - - [Theory] - [InlineUserCipherAutoData] - [InlineOrganizationCipherAutoData] - public async Task CleanupAsync_Succes(Cipher cipher) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; - var permPath = $"{tempDirectory}/{cipher.Id}"; - Directory.CreateDirectory(tempPath); - Directory.CreateDirectory(permPath); - - await sutProvider.Sut.CleanupAsync(cipher.Id); - - Assert.False(Directory.Exists(tempPath)); - Assert.True(Directory.Exists(permPath)); - } - } - - [Theory] - [InlineUserCipherAutoData] - [InlineOrganizationCipherAutoData] - public async Task DeleteAttachmentsForCipherAsync_Succes(Cipher cipher) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; - var permPath = $"{tempDirectory}/{cipher.Id}"; - Directory.CreateDirectory(tempPath); - Directory.CreateDirectory(permPath); - - await sutProvider.Sut.DeleteAttachmentsForCipherAsync(cipher.Id); - - Assert.True(Directory.Exists(tempPath)); - Assert.False(Directory.Exists(permPath)); - } - } - - private async Task StartShareAttachmentAsync(string source, string destOriginal, Cipher cipher, - CipherAttachment.MetaData attachmentData, TempDirectory tempDirectory) - { - var sutProvider = GetSutProvider(tempDirectory); - - var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; - var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; - var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; - Directory.CreateDirectory(Path.GetDirectoryName(sourcePath)); - Directory.CreateDirectory(Path.GetDirectoryName(destPath)); - File.WriteAllText(sourcePath, source); - File.WriteAllText(destPath, destOriginal); - - await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); - - Assert.False(File.Exists(sourcePath)); - Assert.True(File.Exists(destPath)); - Assert.Equal(source, File.ReadAllText(destPath)); - Assert.True(File.Exists(rollBackPath)); - Assert.Equal(destOriginal, File.ReadAllText(rollBackPath)); - } - - private SutProvider GetSutProvider(TempDirectory tempDirectory) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - fixture.Freeze().Attachment.BaseDirectory.Returns(tempDirectory.Directory); - fixture.Freeze().Attachment.BaseUrl.Returns(Guid.NewGuid().ToString()); - - return new SutProvider(fixture).Create(); - } } diff --git a/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs b/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs index 4e7e36fe02..d4c5208a27 100644 --- a/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs @@ -4,34 +4,35 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class MailKitSmtpMailDeliveryServiceTests +namespace Bit.Core.Test.Services { - private readonly MailKitSmtpMailDeliveryService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - - public MailKitSmtpMailDeliveryServiceTests() + public class MailKitSmtpMailDeliveryServiceTests { - _globalSettings = new GlobalSettings(); - _logger = Substitute.For>(); + private readonly MailKitSmtpMailDeliveryService _sut; - _globalSettings.Mail.Smtp.Host = "unittests.example.com"; - _globalSettings.Mail.ReplyToEmail = "noreply@unittests.example.com"; + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; - _sut = new MailKitSmtpMailDeliveryService( - _globalSettings, - _logger - ); - } + public MailKitSmtpMailDeliveryServiceTests() + { + _globalSettings = new GlobalSettings(); + _logger = Substitute.For>(); - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); + _globalSettings.Mail.Smtp.Host = "unittests.example.com"; + _globalSettings.Mail.ReplyToEmail = "noreply@unittests.example.com"; + + _sut = new MailKitSmtpMailDeliveryService( + _globalSettings, + _logger + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs index b1876f1dda..925456619e 100644 --- a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs +++ b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs @@ -6,49 +6,50 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class MultiServicePushNotificationServiceTests +namespace Bit.Core.Test.Services { - private readonly MultiServicePushNotificationService _sut; - - private readonly IHttpClientFactory _httpFactory; - private readonly IDeviceRepository _deviceRepository; - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; - private readonly ILogger _relayLogger; - private readonly ILogger _hubLogger; - - public MultiServicePushNotificationServiceTests() + public class MultiServicePushNotificationServiceTests { - _httpFactory = Substitute.For(); - _deviceRepository = Substitute.For(); - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - _logger = Substitute.For>(); - _relayLogger = Substitute.For>(); - _hubLogger = Substitute.For>(); + private readonly MultiServicePushNotificationService _sut; - _sut = new MultiServicePushNotificationService( - _httpFactory, - _deviceRepository, - _installationDeviceRepository, - _globalSettings, - _httpContextAccessor, - _logger, - _relayLogger, - _hubLogger - ); - } + private readonly IHttpClientFactory _httpFactory; + private readonly IDeviceRepository _deviceRepository; + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; + private readonly ILogger _relayLogger; + private readonly ILogger _hubLogger; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); + public MultiServicePushNotificationServiceTests() + { + _httpFactory = Substitute.For(); + _deviceRepository = Substitute.For(); + _installationDeviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + _logger = Substitute.For>(); + _relayLogger = Substitute.For>(); + _hubLogger = Substitute.For>(); + + _sut = new MultiServicePushNotificationService( + _httpFactory, + _deviceRepository, + _installationDeviceRepository, + _globalSettings, + _httpContextAccessor, + _logger, + _relayLogger, + _hubLogger + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs b/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs index a066eee8b8..ea59da3ed7 100644 --- a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs +++ b/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs @@ -5,34 +5,35 @@ using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class NotificationHubPushNotificationServiceTests +namespace Bit.Core.Test.Services { - private readonly NotificationHubPushNotificationService _sut; - - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - - public NotificationHubPushNotificationServiceTests() + public class NotificationHubPushNotificationServiceTests { - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); + private readonly NotificationHubPushNotificationService _sut; - _sut = new NotificationHubPushNotificationService( - _installationDeviceRepository, - _globalSettings, - _httpContextAccessor - ); - } + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public NotificationHubPushNotificationServiceTests() + { + _installationDeviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + + _sut = new NotificationHubPushNotificationService( + _installationDeviceRepository, + _globalSettings, + _httpContextAccessor + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs b/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs index 8e2a19d7b9..432a796865 100644 --- a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs +++ b/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs @@ -4,31 +4,32 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class NotificationHubPushRegistrationServiceTests +namespace Bit.Core.Test.Services { - private readonly NotificationHubPushRegistrationService _sut; - - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - - public NotificationHubPushRegistrationServiceTests() + public class NotificationHubPushRegistrationServiceTests { - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); + private readonly NotificationHubPushRegistrationService _sut; - _sut = new NotificationHubPushRegistrationService( - _installationDeviceRepository, - _globalSettings - ); - } + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public NotificationHubPushRegistrationServiceTests() + { + _installationDeviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + + _sut = new NotificationHubPushRegistrationService( + _installationDeviceRepository, + _globalSettings + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs b/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs index d1ba15d6a5..59976f57d4 100644 --- a/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs +++ b/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs @@ -5,37 +5,38 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class NotificationsApiPushNotificationServiceTests +namespace Bit.Core.Test.Services { - private readonly NotificationsApiPushNotificationService _sut; - - private readonly IHttpClientFactory _httpFactory; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; - - public NotificationsApiPushNotificationServiceTests() + public class NotificationsApiPushNotificationServiceTests { - _httpFactory = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - _logger = Substitute.For>(); + private readonly NotificationsApiPushNotificationService _sut; - _sut = new NotificationsApiPushNotificationService( - _httpFactory, - _globalSettings, - _httpContextAccessor, - _logger - ); - } + private readonly IHttpClientFactory _httpFactory; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public NotificationsApiPushNotificationServiceTests() + { + _httpFactory = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + _logger = Substitute.For>(); + + _sut = new NotificationsApiPushNotificationService( + _httpFactory, + _globalSettings, + _httpContextAccessor, + _logger + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/OrganizationServiceTests.cs b/test/Core.Test/Services/OrganizationServiceTests.cs index 73f0b8cfeb..acf2b7f170 100644 --- a/test/Core.Test/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/Services/OrganizationServiceTests.cs @@ -20,935 +20,936 @@ using Organization = Bit.Core.Entities.Organization; using OrganizationUser = Bit.Core.Entities.OrganizationUser; using Policy = Bit.Core.Entities.Policy; -namespace Bit.Core.Test.Services; - -public class OrganizationServiceTests +namespace Bit.Core.Test.Services { - // [Fact] - [Theory, PaidOrganizationAutoData] - public async Task OrgImportCreateNewUsers(SutProvider sutProvider, Guid userId, - Organization org, List existingUsers, List newUsers) + public class OrganizationServiceTests { - org.UseDirectory = true; - org.Seats = 10; - newUsers.Add(new ImportedOrganizationUser + // [Fact] + [Theory, PaidOrganizationAutoData] + public async Task OrgImportCreateNewUsers(SutProvider sutProvider, Guid userId, + Organization org, List existingUsers, List newUsers) { - Email = existingUsers.First().Email, - ExternalId = existingUsers.First().ExternalId - }); - var expectedNewUsersCount = newUsers.Count - 1; - - existingUsers.First().Type = OrganizationUserType.Owner; - - sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) - .Returns(existingUsers); - sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) - .Returns(existingUsers.Count); - sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) - .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); - sutProvider.GetDependency().ManageUsers(org.Id).Returns(true); - - await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency().Received(1) - .UpsertManyAsync(Arg.Is>(users => users.Count() == 0)); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - - // Create new users - await sutProvider.GetDependency().Received(1) - .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(org.Name, - Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); - - // Send events - await sutProvider.GetDependency().Received(1) - .LogOrganizationUserEventsAsync(Arg.Is>(events => - events.Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .RaiseEventAsync(Arg.Is(referenceEvent => - referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && - referenceEvent.Users == expectedNewUsersCount)); - } - - [Theory, PaidOrganizationAutoData] - public async Task OrgImportCreateNewUsersAndMarryExistingUser(SutProvider sutProvider, - Guid userId, Organization org, List existingUsers, - List newUsers) - { - org.UseDirectory = true; - org.Seats = newUsers.Count + existingUsers.Count + 1; - var reInvitedUser = existingUsers.First(); - reInvitedUser.ExternalId = null; - newUsers.Add(new ImportedOrganizationUser - { - Email = reInvitedUser.Email, - ExternalId = reInvitedUser.Email, - }); - var expectedNewUsersCount = newUsers.Count - 1; - - sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) - .Returns(existingUsers); - sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) - .Returns(existingUsers.Count); - sutProvider.GetDependency().GetByIdAsync(reInvitedUser.Id) - .Returns(new OrganizationUser { Id = reInvitedUser.Id }); - sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) - .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); - var currentContext = sutProvider.GetDependency(); - currentContext.ManageUsers(org.Id).Returns(true); - - await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default, default); - - // Upserted existing user - await sutProvider.GetDependency().Received(1) - .UpsertManyAsync(Arg.Is>(users => users.Count() == 1)); - - // Created and invited new users - await sutProvider.GetDependency().Received(1) - .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(org.Name, - Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); - - // Sent events - await sutProvider.GetDependency().Received(1) - .LogOrganizationUserEventsAsync(Arg.Is>(events => - events.Where(e => e.Item2 == EventType.OrganizationUser_Invited).Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .RaiseEventAsync(Arg.Is(referenceEvent => - referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && - referenceEvent.Users == expectedNewUsersCount)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpgradePlan_OrganizationIsNull_Throws(Guid organizationId, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(Task.FromResult(null)); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organizationId, upgrade)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpgradePlan_GatewayCustomIdIsNull_Throws(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - organization.GatewayCustomerId = string.Empty; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); - Assert.Contains("no payment method", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpgradePlan_AlreadyInPlan_Throws(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - upgrade.Plan = organization.PlanType; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); - Assert.Contains("already on this plan", exception.Message); - } - - [Theory, PaidOrganizationAutoData] - public async Task UpgradePlan_UpgradeFromPaidPlan_Throws(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); - Assert.Contains("can only upgrade", exception.Message); - } - - [Theory] - [FreeOrganizationUpgradeAutoData] - public async Task UpgradePlan_Passes(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); - await sutProvider.GetDependency().Received(1).ReplaceAsync(organization); - } - - [Theory] - [OrganizationInviteAutoData] - public async Task InviteUser_NoEmails_Throws(Organization organization, OrganizationUser invitor, - OrganizationUserInvite invite, SutProvider sutProvider) - { - invite.Emails = null; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - } - - [Theory] - [OrganizationInviteAutoData] - public async Task InviteUser_DuplicateEmails_PassesWithoutDuplicates(Organization organization, OrganizationUser invitor, - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - OrganizationUserInvite invite, SutProvider sutProvider) - { - invite.Emails = invite.Emails.Append(invite.Emails.First()); - - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); - sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); - var organizationUserRepository = sutProvider.GetDependency(); - organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) - .Returns(new[] { owner }); - - await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); - - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(organization.Name, - Arg.Is>(v => v.Count() == invite.Emails.Distinct().Count())); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Admin, - invitorUserType: (int)OrganizationUserType.Owner - )] - public async Task InviteUser_NoOwner_Throws(Organization organization, OrganizationUser invitor, - OrganizationUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); - sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Owner, - invitorUserType: (int)OrganizationUserType.Admin - )] - public async Task InviteUser_NonOwnerConfiguringOwner_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationAdmin(organization.Id).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("only an owner", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Custom, - invitorUserType: (int)OrganizationUserType.User - )] - public async Task InviteUser_NonAdminConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationUser(organization.Id).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("only owners and admins", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Manager, - invitorUserType: (int)OrganizationUserType.Custom - )] - public async Task InviteUser_CustomUserWithoutManageUsersConfiguringUser_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = false }, - new JsonSerializerOptions + org.UseDirectory = true; + org.Seats = 10; + newUsers.Add(new ImportedOrganizationUser { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + Email = existingUsers.First().Email, + ExternalId = existingUsers.First().ExternalId }); - - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationCustom(organization.Id).Returns(true); - currentContext.ManageUsers(organization.Id).Returns(false); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("account does not have permission", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Admin, - invitorUserType: (int)OrganizationUserType.Custom - )] - public async Task InviteUser_CustomUserConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, - new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationCustom(organization.Id).Returns(true); - currentContext.ManageUsers(organization.Id).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("can not manage admins", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.User, - invitorUserType: (int)OrganizationUserType.Owner - )] - public async Task InviteUser_NoPermissionsObject_Passes(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - invite.Permissions = null; - invitor.Status = OrganizationUserStatusType.Confirmed; - var organizationRepository = sutProvider.GetDependency(); - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) - .Returns(new[] { invitor }); - currentContext.OrganizationOwner(organization.Id).Returns(true); - currentContext.ManageUsers(organization.Id).Returns(true); - - await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.User, - invitorUserType: (int)OrganizationUserType.Custom - )] - public async Task InviteUser_Passes(Organization organization, IEnumerable<(OrganizationUserInvite invite, string externalId)> invites, - OrganizationUser invitor, - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - SutProvider sutProvider) - { - // Autofixture will add collections for all of the invites, remove the first and for all the rest set all access false - invites.First().invite.AccessAll = true; - invites.First().invite.Collections = null; - invites.Skip(1).ToList().ForEach(i => i.invite.AccessAll = false); - - invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, - new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - - var organizationRepository = sutProvider.GetDependency(); - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) - .Returns(new[] { owner }); - currentContext.ManageUsers(organization.Id).Returns(true); - - await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, invites); - - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(organization.Name, - Arg.Is>(v => v.Count() == invites.SelectMany(i => i.invite.Emails).Count())); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUser_NoUserId_Throws(OrganizationUser user, Guid? savingUserId, - IEnumerable collections, SutProvider sutProvider) - { - user.Id = default(Guid); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); - Assert.Contains("invite the user first", exception.Message.ToLowerInvariant()); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUser_NoChangeToData_Throws(OrganizationUser user, Guid? savingUserId, - IEnumerable collections, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - organizationUserRepository.GetByIdAsync(user.Id).Returns(user); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); - Assert.Contains("make changes before saving", exception.Message.ToLowerInvariant()); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUser_Passes( - OrganizationUser oldUserData, - OrganizationUser newUserData, - IEnumerable collections, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser savingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - newUserData.Id = oldUserData.Id; - newUserData.UserId = oldUserData.UserId; - newUserData.OrganizationId = savingUser.OrganizationId = oldUserData.OrganizationId; - organizationUserRepository.GetByIdAsync(oldUserData.Id).Returns(oldUserData); - organizationUserRepository.GetManyByOrganizationAsync(savingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new List { savingUser }); - currentContext.OrganizationOwner(savingUser.OrganizationId).Returns(true); - - await sutProvider.Sut.SaveUserAsync(newUserData, savingUser.UserId, collections); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_InvalidUser(OrganizationUser organizationUser, OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(Guid.NewGuid(), organizationUser.Id, deletingUser.UserId)); - Assert.Contains("User not valid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_RemoveYourself(OrganizationUser deletingUser, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, deletingUser.Id, deletingUser.UserId)); - Assert.Contains("You cannot remove yourself.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_NonOwnerRemoveOwner( - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, - [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - currentContext.OrganizationAdmin(deletingUser.OrganizationId).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); - Assert.Contains("Only owners can delete other owners.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_LastOwner( - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, - OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new[] { organizationUser }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, null)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_Success( - OrganizationUser organizationUser, - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new[] { deletingUser, organizationUser }); - currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); - - await sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_FilterInvalid(OrganizationUser organizationUser, OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationUsers = new[] { organizationUser }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId)); - Assert.Contains("Users invalid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_RemoveYourself( - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, - OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationUsers = new[] { deletingUser }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser }); - - var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - Assert.Contains("You cannot remove yourself.", result[0].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_NonOwnerRemoveOwner( - [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, - [OrganizationUser(OrganizationUserStatusType.Confirmed)] OrganizationUser orgUser2, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; - var organizationUsers = new[] { orgUser1 }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser2 }); - - var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - Assert.Contains("Only owners can delete other owners.", result[0].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_LastOwner( - [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - var organizationUsers = new[] { orgUser }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner).Returns(organizationUsers); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUsersAsync(orgUser.OrganizationId, organizationUserIds, null)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_Success( - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; - var organizationUsers = new[] { orgUser1, orgUser2 }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new[] { deletingUser, orgUser1 }); - currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); - - await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_InvalidStatus(OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Invited)] OrganizationUser orgUser, string key, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User not valid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_WrongOrganization(OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, string key, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(confirmingUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User not valid.", exception.Message); - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Owner)] - public async Task ConfirmUserToFree_AlreadyFreeAdminOrOwner_Throws(OrganizationUserType userType, Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - var userRepository = sutProvider.GetDependency(); - - org.PlanType = PlanType.Free; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = user.Id; - orgUser.Type = userType; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User can only be an admin of one free organization.", exception.Message); - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Owner)] - public async Task ConfirmUserToNonFree_AlreadyFreeAdminOrOwner_DoesNotThrow(PlanType planType, OrganizationUserType orgUserType, Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - var userRepository = sutProvider.GetDependency(); - - org.PlanType = planType; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = user.Id; - orgUser.Type = orgUserType; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - - await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); - - await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); - await sutProvider.GetDependency().Received(1).SendOrganizationConfirmedEmailAsync(org.Name, user.Email); - await organizationUserRepository.Received(1).ReplaceManyAsync(Arg.Is>(users => users.Contains(orgUser) && users.Count == 1)); - } - - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_SingleOrgPolicy(Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.Status = OrganizationUserStatusType.Accepted; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { singleOrgPolicy }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User is a member of another organization.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_TwoFactorPolicy(Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User does not have two-step login enabled.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_Success(Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = user.Id; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); - userService.TwoFactorIsEnabledAsync(user).Returns(true); - - await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUsers_Success(Organization org, - OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser1, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser2, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser3, - OrganizationUser anotherOrgUser, User user1, User user2, User user3, - [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser1.OrganizationId = orgUser2.OrganizationId = orgUser3.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser1.UserId = user1.Id; - orgUser2.UserId = user2.Id; - orgUser3.UserId = user3.Id; - anotherOrgUser.UserId = user3.Id; - var orgUsers = new[] { orgUser1, orgUser2, orgUser3 }; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(orgUsers); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user1, user2, user3 }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); - userService.TwoFactorIsEnabledAsync(user1).Returns(true); - userService.TwoFactorIsEnabledAsync(user2).Returns(false); - userService.TwoFactorIsEnabledAsync(user3).Returns(true); - organizationUserRepository.GetManyByManyUsersAsync(default) - .ReturnsForAnyArgs(new[] { orgUser1, orgUser2, orgUser3, anotherOrgUser }); - - var keys = orgUsers.ToDictionary(ou => ou.Id, _ => key); - var result = await sutProvider.Sut.ConfirmUsersAsync(confirmingUser.OrganizationId, keys, confirmingUser.Id, userService); - Assert.Contains("", result[0].Item2); - Assert.Contains("User does not have two-step login enabled.", result[1].Item2); - Assert.Contains("User is a member of another organization.", result[2].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateOrganizationKeysAsync_WithoutManageResetPassword_Throws(Guid orgId, string publicKey, - string privateKey, SutProvider sutProvider) - { - var currentContext = Substitute.For(); - currentContext.ManageResetPassword(orgId).Returns(false); - - await Assert.ThrowsAsync( - () => sutProvider.Sut.UpdateOrganizationKeysAsync(orgId, publicKey, privateKey)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Throws(Organization org, string publicKey, - string privateKey, SutProvider sutProvider) - { - var currentContext = sutProvider.GetDependency(); - currentContext.ManageResetPassword(org.Id).Returns(true); - - var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey)); - Assert.Contains("Organization Keys already exist", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Success(Organization org, string publicKey, - string privateKey, SutProvider sutProvider) - { - org.PublicKey = null; - org.PrivateKey = null; - - var currentContext = sutProvider.GetDependency(); - currentContext.ManageResetPassword(org.Id).Returns(true); - - var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - - await sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey); - } - - [Theory] - [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 1, 0, 2 })] - [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 4, -1, 6 })] - [InlineFreeOrganizationAutoData("Your plan does not allow seat autoscaling", 10, 0, null)] - public async Task UpdateSubscription_BadInputThrows(string expectedMessage, - int? maxAutoscaleSeats, int seatAdjustment, int? currentSeats, Organization organization, SutProvider sutProvider) - { - organization.Seats = currentSeats; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organization.Id, - seatAdjustment, maxAutoscaleSeats)); - - Assert.Contains(expectedMessage, exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateSubscription_NoOrganization_Throws(Guid organizationId, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organizationId).Returns((Organization)null); - - await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organizationId, 0, null)); - } - - [Theory] - [InlinePaidOrganizationAutoData(0, 100, null, true, "")] - [InlinePaidOrganizationAutoData(0, 100, 100, true, "")] - [InlinePaidOrganizationAutoData(0, null, 100, true, "")] - [InlinePaidOrganizationAutoData(1, 100, null, true, "")] - [InlinePaidOrganizationAutoData(1, 100, 100, false, "Cannot invite new users. Seat limit has been reached")] - public void CanScale(int seatsToAdd, int? currentSeats, int? maxAutoscaleSeats, - bool expectedResult, string expectedFailureMessage, Organization organization, - SutProvider sutProvider) - { - organization.Seats = currentSeats; - organization.MaxAutoscaleSeats = maxAutoscaleSeats; - sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); - - var (result, failureMessage) = sutProvider.Sut.CanScale(organization, seatsToAdd); - - if (expectedFailureMessage == string.Empty) - { - Assert.Empty(failureMessage); + var expectedNewUsersCount = newUsers.Count - 1; + + existingUsers.First().Type = OrganizationUserType.Owner; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) + .Returns(existingUsers); + sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) + .Returns(existingUsers.Count); + sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) + .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); + sutProvider.GetDependency().ManageUsers(org.Id).Returns(true); + + await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency().Received(1) + .UpsertManyAsync(Arg.Is>(users => users.Count() == 0)); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + + // Create new users + await sutProvider.GetDependency().Received(1) + .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(org.Name, + Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); + + // Send events + await sutProvider.GetDependency().Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>(events => + events.Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .RaiseEventAsync(Arg.Is(referenceEvent => + referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && + referenceEvent.Users == expectedNewUsersCount)); } - else + + [Theory, PaidOrganizationAutoData] + public async Task OrgImportCreateNewUsersAndMarryExistingUser(SutProvider sutProvider, + Guid userId, Organization org, List existingUsers, + List newUsers) { - Assert.Contains(expectedFailureMessage, failureMessage); + org.UseDirectory = true; + org.Seats = newUsers.Count + existingUsers.Count + 1; + var reInvitedUser = existingUsers.First(); + reInvitedUser.ExternalId = null; + newUsers.Add(new ImportedOrganizationUser + { + Email = reInvitedUser.Email, + ExternalId = reInvitedUser.Email, + }); + var expectedNewUsersCount = newUsers.Count - 1; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) + .Returns(existingUsers); + sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) + .Returns(existingUsers.Count); + sutProvider.GetDependency().GetByIdAsync(reInvitedUser.Id) + .Returns(new OrganizationUser { Id = reInvitedUser.Id }); + sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) + .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); + var currentContext = sutProvider.GetDependency(); + currentContext.ManageUsers(org.Id).Returns(true); + + await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default, default); + + // Upserted existing user + await sutProvider.GetDependency().Received(1) + .UpsertManyAsync(Arg.Is>(users => users.Count() == 1)); + + // Created and invited new users + await sutProvider.GetDependency().Received(1) + .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(org.Name, + Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); + + // Sent events + await sutProvider.GetDependency().Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>(events => + events.Where(e => e.Item2 == EventType.OrganizationUser_Invited).Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .RaiseEventAsync(Arg.Is(referenceEvent => + referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && + referenceEvent.Users == expectedNewUsersCount)); } - Assert.Equal(expectedResult, result); - } - [Theory, PaidOrganizationAutoData] - public void CanScale_FailsOnSelfHosted(Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().SelfHosted.Returns(true); - var (result, failureMessage) = sutProvider.Sut.CanScale(organization, 10); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpgradePlan_OrganizationIsNull_Throws(Guid organizationId, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(Task.FromResult(null)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organizationId, upgrade)); + } - Assert.False(result); - Assert.Contains("Cannot autoscale on self-hosted instance", failureMessage); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpgradePlan_GatewayCustomIdIsNull_Throws(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + organization.GatewayCustomerId = string.Empty; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); + Assert.Contains("no payment method", exception.Message); + } - [Theory, PaidOrganizationAutoData] - public async Task Delete_Success(Organization organization, SutProvider sutProvider) - { - var organizationRepository = sutProvider.GetDependency(); - var applicationCacheService = sutProvider.GetDependency(); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpgradePlan_AlreadyInPlan_Throws(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + upgrade.Plan = organization.PlanType; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); + Assert.Contains("already on this plan", exception.Message); + } - await sutProvider.Sut.DeleteAsync(organization); + [Theory, PaidOrganizationAutoData] + public async Task UpgradePlan_UpgradeFromPaidPlan_Throws(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); + Assert.Contains("can only upgrade", exception.Message); + } - await organizationRepository.Received().DeleteAsync(organization); - await applicationCacheService.Received().DeleteOrganizationAbilityAsync(organization.Id); - } + [Theory] + [FreeOrganizationUpgradeAutoData] + public async Task UpgradePlan_Passes(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); + await sutProvider.GetDependency().Received(1).ReplaceAsync(organization); + } - [Theory, PaidOrganizationAutoData] - public async Task Delete_Fails_KeyConnector(Organization organization, SutProvider sutProvider, - SsoConfig ssoConfig) - { - ssoConfig.Enabled = true; - ssoConfig.SetData(new SsoConfigurationData { KeyConnectorEnabled = true }); - var ssoConfigRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var applicationCacheService = sutProvider.GetDependency(); + [Theory] + [OrganizationInviteAutoData] + public async Task InviteUser_NoEmails_Throws(Organization organization, OrganizationUser invitor, + OrganizationUserInvite invite, SutProvider sutProvider) + { + invite.Emails = null; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + } - ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(ssoConfig); + [Theory] + [OrganizationInviteAutoData] + public async Task InviteUser_DuplicateEmails_PassesWithoutDuplicates(Organization organization, OrganizationUser invitor, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + OrganizationUserInvite invite, SutProvider sutProvider) + { + invite.Emails = invite.Emails.Append(invite.Emails.First()); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteAsync(organization)); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); + sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new[] { owner }); - Assert.Contains("You cannot delete an Organization that is using Key Connector.", exception.Message); + await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); - await organizationRepository.DidNotReceiveWithAnyArgs().DeleteAsync(default); - await applicationCacheService.DidNotReceiveWithAnyArgs().DeleteOrganizationAbilityAsync(default); + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(organization.Name, + Arg.Is>(v => v.Count() == invite.Emails.Distinct().Count())); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Admin, + invitorUserType: (int)OrganizationUserType.Owner + )] + public async Task InviteUser_NoOwner_Throws(Organization organization, OrganizationUser invitor, + OrganizationUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); + sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Owner, + invitorUserType: (int)OrganizationUserType.Admin + )] + public async Task InviteUser_NonOwnerConfiguringOwner_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationAdmin(organization.Id).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("only an owner", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Custom, + invitorUserType: (int)OrganizationUserType.User + )] + public async Task InviteUser_NonAdminConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationUser(organization.Id).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("only owners and admins", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Manager, + invitorUserType: (int)OrganizationUserType.Custom + )] + public async Task InviteUser_CustomUserWithoutManageUsersConfiguringUser_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = false }, + new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationCustom(organization.Id).Returns(true); + currentContext.ManageUsers(organization.Id).Returns(false); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("account does not have permission", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Admin, + invitorUserType: (int)OrganizationUserType.Custom + )] + public async Task InviteUser_CustomUserConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, + new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationCustom(organization.Id).Returns(true); + currentContext.ManageUsers(organization.Id).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("can not manage admins", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.User, + invitorUserType: (int)OrganizationUserType.Owner + )] + public async Task InviteUser_NoPermissionsObject_Passes(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + invite.Permissions = null; + invitor.Status = OrganizationUserStatusType.Confirmed; + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new[] { invitor }); + currentContext.OrganizationOwner(organization.Id).Returns(true); + currentContext.ManageUsers(organization.Id).Returns(true); + + await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.User, + invitorUserType: (int)OrganizationUserType.Custom + )] + public async Task InviteUser_Passes(Organization organization, IEnumerable<(OrganizationUserInvite invite, string externalId)> invites, + OrganizationUser invitor, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + SutProvider sutProvider) + { + // Autofixture will add collections for all of the invites, remove the first and for all the rest set all access false + invites.First().invite.AccessAll = true; + invites.First().invite.Collections = null; + invites.Skip(1).ToList().ForEach(i => i.invite.AccessAll = false); + + invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, + new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new[] { owner }); + currentContext.ManageUsers(organization.Id).Returns(true); + + await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, invites); + + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(organization.Name, + Arg.Is>(v => v.Count() == invites.SelectMany(i => i.invite.Emails).Count())); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUser_NoUserId_Throws(OrganizationUser user, Guid? savingUserId, + IEnumerable collections, SutProvider sutProvider) + { + user.Id = default(Guid); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); + Assert.Contains("invite the user first", exception.Message.ToLowerInvariant()); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUser_NoChangeToData_Throws(OrganizationUser user, Guid? savingUserId, + IEnumerable collections, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository.GetByIdAsync(user.Id).Returns(user); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); + Assert.Contains("make changes before saving", exception.Message.ToLowerInvariant()); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUser_Passes( + OrganizationUser oldUserData, + OrganizationUser newUserData, + IEnumerable collections, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser savingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + newUserData.Id = oldUserData.Id; + newUserData.UserId = oldUserData.UserId; + newUserData.OrganizationId = savingUser.OrganizationId = oldUserData.OrganizationId; + organizationUserRepository.GetByIdAsync(oldUserData.Id).Returns(oldUserData); + organizationUserRepository.GetManyByOrganizationAsync(savingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new List { savingUser }); + currentContext.OrganizationOwner(savingUser.OrganizationId).Returns(true); + + await sutProvider.Sut.SaveUserAsync(newUserData, savingUser.UserId, collections); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_InvalidUser(OrganizationUser organizationUser, OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(Guid.NewGuid(), organizationUser.Id, deletingUser.UserId)); + Assert.Contains("User not valid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_RemoveYourself(OrganizationUser deletingUser, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, deletingUser.Id, deletingUser.UserId)); + Assert.Contains("You cannot remove yourself.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_NonOwnerRemoveOwner( + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, + [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationUser.OrganizationId = deletingUser.OrganizationId; + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + currentContext.OrganizationAdmin(deletingUser.OrganizationId).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); + Assert.Contains("Only owners can delete other owners.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_LastOwner( + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, + OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + organizationUser.OrganizationId = deletingUser.OrganizationId; + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new[] { organizationUser }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, null)); + Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_Success( + OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationUser.OrganizationId = deletingUser.OrganizationId; + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new[] { deletingUser, organizationUser }); + currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); + + await sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_FilterInvalid(OrganizationUser organizationUser, OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationUsers = new[] { organizationUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId)); + Assert.Contains("Users invalid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_RemoveYourself( + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationUsers = new[] { deletingUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser }); + + var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + Assert.Contains("You cannot remove yourself.", result[0].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_NonOwnerRemoveOwner( + [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed)] OrganizationUser orgUser2, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + var organizationUsers = new[] { orgUser1 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser2 }); + + var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + Assert.Contains("Only owners can delete other owners.", result[0].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_LastOwner( + [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + var organizationUsers = new[] { orgUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner).Returns(organizationUsers); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUsersAsync(orgUser.OrganizationId, organizationUserIds, null)); + Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_Success( + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + var organizationUsers = new[] { orgUser1, orgUser2 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new[] { deletingUser, orgUser1 }); + currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); + + await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_InvalidStatus(OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Invited)] OrganizationUser orgUser, string key, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User not valid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_WrongOrganization(OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, string key, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(confirmingUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User not valid.", exception.Message); + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Owner)] + public async Task ConfirmUserToFree_AlreadyFreeAdminOrOwner_Throws(OrganizationUserType userType, Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + var userRepository = sutProvider.GetDependency(); + + org.PlanType = PlanType.Free; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + orgUser.Type = userType; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User can only be an admin of one free organization.", exception.Message); + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Owner)] + public async Task ConfirmUserToNonFree_AlreadyFreeAdminOrOwner_DoesNotThrow(PlanType planType, OrganizationUserType orgUserType, Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + var userRepository = sutProvider.GetDependency(); + + org.PlanType = planType; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + orgUser.Type = orgUserType; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + + await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); + + await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); + await sutProvider.GetDependency().Received(1).SendOrganizationConfirmedEmailAsync(org.Name, user.Email); + await organizationUserRepository.Received(1).ReplaceManyAsync(Arg.Is>(users => users.Contains(orgUser) && users.Count == 1)); + } + + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_SingleOrgPolicy(Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.Status = OrganizationUserStatusType.Accepted; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { singleOrgPolicy }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User is a member of another organization.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_TwoFactorPolicy(Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User does not have two-step login enabled.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_Success(Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, + [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); + userService.TwoFactorIsEnabledAsync(user).Returns(true); + + await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUsers_Success(Organization org, + OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser2, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser3, + OrganizationUser anotherOrgUser, User user1, User user2, User user3, + [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, + [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser1.OrganizationId = orgUser2.OrganizationId = orgUser3.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser1.UserId = user1.Id; + orgUser2.UserId = user2.Id; + orgUser3.UserId = user3.Id; + anotherOrgUser.UserId = user3.Id; + var orgUsers = new[] { orgUser1, orgUser2, orgUser3 }; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(orgUsers); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user1, user2, user3 }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); + userService.TwoFactorIsEnabledAsync(user1).Returns(true); + userService.TwoFactorIsEnabledAsync(user2).Returns(false); + userService.TwoFactorIsEnabledAsync(user3).Returns(true); + organizationUserRepository.GetManyByManyUsersAsync(default) + .ReturnsForAnyArgs(new[] { orgUser1, orgUser2, orgUser3, anotherOrgUser }); + + var keys = orgUsers.ToDictionary(ou => ou.Id, _ => key); + var result = await sutProvider.Sut.ConfirmUsersAsync(confirmingUser.OrganizationId, keys, confirmingUser.Id, userService); + Assert.Contains("", result[0].Item2); + Assert.Contains("User does not have two-step login enabled.", result[1].Item2); + Assert.Contains("User is a member of another organization.", result[2].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateOrganizationKeysAsync_WithoutManageResetPassword_Throws(Guid orgId, string publicKey, + string privateKey, SutProvider sutProvider) + { + var currentContext = Substitute.For(); + currentContext.ManageResetPassword(orgId).Returns(false); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateOrganizationKeysAsync(orgId, publicKey, privateKey)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Throws(Organization org, string publicKey, + string privateKey, SutProvider sutProvider) + { + var currentContext = sutProvider.GetDependency(); + currentContext.ManageResetPassword(org.Id).Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey)); + Assert.Contains("Organization Keys already exist", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Success(Organization org, string publicKey, + string privateKey, SutProvider sutProvider) + { + org.PublicKey = null; + org.PrivateKey = null; + + var currentContext = sutProvider.GetDependency(); + currentContext.ManageResetPassword(org.Id).Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + + await sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey); + } + + [Theory] + [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 1, 0, 2 })] + [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 4, -1, 6 })] + [InlineFreeOrganizationAutoData("Your plan does not allow seat autoscaling", 10, 0, null)] + public async Task UpdateSubscription_BadInputThrows(string expectedMessage, + int? maxAutoscaleSeats, int seatAdjustment, int? currentSeats, Organization organization, SutProvider sutProvider) + { + organization.Seats = currentSeats; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organization.Id, + seatAdjustment, maxAutoscaleSeats)); + + Assert.Contains(expectedMessage, exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateSubscription_NoOrganization_Throws(Guid organizationId, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns((Organization)null); + + await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organizationId, 0, null)); + } + + [Theory] + [InlinePaidOrganizationAutoData(0, 100, null, true, "")] + [InlinePaidOrganizationAutoData(0, 100, 100, true, "")] + [InlinePaidOrganizationAutoData(0, null, 100, true, "")] + [InlinePaidOrganizationAutoData(1, 100, null, true, "")] + [InlinePaidOrganizationAutoData(1, 100, 100, false, "Cannot invite new users. Seat limit has been reached")] + public void CanScale(int seatsToAdd, int? currentSeats, int? maxAutoscaleSeats, + bool expectedResult, string expectedFailureMessage, Organization organization, + SutProvider sutProvider) + { + organization.Seats = currentSeats; + organization.MaxAutoscaleSeats = maxAutoscaleSeats; + sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); + + var (result, failureMessage) = sutProvider.Sut.CanScale(organization, seatsToAdd); + + if (expectedFailureMessage == string.Empty) + { + Assert.Empty(failureMessage); + } + else + { + Assert.Contains(expectedFailureMessage, failureMessage); + } + Assert.Equal(expectedResult, result); + } + + [Theory, PaidOrganizationAutoData] + public void CanScale_FailsOnSelfHosted(Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SelfHosted.Returns(true); + var (result, failureMessage) = sutProvider.Sut.CanScale(organization, 10); + + Assert.False(result); + Assert.Contains("Cannot autoscale on self-hosted instance", failureMessage); + } + + [Theory, PaidOrganizationAutoData] + public async Task Delete_Success(Organization organization, SutProvider sutProvider) + { + var organizationRepository = sutProvider.GetDependency(); + var applicationCacheService = sutProvider.GetDependency(); + + await sutProvider.Sut.DeleteAsync(organization); + + await organizationRepository.Received().DeleteAsync(organization); + await applicationCacheService.Received().DeleteOrganizationAbilityAsync(organization.Id); + } + + [Theory, PaidOrganizationAutoData] + public async Task Delete_Fails_KeyConnector(Organization organization, SutProvider sutProvider, + SsoConfig ssoConfig) + { + ssoConfig.Enabled = true; + ssoConfig.SetData(new SsoConfigurationData { KeyConnectorEnabled = true }); + var ssoConfigRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var applicationCacheService = sutProvider.GetDependency(); + + ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(ssoConfig); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organization)); + + Assert.Contains("You cannot delete an Organization that is using Key Connector.", exception.Message); + + await organizationRepository.DidNotReceiveWithAnyArgs().DeleteAsync(default); + await applicationCacheService.DidNotReceiveWithAnyArgs().DeleteOrganizationAbilityAsync(default); + } } } diff --git a/test/Core.Test/Services/PolicyServiceTests.cs b/test/Core.Test/Services/PolicyServiceTests.cs index 29b4285a19..8f99b816c8 100644 --- a/test/Core.Test/Services/PolicyServiceTests.cs +++ b/test/Core.Test/Services/PolicyServiceTests.cs @@ -10,391 +10,392 @@ using NSubstitute; using Xunit; using PolicyFixtures = Bit.Core.Test.AutoFixture.PolicyFixtures; -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class PolicyServiceTests +namespace Bit.Core.Test.Services { - [Theory, BitAutoData] - public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest( - [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) + [SutProviderCustomize] + public class PolicyServiceTests { - SetupOrg(sutProvider, policy.OrganizationId, null); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( - [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) - { - var orgId = Guid.NewGuid(); - - SetupOrg(sutProvider, policy.OrganizationId, new Organization + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest( + [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) { - UsePolicies = false, - }); + SetupOrg(sutProvider, policy.OrganizationId, null); - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); - Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest( - [PolicyFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = false; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( + [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) { - Id = policy.OrganizationId, - UsePolicies = true, - }); + var orgId = Guid.NewGuid(); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.RequireSso) - .Returns(Task.FromResult(new Policy { Enabled = true })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([PolicyFixtures.Policy(Enums.PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = false; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.MaximumVaultTimeout) - .Returns(new Policy { Enabled = true }); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory] - [BitAutoData(PolicyType.SingleOrg)] - [BitAutoData(PolicyType.RequireSso)] - public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBadRequest( - Enums.PolicyType policyType, - Policy policy, - SutProvider sutProvider) - { - policy.Enabled = false; - policy.Type = policyType; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - var ssoConfig = new SsoConfig { Enabled = true }; - var data = new SsoConfigurationData { KeyConnectorEnabled = true }; - ssoConfig.SetData(data); - - sutProvider.GetDependency() - .GetByOrganizationIdAsync(policy.OrganizationId) - .Returns(ssoConfig); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync( - [PolicyFixtures.Policy(Enums.PolicyType.RequireSso)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_NewPolicy_Created( - [PolicyFixtures.Policy(PolicyType.ResetPassword)] Policy policy, SutProvider sutProvider) - { - policy.Id = default; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - var utcNow = DateTime.UtcNow; - - await sutProvider.Sut.SaveAsync(policy, Substitute.For(), Substitute.For(), Guid.NewGuid()); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync( - [PolicyFixtures.Policy(PolicyType.MaximumVaultTimeout)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( - [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) - { - // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - - var org = new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - Name = "TEST", - }; - - SetupOrg(sutProvider, policy.OrganizationId, org); - - sutProvider.GetDependency() - .GetByIdAsync(policy.Id) - .Returns(new Policy + SetupOrg(sutProvider, policy.OrganizationId, new Organization { - Id = policy.Id, - Type = PolicyType.TwoFactorAuthentication, - Enabled = false, + UsePolicies = false, }); - var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "test@bitwarden.com", - Name = "TEST", - UserId = Guid.NewGuid(), - }; + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List + Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest( + [PolicyFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = false; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization { - orgUserDetail, + Id = policy.OrganizationId, + UsePolicies = true, }); - var userService = Substitute.For(); - var organizationService = Substitute.For(); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.RequireSso) + .Returns(Task.FromResult(new Policy { Enabled = true })); - userService.TwoFactorIsEnabledAsync(orgUserDetail) - .Returns(false); + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); - var utcNow = DateTime.UtcNow; + Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - var savingUserId = Guid.NewGuid(); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); - await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } - await organizationService.Received() - .DeleteUserAsync(policy.OrganizationId, orgUserDetail.Id, savingUserId); - - await sutProvider.GetDependency().Received() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(org.Name, orgUserDetail.Email); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg( - [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) - { - // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - - var org = new Organization + [Theory, BitAutoData] + public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([PolicyFixtures.Policy(Enums.PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) { - Id = policy.OrganizationId, - UsePolicies = true, - Name = "TEST", - }; + policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, org); - - sutProvider.GetDependency() - .GetByIdAsync(policy.Id) - .Returns(new Policy + SetupOrg(sutProvider, policy.OrganizationId, new Organization { - Id = policy.Id, - Type = PolicyType.SingleOrg, - Enabled = false, + Id = policy.OrganizationId, + UsePolicies = true, }); - var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "test@bitwarden.com", - Name = "TEST", - UserId = Guid.NewGuid(), - }; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.MaximumVaultTimeout) + .Returns(new Policy { Enabled = true }); - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory] + [BitAutoData(PolicyType.SingleOrg)] + [BitAutoData(PolicyType.RequireSso)] + public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBadRequest( + Enums.PolicyType policyType, + Policy policy, + SutProvider sutProvider) + { + policy.Enabled = false; + policy.Type = policyType; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization { - orgUserDetail, + Id = policy.OrganizationId, + UsePolicies = true, }); - var userService = Substitute.For(); - var organizationService = Substitute.For(); + var ssoConfig = new SsoConfig { Enabled = true }; + var data = new SsoConfigurationData { KeyConnectorEnabled = true }; + ssoConfig.SetData(data); - userService.TwoFactorIsEnabledAsync(orgUserDetail) - .Returns(false); + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policy.OrganizationId) + .Returns(ssoConfig); - var utcNow = DateTime.UtcNow; + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); - var savingUserId = Guid.NewGuid(); + Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); + [Theory, BitAutoData] + public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync( + [PolicyFixtures.Policy(Enums.PolicyType.RequireSso)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = true; - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) + .Returns(Task.FromResult(new Policy { Enabled = false })); - private static void SetupOrg(SutProvider sutProvider, Guid organizationId, Organization organization) - { - sutProvider.GetDependency() - .GetByIdAsync(organizationId) - .Returns(Task.FromResult(organization)); + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_NewPolicy_Created( + [PolicyFixtures.Policy(PolicyType.ResetPassword)] Policy policy, SutProvider sutProvider) + { + policy.Id = default; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + var utcNow = DateTime.UtcNow; + + await sutProvider.Sut.SaveAsync(policy, Substitute.For(), Substitute.For(), Guid.NewGuid()); + + await sutProvider.GetDependency().Received() + .LogPolicyEventAsync(policy, EventType.Policy_Updated); + + await sutProvider.GetDependency().Received() + .UpsertAsync(policy); + + Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync( + [PolicyFixtures.Policy(PolicyType.MaximumVaultTimeout)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = true; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.SingleOrg) + .Returns(Task.FromResult(new Policy { Enabled = false })); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( + [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) + { + // If the policy that this is updating isn't enabled then do some work now that the current one is enabled + + var org = new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + Name = "TEST", + }; + + SetupOrg(sutProvider, policy.OrganizationId, org); + + sutProvider.GetDependency() + .GetByIdAsync(policy.Id) + .Returns(new Policy + { + Id = policy.Id, + Type = PolicyType.TwoFactorAuthentication, + Enabled = false, + }); + + var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "test@bitwarden.com", + Name = "TEST", + UserId = Guid.NewGuid(), + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policy.OrganizationId) + .Returns(new List + { + orgUserDetail, + }); + + var userService = Substitute.For(); + var organizationService = Substitute.For(); + + userService.TwoFactorIsEnabledAsync(orgUserDetail) + .Returns(false); + + var utcNow = DateTime.UtcNow; + + var savingUserId = Guid.NewGuid(); + + await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); + + await organizationService.Received() + .DeleteUserAsync(policy.OrganizationId, orgUserDetail.Id, savingUserId); + + await sutProvider.GetDependency().Received() + .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(org.Name, orgUserDetail.Email); + + await sutProvider.GetDependency().Received() + .LogPolicyEventAsync(policy, EventType.Policy_Updated); + + await sutProvider.GetDependency().Received() + .UpsertAsync(policy); + + Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg( + [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) + { + // If the policy that this is updating isn't enabled then do some work now that the current one is enabled + + var org = new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + Name = "TEST", + }; + + SetupOrg(sutProvider, policy.OrganizationId, org); + + sutProvider.GetDependency() + .GetByIdAsync(policy.Id) + .Returns(new Policy + { + Id = policy.Id, + Type = PolicyType.SingleOrg, + Enabled = false, + }); + + var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "test@bitwarden.com", + Name = "TEST", + UserId = Guid.NewGuid(), + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policy.OrganizationId) + .Returns(new List + { + orgUserDetail, + }); + + var userService = Substitute.For(); + var organizationService = Substitute.For(); + + userService.TwoFactorIsEnabledAsync(orgUserDetail) + .Returns(false); + + var utcNow = DateTime.UtcNow; + + var savingUserId = Guid.NewGuid(); + + await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); + + await sutProvider.GetDependency().Received() + .LogPolicyEventAsync(policy, EventType.Policy_Updated); + + await sutProvider.GetDependency().Received() + .UpsertAsync(policy); + + Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } + + private static void SetupOrg(SutProvider sutProvider, Guid organizationId, Organization organization) + { + sutProvider.GetDependency() + .GetByIdAsync(organizationId) + .Returns(Task.FromResult(organization)); + } } } diff --git a/test/Core.Test/Services/RelayPushNotificationServiceTests.cs b/test/Core.Test/Services/RelayPushNotificationServiceTests.cs index ccf5e3d4bb..68b8633e2e 100644 --- a/test/Core.Test/Services/RelayPushNotificationServiceTests.cs +++ b/test/Core.Test/Services/RelayPushNotificationServiceTests.cs @@ -6,40 +6,41 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class RelayPushNotificationServiceTests +namespace Bit.Core.Test.Services { - private readonly RelayPushNotificationService _sut; - - private readonly IHttpClientFactory _httpFactory; - private readonly IDeviceRepository _deviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; - - public RelayPushNotificationServiceTests() + public class RelayPushNotificationServiceTests { - _httpFactory = Substitute.For(); - _deviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - _logger = Substitute.For>(); + private readonly RelayPushNotificationService _sut; - _sut = new RelayPushNotificationService( - _httpFactory, - _deviceRepository, - _globalSettings, - _httpContextAccessor, - _logger - ); - } + private readonly IHttpClientFactory _httpFactory; + private readonly IDeviceRepository _deviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public RelayPushNotificationServiceTests() + { + _httpFactory = Substitute.For(); + _deviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + _logger = Substitute.For>(); + + _sut = new RelayPushNotificationService( + _httpFactory, + _deviceRepository, + _globalSettings, + _httpContextAccessor, + _logger + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs b/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs index 926a19bc00..371d50168b 100644 --- a/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs +++ b/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs @@ -4,34 +4,35 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class RelayPushRegistrationServiceTests +namespace Bit.Core.Test.Services { - private readonly RelayPushRegistrationService _sut; - - private readonly IHttpClientFactory _httpFactory; - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - - public RelayPushRegistrationServiceTests() + public class RelayPushRegistrationServiceTests { - _globalSettings = new GlobalSettings(); - _httpFactory = Substitute.For(); - _logger = Substitute.For>(); + private readonly RelayPushRegistrationService _sut; - _sut = new RelayPushRegistrationService( - _httpFactory, - _globalSettings, - _logger - ); - } + private readonly IHttpClientFactory _httpFactory; + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); + public RelayPushRegistrationServiceTests() + { + _globalSettings = new GlobalSettings(); + _httpFactory = Substitute.For(); + _logger = Substitute.For>(); + + _sut = new RelayPushRegistrationService( + _httpFactory, + _globalSettings, + _logger + ); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs b/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs index 9cfe2c9e84..4ee3460abf 100644 --- a/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs +++ b/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs @@ -3,26 +3,27 @@ using Bit.Core.Services; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class RepositoryEventWriteServiceTests +namespace Bit.Core.Test.Services { - private readonly RepositoryEventWriteService _sut; - - private readonly IEventRepository _eventRepository; - - public RepositoryEventWriteServiceTests() + public class RepositoryEventWriteServiceTests { - _eventRepository = Substitute.For(); + private readonly RepositoryEventWriteService _sut; - _sut = new RepositoryEventWriteService(_eventRepository); - } + private readonly IEventRepository _eventRepository; - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); + public RepositoryEventWriteServiceTests() + { + _eventRepository = Substitute.For(); + + _sut = new RepositoryEventWriteService(_eventRepository); + } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } } diff --git a/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs b/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs index 3c64e5c406..8366cc266a 100644 --- a/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs @@ -8,77 +8,78 @@ using SendGrid; using SendGrid.Helpers.Mail; using Xunit; -namespace Bit.Core.Test.Services; - -public class SendGridMailDeliveryServiceTests : IDisposable +namespace Bit.Core.Test.Services { - private readonly SendGridMailDeliveryService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly ISendGridClient _sendGridClient; - - public SendGridMailDeliveryServiceTests() + public class SendGridMailDeliveryServiceTests : IDisposable { - _globalSettings = new GlobalSettings + private readonly SendGridMailDeliveryService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly ISendGridClient _sendGridClient; + + public SendGridMailDeliveryServiceTests() { - Mail = + _globalSettings = new GlobalSettings { - SendGridApiKey = "SendGridApiKey" - } - }; + Mail = + { + SendGridApiKey = "SendGridApiKey" + } + }; - _hostingEnvironment = Substitute.For(); - _logger = Substitute.For>(); - _sendGridClient = Substitute.For(); + _hostingEnvironment = Substitute.For(); + _logger = Substitute.For>(); + _sendGridClient = Substitute.For(); - _sut = new SendGridMailDeliveryService( - _sendGridClient, - _globalSettings, - _hostingEnvironment, - _logger - ); - } + _sut = new SendGridMailDeliveryService( + _sendGridClient, + _globalSettings, + _hostingEnvironment, + _logger + ); + } - public void Dispose() - { - _sut?.Dispose(); - } - - [Fact] - public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() - { - var mailMessage = new MailMessage + public void Dispose() { - ToEmails = new List { "ToEmails" }, - BccEmails = new List { "BccEmails" }, - Subject = "Subject", - HtmlContent = "HtmlContent", - TextContent = "TextContent", - Category = "Category" - }; + _sut?.Dispose(); + } - _sendGridClient.SendEmailAsync(Arg.Any()).Returns( - new Response(System.Net.HttpStatusCode.OK, null, null)); - await _sut.SendEmailAsync(mailMessage); - - await _sendGridClient.Received(1).SendEmailAsync( - Arg.Do(msg => + [Fact] + public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() + { + var mailMessage = new MailMessage { - msg.Received(1).AddTos(new List { new EmailAddress(mailMessage.ToEmails.First()) }); - msg.Received(1).AddBccs(new List { new EmailAddress(mailMessage.ToEmails.First()) }); + ToEmails = new List { "ToEmails" }, + BccEmails = new List { "BccEmails" }, + Subject = "Subject", + HtmlContent = "HtmlContent", + TextContent = "TextContent", + Category = "Category" + }; - Assert.Equal(mailMessage.Subject, msg.Subject); - Assert.Equal(mailMessage.HtmlContent, msg.HtmlContent); - Assert.Equal(mailMessage.TextContent, msg.PlainTextContent); + _sendGridClient.SendEmailAsync(Arg.Any()).Returns( + new Response(System.Net.HttpStatusCode.OK, null, null)); + await _sut.SendEmailAsync(mailMessage); - Assert.Contains("type:Cateogry", msg.Categories); - Assert.Contains(msg.Categories, x => x.StartsWith("env:")); - Assert.Contains(msg.Categories, x => x.StartsWith("sender:")); + await _sendGridClient.Received(1).SendEmailAsync( + Arg.Do(msg => + { + msg.Received(1).AddTos(new List { new EmailAddress(mailMessage.ToEmails.First()) }); + msg.Received(1).AddBccs(new List { new EmailAddress(mailMessage.ToEmails.First()) }); - msg.Received(1).SetClickTracking(false, false); - msg.Received(1).SetOpenTracking(false); - })); + Assert.Equal(mailMessage.Subject, msg.Subject); + Assert.Equal(mailMessage.HtmlContent, msg.HtmlContent); + Assert.Equal(mailMessage.TextContent, msg.PlainTextContent); + + Assert.Contains("type:Cateogry", msg.Categories); + Assert.Contains(msg.Categories, x => x.StartsWith("env:")); + Assert.Contains(msg.Categories, x => x.StartsWith("sender:")); + + msg.Received(1).SetClickTracking(false, false); + msg.Received(1).SetOpenTracking(false); + })); + } } } diff --git a/test/Core.Test/Services/SendServiceTests.cs b/test/Core.Test/Services/SendServiceTests.cs index 1468bd0b05..aed7d2f042 100644 --- a/test/Core.Test/Services/SendServiceTests.cs +++ b/test/Core.Test/Services/SendServiceTests.cs @@ -14,748 +14,749 @@ using Microsoft.AspNetCore.Identity; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class SendServiceTests +namespace Bit.Core.Test.Services { - private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser, - SutProvider sutProvider, Send send) + public class SendServiceTests { - send.Id = default; - send.Type = sendType; - - sutProvider.GetDependency().GetCountByTypeApplicableToUserIdAsync( - Arg.Any(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser ? 1 : 0); - } - - // Disable Send policy check - - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableSend_Applies_throws(SendType sendType, - SutProvider sutProvider, Send send) - { - SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send); - - await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); - } - - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType, - SutProvider sutProvider, Send send) - { - SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send); - - await sutProvider.Sut.SaveSendAsync(send); - - await sutProvider.GetDependency().Received(1).CreateAsync(send); - } - - // Send Options Policy - Disable Hide Email check - - private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser, - SutProvider sutProvider, Send send, Policy policy) - { - send.HideEmail = true; - - var sendOptions = new SendOptionsPolicyData + private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser, + SutProvider sutProvider, Send send) { - DisableHideEmail = disableHideEmailAppliesToUser - }; - policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); + send.Id = default; + send.Type = sendType; - sutProvider.GetDependency().GetManyByTypeApplicableToUserIdAsync( - Arg.Any(), PolicyType.SendOptions).Returns(new List + sutProvider.GetDependency().GetCountByTypeApplicableToUserIdAsync( + Arg.Any(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser ? 1 : 0); + } + + // Disable Send policy check + + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableSend_Applies_throws(SendType sendType, + SutProvider sutProvider, Send send) + { + SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send); + + await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); + } + + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType, + SutProvider sutProvider, Send send) + { + SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send); + + await sutProvider.Sut.SaveSendAsync(send); + + await sutProvider.GetDependency().Received(1).CreateAsync(send); + } + + // Send Options Policy - Disable Hide Email check + + private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser, + SutProvider sutProvider, Send send, Policy policy) + { + send.HideEmail = true; + + var sendOptions = new SendOptionsPolicyData { - policy, + DisableHideEmail = disableHideEmailAppliesToUser + }; + policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, }); - } - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType, - SutProvider sutProvider, Send send, Policy policy) - { - SaveSendAsync_Setup(sendType, false, sutProvider, send); - SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy); + sutProvider.GetDependency().GetManyByTypeApplicableToUserIdAsync( + Arg.Any(), PolicyType.SendOptions).Returns(new List + { + policy, + }); + } - await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); - } - - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType, - SutProvider sutProvider, Send send, Policy policy) - { - SaveSendAsync_Setup(sendType, false, sutProvider, send); - SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy); - - await sutProvider.Sut.SaveSendAsync(send); - - await sutProvider.GetDependency().Received(1).CreateAsync(send); - } - - [Theory] - [InlineUserSendAutoData] - [InlineUserSendAutoData] - public async void SaveSendAsync_ExistingSend_Updates(SutProvider sutProvider, - Send send) - { - send.Id = Guid.NewGuid(); - - var now = DateTime.UtcNow; - await sutProvider.Sut.SaveSendAsync(send); - - Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1)); - - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .PushSyncSendUpdateAsync(send); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - send.Type = SendType.Text; - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 0) - ); - - Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - send.Type = SendType.File; - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 0) - ); - - Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType, + SutProvider sutProvider, Send send, Policy policy) { - Id = Guid.NewGuid(), - }; + SaveSendAsync_Setup(sendType, false, sutProvider, send); + SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy); - send.UserId = user.Id; - send.Type = SendType.File; + await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); + } - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(false); - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); - - Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType, + SutProvider sutProvider, Send send, Policy policy) { - Id = Guid.NewGuid(), - EmailVerified = false, - }; + SaveSendAsync_Setup(sendType, false, sutProvider, send); + SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy); - send.UserId = user.Id; - send.Type = SendType.File; + await sutProvider.Sut.SaveSendAsync(send); - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + await sutProvider.GetDependency().Received(1).CreateAsync(send); + } - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); - - Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData] + [InlineUserSendAutoData] + public async void SaveSendAsync_ExistingSend_Updates(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = true, - MaxStorageGb = null, - Storage = 0, - }; + send.Id = Guid.NewGuid(); - send.UserId = user.Id; - send.Type = SendType.File; + var now = DateTime.UtcNow; + await sutProvider.Sut.SaveSendAsync(send); - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1)); - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(send); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + await sutProvider.GetDependency() + .Received(1) + .PushSyncSendUpdateAsync(send); + } - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = true, - MaxStorageGb = 2, - Storage = 2 * UserTests.Multiplier, - }; + send.Type = SendType.Text; - send.UserId = user.Id; - send.Type = SendType.File; + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 0) + ); - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); - - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = false, - }; + send.Type = SendType.File; - send.UserId = user.Id; - send.Type = SendType.File; + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 0) + ); - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .SelfHosted = true; - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier) - ); - - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = false, - }; + var user = new User + { + Id = Guid.NewGuid(), + }; - send.UserId = user.Id; - send.Type = SendType.File; + send.UserId = user.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(false); - sutProvider.GetDependency() - .SelfHosted = false; + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) - ); + Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var org = new Organization + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - MaxStorageGb = null, - }; + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = false, + }; - send.UserId = null; - send.OrganizationId = org.Id; - send.Type = SendType.File; + send.UserId = user.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(org.Id) - .Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); - Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var org = new Organization + Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - MaxStorageGb = null, - }; + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = true, + MaxStorageGb = null, + Storage = 0, + }; - send.UserId = null; - send.OrganizationId = org.Id; - send.Type = SendType.File; + send.UserId = user.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(org.Id) - .Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); - Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var org = new Organization + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - MaxStorageGb = 1, - }; + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = true, + MaxStorageGb = 2, + Storage = 2 * UserTests.Multiplier, + }; - send.UserId = null; - send.OrganizationId = org.Id; - send.Type = SendType.File; + send.UserId = user.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(org.Id) - .Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) - ); + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_HasEnouphStorage_Success(SutProvider sutProvider, - Send send) - { - var user = new User + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - EmailVerified = true, - MaxStorageGb = 10, - }; + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = false, + }; - var data = new SendFileData + send.UserId = user.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .SelfHosted = true; + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier) + ); + + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = false, + }; + + send.UserId = user.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .SelfHosted = false; + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) + ); + + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var org = new Organization + { + Id = Guid.NewGuid(), + MaxStorageGb = null, + }; + + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(org.Id) + .Returns(org); + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); + + Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var org = new Organization + { + Id = Guid.NewGuid(), + MaxStorageGb = null, + }; + + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(org.Id) + .Returns(org); + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); + + Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var org = new Organization + { + Id = Guid.NewGuid(), + MaxStorageGb = 1, + }; + + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(org.Id) + .Returns(org); + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) + ); + + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_HasEnouphStorage_Success(SutProvider sutProvider, + Send send) + { + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = true, + MaxStorageGb = 10, + }; + + var data = new SendFileData + { + + }; + + send.UserId = user.Id; + send.Type = SendType.File; + + var testUrl = "https://test.com/"; + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .GetSendFileUploadUrlAsync(send, Arg.Any()) + .Returns(testUrl); + + var utcNow = DateTime.UtcNow; + + var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier); + + Assert.Equal(testUrl, url); + Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + + await sutProvider.GetDependency() + .Received(1) + .GetSendFileUploadUrlAsync(send, Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncSendUpdateAsync(send); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_HasEnouphStorage_SendFileThrows_CleansUp(SutProvider sutProvider, + Send send) + { + var user = new User + { + Id = Guid.NewGuid(), + EmailVerified = true, + MaxStorageGb = 10, + }; + + var data = new SendFileData + { + + }; + + send.UserId = user.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .GetSendFileUploadUrlAsync(send, Arg.Any()) + .Returns(callInfo => throw new Exception("Problem")); + + var utcNow = DateTime.UtcNow; + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier) + ); + + Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.Equal("Problem", exception.Message); + + await sutProvider.GetDependency() + .Received(1) + .GetSendFileUploadUrlAsync(send, Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncSendUpdateAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .DeleteFileAsync(send, Arg.Any()); + } + + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider sutProvider) { - }; + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null) + ); - send.UserId = user.Id; - send.Type = SendType.File; + Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - var testUrl = "https://test.com/"; - - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .GetSendFileUploadUrlAsync(send, Arg.Any()) - .Returns(testUrl); - - var utcNow = DateTime.UtcNow; - - var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier); - - Assert.Equal(testUrl, url); - Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - - await sutProvider.GetDependency() - .Received(1) - .GetSendFileUploadUrlAsync(send, Arg.Any()); - - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .PushSyncSendUpdateAsync(send); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_HasEnouphStorage_SendFileThrows_CleansUp(SutProvider sutProvider, - Send send) - { - var user = new User + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider sutProvider, + Send send) { - Id = Guid.NewGuid(), - EmailVerified = true, - MaxStorageGb = 10, - }; + send.Data = null; - var data = new SendFileData + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) + ); + + Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider sutProvider, + Send send) { + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) + ); - }; + Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - send.UserId = user.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .GetSendFileUploadUrlAsync(send, Arg.Any()) - .Returns(callInfo => throw new Exception("Problem")); - - var utcNow = DateTime.UtcNow; - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier) - ); - - Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.Equal("Problem", exception.Message); - - await sutProvider.GetDependency() - .Received(1) - .GetSendFileUploadUrlAsync(send, Arg.Any()); - - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .PushSyncSendUpdateAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .DeleteFileAsync(send, Arg.Any()); - } - - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider sutProvider) - { - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null) - ); - - Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - send.Data = null; - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) - ); - - Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) - ); - - Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_Success(SutProvider sutProvider, - Send send) - { - var fileContents = "Test file content"; - - var sendFileData = new SendFileData + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_Success(SutProvider sutProvider, + Send send) { - Id = "TEST", - Size = fileContents.Length, - Validated = false, - }; + var fileContents = "Test file content"; - send.Type = SendType.File; - send.Data = JsonSerializer.Serialize(sendFileData); + var sendFileData = new SendFileData + { + Id = "TEST", + Size = fileContents.Length, + Validated = false, + }; - sutProvider.GetDependency() - .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) - .Returns((true, sendFileData.Size)); + send.Type = SendType.File; + send.Data = JsonSerializer.Serialize(sendFileData); - await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send); - } + sutProvider.GetDependency() + .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) + .Returns((true, sendFileData.Size)); - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_InvalidSize(SutProvider sutProvider, - Send send) - { - var fileContents = "Test file content"; + await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send); + } - var sendFileData = new SendFileData + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_InvalidSize(SutProvider sutProvider, + Send send) { - Id = "TEST", - Size = fileContents.Length, - }; + var fileContents = "Test file content"; - send.Type = SendType.File; - send.Data = JsonSerializer.Serialize(sendFileData); + var sendFileData = new SendFileData + { + Id = "TEST", + Size = fileContents.Length, + }; - sutProvider.GetDependency() - .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) - .Returns((false, sendFileData.Size)); + send.Type = SendType.File; + send.Data = JsonSerializer.Serialize(sendFileData); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send) - ); - } + sutProvider.GetDependency() + .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) + .Returns((false, sendFileData.Size)); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_Success(SutProvider sutProvider, Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = 10; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send) + ); + } - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") - .Returns(PasswordVerificationResult.Success); + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_Success(SutProvider sutProvider, Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = 10; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") + .Returns(PasswordVerificationResult.Success); - Assert.True(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; + Assert.True(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") - .Returns(PasswordVerificationResult.Success); + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") + .Returns(PasswordVerificationResult.Success); - Assert.True(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider sutProvider) - { - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.Success); + Assert.True(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(null, "TEST"); + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider sutProvider) + { + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.Success); - Assert.False(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(null, "TEST"); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; - send.Password = "HASH"; + Assert.False(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.Success); + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; + send.Password = "HASH"; - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, null); + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.Success); - Assert.False(grant); - Assert.True(passwordRequiredError); - Assert.False(passwordInvalidError); - } + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, null); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; - send.Password = "TEST"; + Assert.False(grant); + Assert.True(passwordRequiredError); + Assert.False(passwordInvalidError); + } - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.SuccessRehashNeeded); + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; + send.Password = "TEST"; - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.SuccessRehashNeeded); - sutProvider.GetDependency>() - .Received(1) - .HashPassword(Arg.Any(), "TEST"); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - Assert.True(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + sutProvider.GetDependency>() + .Received(1) + .HashPassword(Arg.Any(), "TEST"); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; - send.Password = "TEST"; + Assert.True(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.Failed); + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; + send.Password = "TEST"; - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.Failed); - Assert.False(grant); - Assert.False(passwordRequiredError); - Assert.True(passwordInvalidError); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + + Assert.False(grant); + Assert.False(passwordRequiredError); + Assert.True(passwordInvalidError); + } } } diff --git a/test/Core.Test/Services/SsoConfigServiceTests.cs b/test/Core.Test/Services/SsoConfigServiceTests.cs index fa5cb904ad..475a2a1c52 100644 --- a/test/Core.Test/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Services/SsoConfigServiceTests.cs @@ -9,308 +9,309 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; - -public class SsoConfigServiceTests +namespace Bit.Core.Test.Services { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_ExistingItem_UpdatesRevisionDateOnly(SutProvider sutProvider, - Organization organization) + public class SsoConfigServiceTests { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_ExistingItem_UpdatesRevisionDateOnly(SutProvider sutProvider, + Organization organization) { - Id = 1, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + var utcNow = DateTime.UtcNow; - sutProvider.GetDependency() - .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); - - await sutProvider.Sut.SaveAsync(ssoConfig, organization); - - await sutProvider.GetDependency().Received() - .UpsertAsync(ssoConfig); - - Assert.Equal(utcNow.AddDays(-10), ssoConfig.CreationDate); - Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_NewItem_UpdatesCreationAndRevisionDate(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig - { - Id = default, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - sutProvider.GetDependency() - .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); - - await sutProvider.Sut.SaveAsync(ssoConfig, organization); - - await sutProvider.GetDependency().Received() - .UpsertAsync(ssoConfig); - - Assert.True(ssoConfig.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_PreventDisablingKeyConnector(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; - - var oldSsoConfig = new SsoConfig - { - Id = 1, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - var newSsoConfig = new SsoConfig - { - Id = 1, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow, - }; - - var ssoConfigRepository = sutProvider.GetDependency(); - ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); - ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) - .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = true } }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(newSsoConfig, organization)); - - Assert.Contains("Key Connector cannot be disabled at this moment.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_AllowDisablingKeyConnectorWhenNoUserIsUsingIt( - SutProvider sutProvider, Organization organization) - { - var utcNow = DateTime.UtcNow; - - var oldSsoConfig = new SsoConfig - { - Id = 1, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - var newSsoConfig = new SsoConfig - { - Id = 1, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow, - }; - - var ssoConfigRepository = sutProvider.GetDependency(); - ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); - ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) - .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = false } }); - - await sutProvider.Sut.SaveAsync(newSsoConfig, organization); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_SingleOrgNotEnabled_Throws(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig - { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - - Assert.Contains("Key Connector requires the Single Organization policy to be enabled.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_SsoPolicyNotEnabled_Throws(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig - { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Enums.PolicyType.SingleOrg).Returns(new Policy - { - Enabled = true - }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - - Assert.Contains("Key Connector requires the Single Sign-On Authentication policy to be enabled.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_SsoConfigNotEnabled_Throws(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig - { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = false, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true - }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - - Assert.Contains("You must enable SSO to use Key Connector.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_KeyConnectorAbilityNotEnabled_Throws(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; - - organization.UseKeyConnector = false; - var ssoConfig = new SsoConfig - { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; - - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy + var ssoConfig = new SsoConfig { + Id = 1, + Data = "{}", Enabled = true, - }); + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + sutProvider.GetDependency() + .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); - Assert.Contains("Organization cannot use Key Connector.", exception.Message); + await sutProvider.Sut.SaveAsync(ssoConfig, organization); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + await sutProvider.GetDependency().Received() + .UpsertAsync(ssoConfig); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_Success(SutProvider sutProvider, - Organization organization) - { - var utcNow = DateTime.UtcNow; + Assert.Equal(utcNow.AddDays(-10), ssoConfig.CreationDate); + Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - organization.UseKeyConnector = true; - var ssoConfig = new SsoConfig + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_NewItem_UpdatesCreationAndRevisionDate(SutProvider sutProvider, + Organization organization) { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + var utcNow = DateTime.UtcNow; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy + var ssoConfig = new SsoConfig { + Id = default, + Data = "{}", Enabled = true, - }); + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - await sutProvider.Sut.SaveAsync(ssoConfig, organization); + sutProvider.GetDependency() + .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); - await sutProvider.GetDependency().ReceivedWithAnyArgs() - .UpsertAsync(default); + await sutProvider.Sut.SaveAsync(ssoConfig, organization); + + await sutProvider.GetDependency().Received() + .UpsertAsync(ssoConfig); + + Assert.True(ssoConfig.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_PreventDisablingKeyConnector(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + var oldSsoConfig = new SsoConfig + { + Id = 1, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + var newSsoConfig = new SsoConfig + { + Id = 1, + Data = "{}", + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow, + }; + + var ssoConfigRepository = sutProvider.GetDependency(); + ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); + ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) + .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = true } }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(newSsoConfig, organization)); + + Assert.Contains("Key Connector cannot be disabled at this moment.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_AllowDisablingKeyConnectorWhenNoUserIsUsingIt( + SutProvider sutProvider, Organization organization) + { + var utcNow = DateTime.UtcNow; + + var oldSsoConfig = new SsoConfig + { + Id = 1, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + var newSsoConfig = new SsoConfig + { + Id = 1, + Data = "{}", + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow, + }; + + var ssoConfigRepository = sutProvider.GetDependency(); + ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); + ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) + .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = false } }); + + await sutProvider.Sut.SaveAsync(newSsoConfig, organization); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_SingleOrgNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + + Assert.Contains("Key Connector requires the Single Organization policy to be enabled.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_SsoPolicyNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Enums.PolicyType.SingleOrg).Returns(new Policy + { + Enabled = true + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + + Assert.Contains("Key Connector requires the Single Sign-On Authentication policy to be enabled.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_SsoConfigNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = false, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Arg.Any()).Returns(new Policy + { + Enabled = true + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + + Assert.Contains("You must enable SSO to use Key Connector.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_KeyConnectorAbilityNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + organization.UseKeyConnector = false; + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Arg.Any()).Returns(new Policy + { + Enabled = true, + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + + Assert.Contains("Organization cannot use Key Connector.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_Success(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + organization.UseKeyConnector = true; + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Arg.Any()).Returns(new Policy + { + Enabled = true, + }); + + await sutProvider.Sut.SaveAsync(ssoConfig, organization); + + await sutProvider.GetDependency().ReceivedWithAnyArgs() + .UpsertAsync(default); + } } } diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index a14f183d44..0c4ea5c031 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -12,362 +12,363 @@ using NSubstitute; using Xunit; using PaymentMethodType = Bit.Core.Enums.PaymentMethodType; -namespace Bit.Core.Test.Services; - -public class StripePaymentServiceTests +namespace Bit.Core.Test.Services { - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Credit)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.WireTransfer)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.AppleInApp)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.GoogleInApp)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Check)] - public async void PurchaseOrganizationAsync_Invalid(PaymentMethodType paymentMethodType, SutProvider sutProvider) + public class StripePaymentServiceTests { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(null, paymentMethodType, null, null, 0, 0, false, null)); - - Assert.Equal("Payment method is not supported at this time.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Credit)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.WireTransfer)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.AppleInApp)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.GoogleInApp)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Check)] + public async void PurchaseOrganizationAsync_Invalid(PaymentMethodType paymentMethodType, SutProvider sutProvider) { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(null, paymentMethodType, null, null, 0, 0, false, null)); + + Assert.Equal("Payment method is not supported at this time.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - - Assert.Null(result); - Assert.Equal(GatewayType.Stripe, organization.Gateway); - Assert.Equal("C-1", organization.GatewayCustomerId); - Assert.Equal("S-1", organization.GatewaySubscriptionId); - Assert.True(organization.Enabled); - Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - - await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => - c.Description == organization.BusinessName && - c.Email == organization.BillingEmail && - c.Source == paymentToken && - c.PaymentMethod == null && - !c.Metadata.Any() && - c.InvoiceSettings.DefaultPaymentMethod == null && - c.Address.Country == taxInfo.BillingAddressCountry && - c.Address.PostalCode == taxInfo.BillingAddressPostalCode && - c.Address.Line1 == taxInfo.BillingAddressLine1 && - c.Address.Line2 == taxInfo.BillingAddressLine2 && - c.Address.City == taxInfo.BillingAddressCity && - c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null - )); - - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.Customer == "C-1" && - s.Expand[0] == "latest_invoice.payment_intent" && - s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && - s.Items.Count == 0 - )); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_PM(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - paymentToken = "pm_" + paymentToken; - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); - - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - - Assert.Null(result); - Assert.Equal(GatewayType.Stripe, organization.Gateway); - Assert.Equal("C-1", organization.GatewayCustomerId); - Assert.Equal("S-1", organization.GatewaySubscriptionId); - Assert.True(organization.Enabled); - Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - - await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => - c.Description == organization.BusinessName && - c.Email == organization.BillingEmail && - c.Source == null && - c.PaymentMethod == paymentToken && - !c.Metadata.Any() && - c.InvoiceSettings.DefaultPaymentMethod == paymentToken && - c.Address.Country == taxInfo.BillingAddressCountry && - c.Address.PostalCode == taxInfo.BillingAddressPostalCode && - c.Address.Line1 == taxInfo.BillingAddressLine1 && - c.Address.Line2 == taxInfo.BillingAddressLine2 && - c.Address.City == taxInfo.BillingAddressCity && - c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null - )); - - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.Customer == "C-1" && - s.Expand[0] == "latest_invoice.payment_intent" && - s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && - s.Items.Count == 0 - )); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_TaxRate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); - sutProvider.GetDependency().GetByLocationAsync(Arg.Is(t => - t.Country == taxInfo.BillingAddressCountry && t.PostalCode == taxInfo.BillingAddressPostalCode)) - .Returns(new List { new() { Id = "T-1" } }); - - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - - Assert.Null(result); - - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.DefaultTaxRates.Count == 1 && - s.DefaultTaxRates[0] == "T-1" - )); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - paymentToken = "pm_" + paymentToken; - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - Status = "incomplete", - LatestInvoice = new Stripe.Invoice + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - PaymentIntent = new Stripe.PaymentIntent + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); + + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + + Assert.Null(result); + Assert.Equal(GatewayType.Stripe, organization.Gateway); + Assert.Equal("C-1", organization.GatewayCustomerId); + Assert.Equal("S-1", organization.GatewaySubscriptionId); + Assert.True(organization.Enabled); + Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); + + await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => + c.Description == organization.BusinessName && + c.Email == organization.BillingEmail && + c.Source == paymentToken && + c.PaymentMethod == null && + !c.Metadata.Any() && + c.InvoiceSettings.DefaultPaymentMethod == null && + c.Address.Country == taxInfo.BillingAddressCountry && + c.Address.PostalCode == taxInfo.BillingAddressPostalCode && + c.Address.Line1 == taxInfo.BillingAddressLine1 && + c.Address.Line2 == taxInfo.BillingAddressLine2 && + c.Address.City == taxInfo.BillingAddressCity && + c.Address.State == taxInfo.BillingAddressState && + c.TaxIdData == null + )); + + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.Customer == "C-1" && + s.Expand[0] == "latest_invoice.payment_intent" && + s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && + s.Items.Count == 0 + )); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_PM(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + paymentToken = "pm_" + paymentToken; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + { + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); + + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + + Assert.Null(result); + Assert.Equal(GatewayType.Stripe, organization.Gateway); + Assert.Equal("C-1", organization.GatewayCustomerId); + Assert.Equal("S-1", organization.GatewaySubscriptionId); + Assert.True(organization.Enabled); + Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); + + await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => + c.Description == organization.BusinessName && + c.Email == organization.BillingEmail && + c.Source == null && + c.PaymentMethod == paymentToken && + !c.Metadata.Any() && + c.InvoiceSettings.DefaultPaymentMethod == paymentToken && + c.Address.Country == taxInfo.BillingAddressCountry && + c.Address.PostalCode == taxInfo.BillingAddressPostalCode && + c.Address.Line1 == taxInfo.BillingAddressLine1 && + c.Address.Line2 == taxInfo.BillingAddressLine2 && + c.Address.City == taxInfo.BillingAddressCity && + c.Address.State == taxInfo.BillingAddressState && + c.TaxIdData == null + )); + + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.Customer == "C-1" && + s.Expand[0] == "latest_invoice.payment_intent" && + s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && + s.Items.Count == 0 + )); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_TaxRate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + { + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); + sutProvider.GetDependency().GetByLocationAsync(Arg.Is(t => + t.Country == taxInfo.BillingAddressCountry && t.PostalCode == taxInfo.BillingAddressPostalCode)) + .Returns(new List { new() { Id = "T-1" } }); + + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + + Assert.Null(result); + + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.DefaultTaxRates.Count == 1 && + s.DefaultTaxRates[0] == "T-1" + )); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + paymentToken = "pm_" + paymentToken; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + { + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + Status = "incomplete", + LatestInvoice = new Stripe.Invoice { - Status = "requires_payment_method", + PaymentIntent = new Stripe.PaymentIntent + { + Status = "requires_payment_method", + }, }, - }, - }); + }); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo)); - Assert.Equal("Payment method was declined.", exception.Message); + Assert.Equal("Payment method was declined.", exception.Message); - await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); - } + await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_RequiresAction(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_RequiresAction(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - Status = "incomplete", - LatestInvoice = new Stripe.Invoice + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - PaymentIntent = new Stripe.PaymentIntent + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + Status = "incomplete", + LatestInvoice = new Stripe.Invoice { - Status = "requires_action", - ClientSecret = "clientSecret", + PaymentIntent = new Stripe.PaymentIntent + { + Status = "requires_action", + ClientSecret = "clientSecret", + }, }, - }, - }); + }); - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - Assert.Equal("clientSecret", result); - Assert.False(organization.Enabled); - } + Assert.Equal("clientSecret", result); + Assert.False(organization.Enabled); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Paypal(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Paypal(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - var customer = Substitute.For(); - customer.Id.ReturnsForAnyArgs("Braintree-Id"); - customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); - var customerResult = Substitute.For>(); - customerResult.IsSuccess().Returns(true); - customerResult.Target.ReturnsForAnyArgs(customer); - - var braintreeGateway = sutProvider.GetDependency(); - braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); - - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo); - - Assert.Null(result); - Assert.Equal(GatewayType.Stripe, organization.Gateway); - Assert.Equal("C-1", organization.GatewayCustomerId); - Assert.Equal("S-1", organization.GatewaySubscriptionId); - Assert.True(organization.Enabled); - Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - - await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => - c.Description == organization.BusinessName && - c.Email == organization.BillingEmail && - c.PaymentMethod == null && - c.Metadata.Count == 1 && - c.Metadata["btCustomerId"] == "Braintree-Id" && - c.InvoiceSettings.DefaultPaymentMethod == null && - c.Address.Country == taxInfo.BillingAddressCountry && - c.Address.PostalCode == taxInfo.BillingAddressPostalCode && - c.Address.Line1 == taxInfo.BillingAddressLine1 && - c.Address.Line2 == taxInfo.BillingAddressLine2 && - c.Address.City == taxInfo.BillingAddressCity && - c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null - )); - - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.Customer == "C-1" && - s.Expand[0] == "latest_invoice.payment_intent" && - s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && - s.Items.Count == 0 - )); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Paypal_FailedCreate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var customerResult = Substitute.For>(); - customerResult.IsSuccess().Returns(false); - - var braintreeGateway = sutProvider.GetDependency(); - braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); - - Assert.Equal("Failed to create PayPal customer record.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_PayPal_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) - { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - paymentToken = "pm_" + paymentToken; - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - Status = "incomplete", - LatestInvoice = new Stripe.Invoice + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - PaymentIntent = new Stripe.PaymentIntent + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); + + var customer = Substitute.For(); + customer.Id.ReturnsForAnyArgs("Braintree-Id"); + customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); + var customerResult = Substitute.For>(); + customerResult.IsSuccess().Returns(true); + customerResult.Target.ReturnsForAnyArgs(customer); + + var braintreeGateway = sutProvider.GetDependency(); + braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); + + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo); + + Assert.Null(result); + Assert.Equal(GatewayType.Stripe, organization.Gateway); + Assert.Equal("C-1", organization.GatewayCustomerId); + Assert.Equal("S-1", organization.GatewaySubscriptionId); + Assert.True(organization.Enabled); + Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); + + await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => + c.Description == organization.BusinessName && + c.Email == organization.BillingEmail && + c.PaymentMethod == null && + c.Metadata.Count == 1 && + c.Metadata["btCustomerId"] == "Braintree-Id" && + c.InvoiceSettings.DefaultPaymentMethod == null && + c.Address.Country == taxInfo.BillingAddressCountry && + c.Address.PostalCode == taxInfo.BillingAddressPostalCode && + c.Address.Line1 == taxInfo.BillingAddressLine1 && + c.Address.Line2 == taxInfo.BillingAddressLine2 && + c.Address.City == taxInfo.BillingAddressCity && + c.Address.State == taxInfo.BillingAddressState && + c.TaxIdData == null + )); + + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.Customer == "C-1" && + s.Expand[0] == "latest_invoice.payment_intent" && + s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && + s.Items.Count == 0 + )); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Paypal_FailedCreate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var customerResult = Substitute.For>(); + customerResult.IsSuccess().Returns(false); + + var braintreeGateway = sutProvider.GetDependency(); + braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); + + Assert.Equal("Failed to create PayPal customer record.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_PayPal_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + paymentToken = "pm_" + paymentToken; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + { + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + Status = "incomplete", + LatestInvoice = new Stripe.Invoice { - Status = "requires_payment_method", + PaymentIntent = new Stripe.PaymentIntent + { + Status = "requires_payment_method", + }, }, - }, - }); + }); - var customer = Substitute.For(); - customer.Id.ReturnsForAnyArgs("Braintree-Id"); - customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); - var customerResult = Substitute.For>(); - customerResult.IsSuccess().Returns(true); - customerResult.Target.ReturnsForAnyArgs(customer); + var customer = Substitute.For(); + customer.Id.ReturnsForAnyArgs("Braintree-Id"); + customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); + var customerResult = Substitute.For>(); + customerResult.IsSuccess().Returns(true); + customerResult.Target.ReturnsForAnyArgs(customer); - var braintreeGateway = sutProvider.GetDependency(); - braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); + var braintreeGateway = sutProvider.GetDependency(); + braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); - Assert.Equal("Payment method was declined.", exception.Message); + Assert.Equal("Payment method was declined.", exception.Message); - await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); - await braintreeGateway.Customer.Received(1).DeleteAsync("Braintree-Id"); - } + await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); + await braintreeGateway.Customer.Received(1).DeleteAsync("Braintree-Id"); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void UpgradeFreeOrganizationAsync_Success(SutProvider sutProvider, - Organization organization, TaxInfo taxInfo) - { - organization.GatewaySubscriptionId = null; - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(default).ReturnsForAnyArgs(new Stripe.Customer + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void UpgradeFreeOrganizationAsync_Success(SutProvider sutProvider, + Organization organization, TaxInfo taxInfo) { - Id = "C-1", - Metadata = new Dictionary + organization.GatewaySubscriptionId = null; + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerGetAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - { "btCustomerId", "B-123" }, - } - }); - stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice - { - PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", }, - AmountDue = 0 - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { }); + Id = "C-1", + Metadata = new Dictionary + { + { "btCustomerId", "B-123" }, + } + }); + stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice + { + PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", }, + AmountDue = 0 + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { }); - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - var result = await sutProvider.Sut.UpgradeFreeOrganizationAsync(organization, plan, 0, 0, false, taxInfo); + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + var result = await sutProvider.Sut.UpgradeFreeOrganizationAsync(organization, plan, 0, 0, false, taxInfo); - Assert.Null(result); + Assert.Null(result); + } } } diff --git a/test/Core.Test/Services/UserServiceTests.cs b/test/Core.Test/Services/UserServiceTests.cs index 5e82e4c40a..10a4beac54 100644 --- a/test/Core.Test/Services/UserServiceTests.cs +++ b/test/Core.Test/Services/UserServiceTests.cs @@ -14,375 +14,376 @@ using NSubstitute; using NSubstitute.ReceivedExtensions; using Xunit; -namespace Bit.Core.Test.Services; - -public class UserServiceTests +namespace Bit.Core.Test.Services { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateLicenseAsync_Success(SutProvider sutProvider, - User user, UserLicense userLicense) + public class UserServiceTests { - using var tempDir = new TempDirectory(); - - var now = DateTime.UtcNow; - userLicense.Issued = now.AddDays(-10); - userLicense.Expires = now.AddDays(10); - userLicense.Version = 1; - userLicense.Premium = true; - - user.EmailVerified = true; - user.Email = userLicense.Email; - - sutProvider.GetDependency().SelfHosted = true; - sutProvider.GetDependency().LicenseDirectory = tempDir.Directory; - sutProvider.GetDependency() - .VerifyLicense(userLicense) - .Returns(true); - - await sutProvider.Sut.UpdateLicenseAsync(user, userLicense); - - var filePath = Path.Combine(tempDir.Directory, "user", $"{user.Id}.json"); - Assert.True(File.Exists(filePath)); - var document = JsonDocument.Parse(File.OpenRead(filePath)); - var root = document.RootElement; - Assert.Equal(JsonValueKind.Object, root.ValueKind); - // Sort of a lazy way to test that it is indented but not sure of a better way - Assert.Contains('\n', root.GetRawText()); - AssertHelper.AssertJsonProperty(root, "LicenseKey", JsonValueKind.String); - AssertHelper.AssertJsonProperty(root, "Id", JsonValueKind.String); - AssertHelper.AssertJsonProperty(root, "Premium", JsonValueKind.True); - var versionProp = AssertHelper.AssertJsonProperty(root, "Version", JsonValueKind.Number); - Assert.Equal(1, versionProp.GetInt32()); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_Success(SutProvider sutProvider, User user) - { - var email = user.Email.ToLowerInvariant(); - var token = "thisisatokentocompare"; - - var userTwoFactorTokenProvider = Substitute.For>(); - userTwoFactorTokenProvider - .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) - .Returns(Task.FromResult(true)); - userTwoFactorTokenProvider - .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) - .Returns(Task.FromResult(token)); - - sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); - - user.SetTwoFactorProviders(new Dictionary + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateLicenseAsync_Success(SutProvider sutProvider, + User user, UserLicense userLicense) { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = email }, - Enabled = true - } - }); - await sutProvider.Sut.SendTwoFactorEmailAsync(user); + using var tempDir = new TempDirectory(); - await sutProvider.GetDependency() - .Received(1) - .SendTwoFactorEmailAsync(email, token); - } + var now = DateTime.UtcNow; + userLicense.Issued = now.AddDays(-10); + userLicense.Expires = now.AddDays(10); + userLicense.Version = 1; + userLicense.Premium = true; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailBecauseNewDeviceLoginAsync_Success(SutProvider sutProvider, User user) - { - var email = user.Email.ToLowerInvariant(); - var token = "thisisatokentocompare"; + user.EmailVerified = true; + user.Email = userLicense.Email; - var userTwoFactorTokenProvider = Substitute.For>(); - userTwoFactorTokenProvider - .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) - .Returns(Task.FromResult(true)); - userTwoFactorTokenProvider - .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) - .Returns(Task.FromResult(token)); + sutProvider.GetDependency().SelfHosted = true; + sutProvider.GetDependency().LicenseDirectory = tempDir.Directory; + sutProvider.GetDependency() + .VerifyLicense(userLicense) + .Returns(true); - sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); + await sutProvider.Sut.UpdateLicenseAsync(user, userLicense); - user.SetTwoFactorProviders(new Dictionary + var filePath = Path.Combine(tempDir.Directory, "user", $"{user.Id}.json"); + Assert.True(File.Exists(filePath)); + var document = JsonDocument.Parse(File.OpenRead(filePath)); + var root = document.RootElement; + Assert.Equal(JsonValueKind.Object, root.ValueKind); + // Sort of a lazy way to test that it is indented but not sure of a better way + Assert.Contains('\n', root.GetRawText()); + AssertHelper.AssertJsonProperty(root, "LicenseKey", JsonValueKind.String); + AssertHelper.AssertJsonProperty(root, "Id", JsonValueKind.String); + AssertHelper.AssertJsonProperty(root, "Premium", JsonValueKind.True); + var versionProp = AssertHelper.AssertJsonProperty(root, "Version", JsonValueKind.Number); + Assert.Equal(1, versionProp.GetInt32()); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_Success(SutProvider sutProvider, User user) { - [TwoFactorProviderType.Email] = new TwoFactorProvider + var email = user.Email.ToLowerInvariant(); + var token = "thisisatokentocompare"; + + var userTwoFactorTokenProvider = Substitute.For>(); + userTwoFactorTokenProvider + .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) + .Returns(Task.FromResult(true)); + userTwoFactorTokenProvider + .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) + .Returns(Task.FromResult(token)); + + sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); + + user.SetTwoFactorProviders(new Dictionary { - MetaData = new Dictionary { ["Email"] = email }, - Enabled = true - } - }); - await sutProvider.Sut.SendTwoFactorEmailAsync(user, true); + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = email }, + Enabled = true + } + }); + await sutProvider.Sut.SendTwoFactorEmailAsync(user); - await sutProvider.GetDependency() - .Received(1) - .SendNewDeviceLoginTwoFactorEmailAsync(email, token); - } + await sutProvider.GetDependency() + .Received(1) + .SendTwoFactorEmailAsync(email, token); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderOnUser(SutProvider sutProvider, User user) - { - user.TwoFactorProviders = null; - - await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderMetadataOnUser(SutProvider sutProvider, User user) - { - user.SetTwoFactorProviders(new Dictionary + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailBecauseNewDeviceLoginAsync_Success(SutProvider sutProvider, User user) { - [TwoFactorProviderType.Email] = new TwoFactorProvider + var email = user.Email.ToLowerInvariant(); + var token = "thisisatokentocompare"; + + var userTwoFactorTokenProvider = Substitute.For>(); + userTwoFactorTokenProvider + .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) + .Returns(Task.FromResult(true)); + userTwoFactorTokenProvider + .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) + .Returns(Task.FromResult(token)); + + sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); + + user.SetTwoFactorProviders(new Dictionary { - MetaData = null, - Enabled = true - } - }); + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = email }, + Enabled = true + } + }); + await sutProvider.Sut.SendTwoFactorEmailAsync(user, true); - await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); - } + await sutProvider.GetDependency() + .Received(1) + .SendNewDeviceLoginTwoFactorEmailAsync(email, token); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderEmailMetadataOnUser(SutProvider sutProvider, User user) - { - user.SetTwoFactorProviders(new Dictionary + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderOnUser(SutProvider sutProvider, User user) { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["qweqwe"] = user.Email.ToLowerInvariant() }, - Enabled = true - } - }); + user.TwoFactorProviders = null; - await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); - } + await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsTrue(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - user.UnknownDeviceVerificationEnabled = true; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.True(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GranType_Is_AuthorizationCode(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "authorization_code")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) - { - user.EmailVerified = false; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Is_The_First_Device(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List())); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_DeviceId_Is_Already_In_Repo(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdToCheck } - })); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_UnknownDeviceVerification_Is_Disabled(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - user.UnknownDeviceVerificationEnabled = false; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsTrue(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.True(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); - - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) - { - user.EmailVerified = false; - user.TwoFactorProviders = null; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Uses_Key_Connector(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - user.UsesKeyConnector = true; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Has_A_2FA_Already_Set_Up(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.SetTwoFactorProviders(new Dictionary + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderMetadataOnUser(SutProvider sutProvider, User user) { - [TwoFactorProviderType.Email] = new TwoFactorProvider + user.SetTwoFactorProviders(new Dictionary { - MetaData = new Dictionary { ["Email"] = "asdfasf" }, - Enabled = true - } - }); + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = null, + Enabled = true + } + }); - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); + } - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderEmailMetadataOnUser(SutProvider sutProvider, User user) + { + user.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["qweqwe"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void HasPremiumFromOrganization_Returns_False_If_No_Orgs(SutProvider sutProvider, User user) - { - sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List()); - Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); + await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); + } - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsTrue(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + user.UnknownDeviceVerificationEnabled = true; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, false, true)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, true, false)] - public async void HasPremiumFromOrganization_Returns_False_If_Org_Not_Eligible(bool orgEnabled, bool orgUsersGetPremium, SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) - { - orgUser.OrganizationId = organization.Id; - organization.Enabled = orgEnabled; - organization.UsersGetPremium = orgUsersGetPremium; - var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); - sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); - Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); - } + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void HasPremiumFromOrganization_Returns_True_If_Org_Eligible(SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) - { - orgUser.OrganizationId = organization.Id; - organization.Enabled = true; - organization.UsersGetPremium = true; - var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; + Assert.True(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } - sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GranType_Is_AuthorizationCode(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - Assert.True(await sutProvider.Sut.HasPremiumFromOrganization(user)); + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "authorization_code")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) + { + user.EmailVerified = false; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Is_The_First_Device(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List())); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_DeviceId_Is_Already_In_Repo(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdToCheck } + })); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_UnknownDeviceVerification_Is_Disabled(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + user.UnknownDeviceVerificationEnabled = false; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsTrue(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.True(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) + { + user.EmailVerified = false; + user.TwoFactorProviders = null; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Uses_Key_Connector(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + user.UsesKeyConnector = true; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Has_A_2FA_Already_Set_Up(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = "asdfasf" }, + Enabled = true + } + }); + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void HasPremiumFromOrganization_Returns_False_If_No_Orgs(SutProvider sutProvider, User user) + { + sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List()); + Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); + + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, false, true)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, true, false)] + public async void HasPremiumFromOrganization_Returns_False_If_Org_Not_Eligible(bool orgEnabled, bool orgUsersGetPremium, SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) + { + orgUser.OrganizationId = organization.Id; + organization.Enabled = orgEnabled; + organization.UsersGetPremium = orgUsersGetPremium; + var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; + + sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + + Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void HasPremiumFromOrganization_Returns_True_If_Org_Eligible(SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) + { + orgUser.OrganizationId = organization.Id; + organization.Enabled = true; + organization.UsersGetPremium = true; + var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; + + sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + + Assert.True(await sutProvider.Sut.HasPremiumFromOrganization(user)); + } } } diff --git a/test/Core.Test/TempDirectory.cs b/test/Core.Test/TempDirectory.cs index 832d8c79ca..9a1cd86af3 100644 --- a/test/Core.Test/TempDirectory.cs +++ b/test/Core.Test/TempDirectory.cs @@ -1,38 +1,39 @@ -namespace Bit.Core.Test; - -public class TempDirectory : IDisposable +namespace Bit.Core.Test { - public string Directory { get; private set; } - - public TempDirectory() + public class TempDirectory : IDisposable { - Directory = Path.Combine(Path.GetTempPath(), $"bitwarden_{Guid.NewGuid().ToString().Replace("-", "")}"); - } + public string Directory { get; private set; } - public override string ToString() => Directory; - - #region IDisposable implementation - ~TempDirectory() - { - Dispose(false); - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - public void Dispose(bool disposing) - { - if (disposing) + public TempDirectory() { - try - { - System.IO.Directory.Delete(Directory, true); - } - catch { } + Directory = Path.Combine(Path.GetTempPath(), $"bitwarden_{Guid.NewGuid().ToString().Replace("-", "")}"); } + + public override string ToString() => Directory; + + #region IDisposable implementation + ~TempDirectory() + { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public void Dispose(bool disposing) + { + if (disposing) + { + try + { + System.IO.Directory.Delete(Directory, true); + } + catch { } + } + } + # endregion } - # endregion } diff --git a/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs b/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs index 3837ae0261..8a75a07900 100644 --- a/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs +++ b/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs @@ -7,121 +7,122 @@ using Bit.Test.Common.Helpers; using Microsoft.AspNetCore.DataProtection; using Xunit; -namespace Bit.Core.Test.Tokens; - -[SutProviderCustomize] -public class DataProtectorTokenFactoryTests +namespace Bit.Core.Test.Tokens { - public static SutProvider> GetSutProvider() + [SutProviderCustomize] + public class DataProtectorTokenFactoryTests { - var fixture = new Fixture(); - return new SutProvider>(fixture) - .SetDependency(fixture.Create()) - .Create(); + public static SutProvider> GetSutProvider() + { + var fixture = new Fixture(); + return new SutProvider>(fixture) + .SetDependency(fixture.Create()) + .Create(); + } + + [Theory, BitAutoData] + public void CanRoundTripTokenables(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + var token = sutProvider.Sut.Protect(tokenable); + var recoveredTokenable = sutProvider.Sut.Unprotect(token); + + AssertHelper.AssertPropertyEqual(tokenable, recoveredTokenable); + } + + [Theory, BitAutoData] + public void PrependsClearText(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + var token = sutProvider.Sut.Protect(tokenable); + + Assert.StartsWith(sutProvider.GetDependency("clearTextPrefix"), token); + } + + [Theory, BitAutoData] + public void EncryptsToken(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var prefix = sutProvider.GetDependency("clearTextPrefix"); + + var token = sutProvider.Sut.Protect(tokenable); + + Assert.NotEqual(new Token(token).RemovePrefix(prefix), tokenable.ToToken()); + } + + [Theory, BitAutoData] + public void ThrowsIfUnprotectFails(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + var token = sutProvider.Sut.Protect(tokenable); + token += "stuff to make sure decryption fails"; + + Assert.Throws(() => sutProvider.Sut.Unprotect(token)); + } + + [Theory, BitAutoData] + public void TryUnprotect_FalseIfUnprotectFails(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; + + var result = sutProvider.Sut.TryUnprotect(token, out var data); + + Assert.False(result); + Assert.Null(data); + } + + [Theory, BitAutoData] + public void TokenValid_FalseIfUnprotectFails(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; + + var result = sutProvider.Sut.TokenValid(token); + + Assert.False(result); + } + + + [Theory, BitAutoData] + public void TokenValid_FalseIfTokenInvalid(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + tokenable.ForceInvalid = true; + var token = sutProvider.Sut.Protect(tokenable); + + var result = sutProvider.Sut.TokenValid(token); + + Assert.False(result); + } + + [Theory, BitAutoData] + public void TryUnprotect_TrueIfSuccess(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable); + + var result = sutProvider.Sut.TryUnprotect(token, out var data); + + Assert.True(result); + AssertHelper.AssertPropertyEqual(tokenable, data); + } + + [Theory, BitAutoData] + public void TokenValid_TrueIfSuccess(TestTokenable tokenable) + { + tokenable.ForceInvalid = false; + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable); + + var result = sutProvider.Sut.TokenValid(token); + + Assert.True(result); + } + } - - [Theory, BitAutoData] - public void CanRoundTripTokenables(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - var token = sutProvider.Sut.Protect(tokenable); - var recoveredTokenable = sutProvider.Sut.Unprotect(token); - - AssertHelper.AssertPropertyEqual(tokenable, recoveredTokenable); - } - - [Theory, BitAutoData] - public void PrependsClearText(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - var token = sutProvider.Sut.Protect(tokenable); - - Assert.StartsWith(sutProvider.GetDependency("clearTextPrefix"), token); - } - - [Theory, BitAutoData] - public void EncryptsToken(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var prefix = sutProvider.GetDependency("clearTextPrefix"); - - var token = sutProvider.Sut.Protect(tokenable); - - Assert.NotEqual(new Token(token).RemovePrefix(prefix), tokenable.ToToken()); - } - - [Theory, BitAutoData] - public void ThrowsIfUnprotectFails(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - var token = sutProvider.Sut.Protect(tokenable); - token += "stuff to make sure decryption fails"; - - Assert.Throws(() => sutProvider.Sut.Unprotect(token)); - } - - [Theory, BitAutoData] - public void TryUnprotect_FalseIfUnprotectFails(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; - - var result = sutProvider.Sut.TryUnprotect(token, out var data); - - Assert.False(result); - Assert.Null(data); - } - - [Theory, BitAutoData] - public void TokenValid_FalseIfUnprotectFails(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; - - var result = sutProvider.Sut.TokenValid(token); - - Assert.False(result); - } - - - [Theory, BitAutoData] - public void TokenValid_FalseIfTokenInvalid(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - tokenable.ForceInvalid = true; - var token = sutProvider.Sut.Protect(tokenable); - - var result = sutProvider.Sut.TokenValid(token); - - Assert.False(result); - } - - [Theory, BitAutoData] - public void TryUnprotect_TrueIfSuccess(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable); - - var result = sutProvider.Sut.TryUnprotect(token, out var data); - - Assert.True(result); - AssertHelper.AssertPropertyEqual(tokenable, data); - } - - [Theory, BitAutoData] - public void TokenValid_TrueIfSuccess(TestTokenable tokenable) - { - tokenable.ForceInvalid = false; - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable); - - var result = sutProvider.Sut.TokenValid(token); - - Assert.True(result); - } - } diff --git a/test/Core.Test/Tokens/ExpiringTokenTests.cs b/test/Core.Test/Tokens/ExpiringTokenTests.cs index 9154b65b69..33ce911786 100644 --- a/test/Core.Test/Tokens/ExpiringTokenTests.cs +++ b/test/Core.Test/Tokens/ExpiringTokenTests.cs @@ -3,68 +3,69 @@ using AutoFixture.Xunit2; using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Tokens; - -public class ExpiringTokenTests +namespace Bit.Core.Test.Tokens { - [Theory, AutoData] - public void ExpirationSerializesToEpochMilliseconds(DateTime expirationDate) + public class ExpiringTokenTests { - var sut = new TestExpiringTokenable + [Theory, AutoData] + public void ExpirationSerializesToEpochMilliseconds(DateTime expirationDate) { - ExpirationDate = expirationDate - }; + var sut = new TestExpiringTokenable + { + ExpirationDate = expirationDate + }; - var result = JsonSerializer.Serialize(sut); - var expectedDate = CoreHelpers.ToEpocMilliseconds(expirationDate); + var result = JsonSerializer.Serialize(sut); + var expectedDate = CoreHelpers.ToEpocMilliseconds(expirationDate); - Assert.Contains($"\"ExpirationDate\":{expectedDate}", result); - } + Assert.Contains($"\"ExpirationDate\":{expectedDate}", result); + } - [Theory, AutoData] - public void ExpirationSerializationRoundTrip(DateTime expirationDate) - { - var sut = new TestExpiringTokenable + [Theory, AutoData] + public void ExpirationSerializationRoundTrip(DateTime expirationDate) { - ExpirationDate = expirationDate - }; + var sut = new TestExpiringTokenable + { + ExpirationDate = expirationDate + }; - var intermediate = JsonSerializer.Serialize(sut); - var result = JsonSerializer.Deserialize(intermediate); + var intermediate = JsonSerializer.Serialize(sut); + var result = JsonSerializer.Deserialize(intermediate); - Assert.Equal(sut.ExpirationDate, result.ExpirationDate, TimeSpan.FromMilliseconds(100)); - } + Assert.Equal(sut.ExpirationDate, result.ExpirationDate, TimeSpan.FromMilliseconds(100)); + } - [Fact] - public void InvalidIfPastExpiryDate() - { - var sut = new TestExpiringTokenable + [Fact] + public void InvalidIfPastExpiryDate() { - ExpirationDate = DateTime.UtcNow.AddHours(-1) - }; + var sut = new TestExpiringTokenable + { + ExpirationDate = DateTime.UtcNow.AddHours(-1) + }; - Assert.False(sut.Valid); - } + Assert.False(sut.Valid); + } - [Fact] - public void ValidIfWithinExpirationAndTokenReportsValid() - { - var sut = new TestExpiringTokenable + [Fact] + public void ValidIfWithinExpirationAndTokenReportsValid() { - ExpirationDate = DateTime.UtcNow.AddHours(1) - }; + var sut = new TestExpiringTokenable + { + ExpirationDate = DateTime.UtcNow.AddHours(1) + }; - Assert.True(sut.Valid); - } + Assert.True(sut.Valid); + } - [Fact] - public void HonorsTokenIsValidAbstractMember() - { - var sut = new TestExpiringTokenable(forceInvalid: true) + [Fact] + public void HonorsTokenIsValidAbstractMember() { - ExpirationDate = DateTime.UtcNow.AddHours(1) - }; + var sut = new TestExpiringTokenable(forceInvalid: true) + { + ExpirationDate = DateTime.UtcNow.AddHours(1) + }; - Assert.False(sut.Valid); + Assert.False(sut.Valid); + } } } diff --git a/test/Core.Test/Tokens/TestTokenable.cs b/test/Core.Test/Tokens/TestTokenable.cs index c8dee643b1..7e73cd5e9d 100644 --- a/test/Core.Test/Tokens/TestTokenable.cs +++ b/test/Core.Test/Tokens/TestTokenable.cs @@ -1,25 +1,26 @@ using System.Text.Json.Serialization; using Bit.Core.Tokens; -namespace Bit.Core.Test.Tokens; - -public class TestTokenable : Tokenable +namespace Bit.Core.Test.Tokens { - public bool ForceInvalid { get; set; } = false; - - [JsonIgnore] - public override bool Valid => !ForceInvalid; -} - -public class TestExpiringTokenable : ExpiringTokenable -{ - private bool _forceInvalid; - - public TestExpiringTokenable() : this(false) { } - - public TestExpiringTokenable(bool forceInvalid) + public class TestTokenable : Tokenable { - _forceInvalid = forceInvalid; + public bool ForceInvalid { get; set; } = false; + + [JsonIgnore] + public override bool Valid => !ForceInvalid; + } + + public class TestExpiringTokenable : ExpiringTokenable + { + private bool _forceInvalid; + + public TestExpiringTokenable() : this(false) { } + + public TestExpiringTokenable(bool forceInvalid) + { + _forceInvalid = forceInvalid; + } + protected override bool TokenIsValid() => !_forceInvalid; } - protected override bool TokenIsValid() => !_forceInvalid; } diff --git a/test/Core.Test/Tokens/TokenTests.cs b/test/Core.Test/Tokens/TokenTests.cs index 1afad24127..bc1ad85688 100644 --- a/test/Core.Test/Tokens/TokenTests.cs +++ b/test/Core.Test/Tokens/TokenTests.cs @@ -2,37 +2,38 @@ using Bit.Core.Tokens; using Xunit; -namespace Bit.Core.Test.Tokens; - -public class TokenTests +namespace Bit.Core.Test.Tokens { - [Theory, AutoData] - public void InitializeWithString_ReturnsString(string initString) + public class TokenTests { - var token = new Token(initString); + [Theory, AutoData] + public void InitializeWithString_ReturnsString(string initString) + { + var token = new Token(initString); - Assert.Equal(initString, token.ToString()); - } + Assert.Equal(initString, token.ToString()); + } - [Theory, AutoData] - public void AddsPrefix(Token token, string prefix) - { - Assert.Equal($"{prefix}{token.ToString()}", token.WithPrefix(prefix).ToString()); - } + [Theory, AutoData] + public void AddsPrefix(Token token, string prefix) + { + Assert.Equal($"{prefix}{token.ToString()}", token.WithPrefix(prefix).ToString()); + } - [Theory, AutoData] - public void RemovePrefix_WithPrefix_RemovesPrefix(string initString, string prefix) - { - var token = new Token(initString).WithPrefix(prefix); + [Theory, AutoData] + public void RemovePrefix_WithPrefix_RemovesPrefix(string initString, string prefix) + { + var token = new Token(initString).WithPrefix(prefix); - Assert.Equal(initString, token.RemovePrefix(prefix).ToString()); - } + Assert.Equal(initString, token.RemovePrefix(prefix).ToString()); + } - [Theory, AutoData] - public void RemovePrefix_WithoutPrefix_Throws(Token token, string prefix) - { - var exception = Assert.Throws(() => token.RemovePrefix(prefix)); + [Theory, AutoData] + public void RemovePrefix_WithoutPrefix_Throws(Token token, string prefix) + { + var exception = Assert.Throws(() => token.RemovePrefix(prefix)); - Assert.Equal($"Expected prefix, {prefix}, was not present.", exception.Message); + Assert.Equal($"Expected prefix, {prefix}, was not present.", exception.Message); + } } } diff --git a/test/Core.Test/Utilities/ClaimsExtensionsTests.cs b/test/Core.Test/Utilities/ClaimsExtensionsTests.cs index 665c647797..d6b5c90dbd 100644 --- a/test/Core.Test/Utilities/ClaimsExtensionsTests.cs +++ b/test/Core.Test/Utilities/ClaimsExtensionsTests.cs @@ -2,35 +2,36 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class ClaimsExtensionsTests +namespace Bit.Core.Test.Utilities { - [Fact] - public void HasSSOIdP_Returns_True_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Sso() + public class ClaimsExtensionsTests { - var claims = new List { new Claim("idp", "sso") }; - Assert.True(claims.HasSsoIdP()); - } + [Fact] + public void HasSSOIdP_Returns_True_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Sso() + { + var claims = new List { new Claim("idp", "sso") }; + Assert.True(claims.HasSsoIdP()); + } - [Fact] - public void HasSSOIdP_Returns_False_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Is_Not_Sso() - { - var claims = new List { new Claim("idp", "asdfasfd") }; - Assert.False(claims.HasSsoIdP()); - } + [Fact] + public void HasSSOIdP_Returns_False_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Is_Not_Sso() + { + var claims = new List { new Claim("idp", "asdfasfd") }; + Assert.False(claims.HasSsoIdP()); + } - [Fact] - public void HasSSOIdP_Returns_False_When_The_Claims_Has_No_One_Of_Type_IdP() - { - var claims = new List { new Claim("qweqweq", "sso") }; - Assert.False(claims.HasSsoIdP()); - } + [Fact] + public void HasSSOIdP_Returns_False_When_The_Claims_Has_No_One_Of_Type_IdP() + { + var claims = new List { new Claim("qweqweq", "sso") }; + Assert.False(claims.HasSsoIdP()); + } - [Fact] - public void HasSSOIdP_Returns_False_When_The_Claims_Are_Empty() - { - var claims = new List(); - Assert.False(claims.HasSsoIdP()); + [Fact] + public void HasSSOIdP_Returns_False_When_The_Claims_Are_Empty() + { + var claims = new List(); + Assert.False(claims.HasSsoIdP()); + } } } diff --git a/test/Core.Test/Utilities/CoreHelpersTests.cs b/test/Core.Test/Utilities/CoreHelpersTests.cs index 76db48fe3e..37b9c22df8 100644 --- a/test/Core.Test/Utilities/CoreHelpersTests.cs +++ b/test/Core.Test/Utilities/CoreHelpersTests.cs @@ -12,433 +12,434 @@ using IdentityModel; using Microsoft.AspNetCore.DataProtection; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class CoreHelpersTests +namespace Bit.Core.Test.Utilities { - public static IEnumerable _epochTestCases = new[] + public class CoreHelpersTests { - new object[] {new DateTime(2020, 12, 30, 11, 49, 12, DateTimeKind.Utc), 1609328952000L}, - }; - - [Fact] - public void GenerateComb_Success() - { - // Arrange & Act - var comb = CoreHelpers.GenerateComb(); - - // Assert - Assert.NotEqual(Guid.Empty, comb); - // TODO: Add more asserts to make sure important aspects of - // the comb are working properly - } - - public static IEnumerable GenerateCombCases = new[] - { - new object[] + public static IEnumerable _epochTestCases = new[] { - Guid.Parse("a58db474-43d8-42f1-b4ee-0c17647cd0c0"), // Input Guid - new DateTime(2022, 3, 12, 12, 12, 0, DateTimeKind.Utc), // Input Time - Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb - }, - new object[] + new object[] {new DateTime(2020, 12, 30, 11, 49, 12, DateTimeKind.Utc), 1609328952000L}, + }; + + [Fact] + public void GenerateComb_Success() { - Guid.Parse("f776e6ee-511f-4352-bb28-88513002bdeb"), - new DateTime(2021, 5, 10, 10, 52, 0, DateTimeKind.Utc), - Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), - }, - new object[] - { - Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011648a1"), - new DateTime(1999, 2, 26, 16, 53, 13, DateTimeKind.Utc), - Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), - }, - new object[] - { - Guid.Parse("bfb8f353-3b32-4a9e-bef6-24fe0b54bfb0"), - new DateTime(2024, 10, 20, 1, 32, 16, DateTimeKind.Utc), - Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), - } - }; + // Arrange & Act + var comb = CoreHelpers.GenerateComb(); - [Theory] - [MemberData(nameof(GenerateCombCases))] - public void GenerateComb_WithInputs_Success(Guid inputGuid, DateTime inputTime, Guid expectedComb) - { - var comb = CoreHelpers.GenerateComb(inputGuid, inputTime); - - Assert.Equal(expectedComb, comb); - } - - [Theory] - [InlineData(2, 5, new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 })] - [InlineData(2, 3, new[] { 1, 2, 3, 4, 5 })] - [InlineData(2, 1, new[] { 1, 2 })] - [InlineData(1, 1, new[] { 1 })] - [InlineData(2, 2, new[] { 1, 2, 3 })] - public void Batch_Success(int batchSize, int totalBatches, int[] collection) - { - // Arrange - var remainder = collection.Length % batchSize; - - // Act - var batches = collection.Batch(batchSize); - - // Assert - Assert.Equal(totalBatches, batches.Count()); - - foreach (var batch in batches.Take(totalBatches - 1)) - { - Assert.Equal(batchSize, batch.Count()); + // Assert + Assert.NotEqual(Guid.Empty, comb); + // TODO: Add more asserts to make sure important aspects of + // the comb are working properly } - Assert.Equal(batches.Last().Count(), remainder == 0 ? batchSize : remainder); - } - - /* - [Fact] - public void ToGuidIdArrayTVP_Success() - { - // Arrange - var item0 = Guid.NewGuid(); - var item1 = Guid.NewGuid(); - - var ids = new[] { item0, item1 }; - - // Act - var dt = ids.ToGuidIdArrayTVP(); - - // Assert - Assert.Single(dt.Columns); - Assert.Equal("GuidId", dt.Columns[0].ColumnName); - Assert.Equal(2, dt.Rows.Count); - Assert.Equal(item0, dt.Rows[0][0]); - Assert.Equal(item1, dt.Rows[1][0]); - } - */ - - // TODO: Test the other ToArrayTVP Methods - - [Theory] - [InlineData("12345&6789", "123456789")] - [InlineData("abcdef", "ABCDEF")] - [InlineData("1!@#$%&*()_+", "1")] - [InlineData("\u00C6123abc\u00C7", "123ABC")] - [InlineData("123\u00C6ABC", "123ABC")] - [InlineData("\r\nHello", "E")] - [InlineData("\tdef", "DEF")] - [InlineData("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV1234567890", "ABCDEFABCDEF1234567890")] - public void CleanCertificateThumbprint_Success(string input, string output) - { - // Arrange & Act - var sanitizedInput = CoreHelpers.CleanCertificateThumbprint(input); - - // Assert - Assert.Equal(output, sanitizedInput); - } - - // TODO: Add more tests - [Theory] - [MemberData(nameof(_epochTestCases))] - public void ToEpocMilliseconds_Success(DateTime date, long milliseconds) - { - // Act & Assert - Assert.Equal(milliseconds, CoreHelpers.ToEpocMilliseconds(date)); - } - - [Theory] - [MemberData(nameof(_epochTestCases))] - public void FromEpocMilliseconds(DateTime date, long milliseconds) - { - // Act & Assert - Assert.Equal(date, CoreHelpers.FromEpocMilliseconds(milliseconds)); - } - - [Fact] - public void SecureRandomString_Success() - { - // Arrange & Act - var @string = CoreHelpers.SecureRandomString(8); - - // Assert - // TODO: Should probably add more Asserts down the line - Assert.Equal(8, @string.Length); - } - - [Theory] - [InlineData(1, "1 Bytes")] - [InlineData(-5L, "-5 Bytes")] - [InlineData(1023L, "1023 Bytes")] - [InlineData(1024L, "1 KB")] - [InlineData(1025L, "1 KB")] - [InlineData(-1023L, "-1023 Bytes")] - [InlineData(-1024L, "-1 KB")] - [InlineData(-1025L, "-1 KB")] - [InlineData(1048575L, "1024 KB")] - [InlineData(1048576L, "1 MB")] - [InlineData(1048577L, "1 MB")] - [InlineData(-1048575L, "-1024 KB")] - [InlineData(-1048576L, "-1 MB")] - [InlineData(-1048577L, "-1 MB")] - [InlineData(1073741823L, "1024 MB")] - [InlineData(1073741824L, "1 GB")] - [InlineData(1073741825L, "1 GB")] - [InlineData(-1073741823L, "-1024 MB")] - [InlineData(-1073741824L, "-1 GB")] - [InlineData(-1073741825L, "-1 GB")] - [InlineData(long.MaxValue, "8589934592 GB")] - public void ReadableBytesSize_Success(long size, string readable) - { - // Act & Assert - Assert.Equal(readable, CoreHelpers.ReadableBytesSize(size)); - } - - [Fact] - public void CloneObject_Success() - { - var original = new { Message = "Message" }; - - var copy = CoreHelpers.CloneObject(original); - - Assert.Equal(original.Message, copy.Message); - } - - [Fact] - public void ExtendQuery_AddNewParameter_Success() - { - // Arrange - var uri = new Uri("https://bitwarden.com/?param1=value1"); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, - new Dictionary { { "param2", "value2" } }); - - // Assert - Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2", newUri.ToString()); - } - - [Fact] - public void ExtendQuery_AddTwoNewParameters_Success() - { - // Arrange - var uri = new Uri("https://bitwarden.com/?param1=value1"); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, - new Dictionary + public static IEnumerable GenerateCombCases = new[] + { + new object[] { - { "param2", "value2" }, - { "param3", "value3" } - }); - - // Assert - Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2¶m3=value3", newUri.ToString()); - } - - [Fact] - public void ExtendQuery_AddExistingParameter_Success() - { - // Arrange - var uri = new Uri("https://bitwarden.com/?param1=value1¶m2=value2"); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, - new Dictionary { { "param1", "test_value" } }); - - // Assert - Assert.Equal("https://bitwarden.com/?param1=test_value¶m2=value2", newUri.ToString()); - } - - [Fact] - public void ExtendQuery_AddNoParameters_Success() - { - // Arrange - const string startingUri = "https://bitwarden.com/?param1=value1"; - - var uri = new Uri(startingUri); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, new Dictionary()); - - // Assert - Assert.Equal(startingUri, newUri.ToString()); - } - - [Theory] - [InlineData("bücher.com", "xn--bcher-kva.com")] - [InlineData("bücher.cömé", "xn--bcher-kva.xn--cm-cja4c")] - [InlineData("hello@bücher.com", "hello@xn--bcher-kva.com")] - [InlineData("hello@world.cömé", "hello@world.xn--cm-cja4c")] - [InlineData("hello@bücher.cömé", "hello@xn--bcher-kva.xn--cm-cja4c")] - [InlineData("ascii.com", "ascii.com")] - [InlineData("", "")] - [InlineData(null, null)] - public void PunyEncode_Success(string text, string expected) - { - var actual = CoreHelpers.PunyEncode(text); - Assert.Equal(expected, actual); - } - - [Fact] - public void GetEmbeddedResourceContentsAsync_Success() - { - var fileContents = CoreHelpers.GetEmbeddedResourceContentsAsync("data.embeddedResource.txt"); - Assert.Equal("Contents of embeddedResource.txt\n", fileContents.Replace("\r\n", "\n")); - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_BaseClaims_Success(User user, bool isPremium) - { - var expected = new Dictionary - { - { "premium", isPremium ? "true" : "false" }, - { JwtClaimTypes.Email, user.Email }, - { JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false" }, - { JwtClaimTypes.Name, user.Name }, - { "sstamp", user.SecurityStamp }, - }.ToList(); - - var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), - Array.Empty(), isPremium); - - foreach (var claim in expected) - { - Assert.Contains(claim, actual); - } - Assert.Equal(expected.Count, actual.Count); - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_NonCustomOrganizationUserType_Success(User user) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - foreach (var organizationUserType in Enum.GetValues().Except(new[] { OrganizationUserType.Custom })) - { - var org = fixture.Create(); - org.Type = organizationUserType; - - var expected = new KeyValuePair($"org{organizationUserType.ToString().ToLower()}", org.Id.ToString()); - var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); - - Assert.Contains(expected, actual); - } - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_CustomOrganizationUserClaims_Success(User user, CurrentContentOrganization org) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - org.Type = OrganizationUserType.Custom; - - var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); - foreach (var (permitted, claimName) in org.Permissions.ClaimsMap) - { - var claim = new KeyValuePair(claimName, org.Id.ToString()); - if (permitted) + Guid.Parse("a58db474-43d8-42f1-b4ee-0c17647cd0c0"), // Input Guid + new DateTime(2022, 3, 12, 12, 12, 0, DateTimeKind.Utc), // Input Time + Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb + }, + new object[] { + Guid.Parse("f776e6ee-511f-4352-bb28-88513002bdeb"), + new DateTime(2021, 5, 10, 10, 52, 0, DateTimeKind.Utc), + Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), + }, + new object[] + { + Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011648a1"), + new DateTime(1999, 2, 26, 16, 53, 13, DateTimeKind.Utc), + Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), + }, + new object[] + { + Guid.Parse("bfb8f353-3b32-4a9e-bef6-24fe0b54bfb0"), + new DateTime(2024, 10, 20, 1, 32, 16, DateTimeKind.Utc), + Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), + } + }; + [Theory] + [MemberData(nameof(GenerateCombCases))] + public void GenerateComb_WithInputs_Success(Guid inputGuid, DateTime inputTime, Guid expectedComb) + { + var comb = CoreHelpers.GenerateComb(inputGuid, inputTime); + + Assert.Equal(expectedComb, comb); + } + + [Theory] + [InlineData(2, 5, new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 })] + [InlineData(2, 3, new[] { 1, 2, 3, 4, 5 })] + [InlineData(2, 1, new[] { 1, 2 })] + [InlineData(1, 1, new[] { 1 })] + [InlineData(2, 2, new[] { 1, 2, 3 })] + public void Batch_Success(int batchSize, int totalBatches, int[] collection) + { + // Arrange + var remainder = collection.Length % batchSize; + + // Act + var batches = collection.Batch(batchSize); + + // Assert + Assert.Equal(totalBatches, batches.Count()); + + foreach (var batch in batches.Take(totalBatches - 1)) + { + Assert.Equal(batchSize, batch.Count()); + } + + Assert.Equal(batches.Last().Count(), remainder == 0 ? batchSize : remainder); + } + + /* + [Fact] + public void ToGuidIdArrayTVP_Success() + { + // Arrange + var item0 = Guid.NewGuid(); + var item1 = Guid.NewGuid(); + + var ids = new[] { item0, item1 }; + + // Act + var dt = ids.ToGuidIdArrayTVP(); + + // Assert + Assert.Single(dt.Columns); + Assert.Equal("GuidId", dt.Columns[0].ColumnName); + Assert.Equal(2, dt.Rows.Count); + Assert.Equal(item0, dt.Rows[0][0]); + Assert.Equal(item1, dt.Rows[1][0]); + } + */ + + // TODO: Test the other ToArrayTVP Methods + + [Theory] + [InlineData("12345&6789", "123456789")] + [InlineData("abcdef", "ABCDEF")] + [InlineData("1!@#$%&*()_+", "1")] + [InlineData("\u00C6123abc\u00C7", "123ABC")] + [InlineData("123\u00C6ABC", "123ABC")] + [InlineData("\r\nHello", "E")] + [InlineData("\tdef", "DEF")] + [InlineData("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV1234567890", "ABCDEFABCDEF1234567890")] + public void CleanCertificateThumbprint_Success(string input, string output) + { + // Arrange & Act + var sanitizedInput = CoreHelpers.CleanCertificateThumbprint(input); + + // Assert + Assert.Equal(output, sanitizedInput); + } + + // TODO: Add more tests + [Theory] + [MemberData(nameof(_epochTestCases))] + public void ToEpocMilliseconds_Success(DateTime date, long milliseconds) + { + // Act & Assert + Assert.Equal(milliseconds, CoreHelpers.ToEpocMilliseconds(date)); + } + + [Theory] + [MemberData(nameof(_epochTestCases))] + public void FromEpocMilliseconds(DateTime date, long milliseconds) + { + // Act & Assert + Assert.Equal(date, CoreHelpers.FromEpocMilliseconds(milliseconds)); + } + + [Fact] + public void SecureRandomString_Success() + { + // Arrange & Act + var @string = CoreHelpers.SecureRandomString(8); + + // Assert + // TODO: Should probably add more Asserts down the line + Assert.Equal(8, @string.Length); + } + + [Theory] + [InlineData(1, "1 Bytes")] + [InlineData(-5L, "-5 Bytes")] + [InlineData(1023L, "1023 Bytes")] + [InlineData(1024L, "1 KB")] + [InlineData(1025L, "1 KB")] + [InlineData(-1023L, "-1023 Bytes")] + [InlineData(-1024L, "-1 KB")] + [InlineData(-1025L, "-1 KB")] + [InlineData(1048575L, "1024 KB")] + [InlineData(1048576L, "1 MB")] + [InlineData(1048577L, "1 MB")] + [InlineData(-1048575L, "-1024 KB")] + [InlineData(-1048576L, "-1 MB")] + [InlineData(-1048577L, "-1 MB")] + [InlineData(1073741823L, "1024 MB")] + [InlineData(1073741824L, "1 GB")] + [InlineData(1073741825L, "1 GB")] + [InlineData(-1073741823L, "-1024 MB")] + [InlineData(-1073741824L, "-1 GB")] + [InlineData(-1073741825L, "-1 GB")] + [InlineData(long.MaxValue, "8589934592 GB")] + public void ReadableBytesSize_Success(long size, string readable) + { + // Act & Assert + Assert.Equal(readable, CoreHelpers.ReadableBytesSize(size)); + } + + [Fact] + public void CloneObject_Success() + { + var original = new { Message = "Message" }; + + var copy = CoreHelpers.CloneObject(original); + + Assert.Equal(original.Message, copy.Message); + } + + [Fact] + public void ExtendQuery_AddNewParameter_Success() + { + // Arrange + var uri = new Uri("https://bitwarden.com/?param1=value1"); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, + new Dictionary { { "param2", "value2" } }); + + // Assert + Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2", newUri.ToString()); + } + + [Fact] + public void ExtendQuery_AddTwoNewParameters_Success() + { + // Arrange + var uri = new Uri("https://bitwarden.com/?param1=value1"); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, + new Dictionary + { + { "param2", "value2" }, + { "param3", "value3" } + }); + + // Assert + Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2¶m3=value3", newUri.ToString()); + } + + [Fact] + public void ExtendQuery_AddExistingParameter_Success() + { + // Arrange + var uri = new Uri("https://bitwarden.com/?param1=value1¶m2=value2"); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, + new Dictionary { { "param1", "test_value" } }); + + // Assert + Assert.Equal("https://bitwarden.com/?param1=test_value¶m2=value2", newUri.ToString()); + } + + [Fact] + public void ExtendQuery_AddNoParameters_Success() + { + // Arrange + const string startingUri = "https://bitwarden.com/?param1=value1"; + + var uri = new Uri(startingUri); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, new Dictionary()); + + // Assert + Assert.Equal(startingUri, newUri.ToString()); + } + + [Theory] + [InlineData("bücher.com", "xn--bcher-kva.com")] + [InlineData("bücher.cömé", "xn--bcher-kva.xn--cm-cja4c")] + [InlineData("hello@bücher.com", "hello@xn--bcher-kva.com")] + [InlineData("hello@world.cömé", "hello@world.xn--cm-cja4c")] + [InlineData("hello@bücher.cömé", "hello@xn--bcher-kva.xn--cm-cja4c")] + [InlineData("ascii.com", "ascii.com")] + [InlineData("", "")] + [InlineData(null, null)] + public void PunyEncode_Success(string text, string expected) + { + var actual = CoreHelpers.PunyEncode(text); + Assert.Equal(expected, actual); + } + + [Fact] + public void GetEmbeddedResourceContentsAsync_Success() + { + var fileContents = CoreHelpers.GetEmbeddedResourceContentsAsync("data.embeddedResource.txt"); + Assert.Equal("Contents of embeddedResource.txt\n", fileContents.Replace("\r\n", "\n")); + } + + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_BaseClaims_Success(User user, bool isPremium) + { + var expected = new Dictionary + { + { "premium", isPremium ? "true" : "false" }, + { JwtClaimTypes.Email, user.Email }, + { JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false" }, + { JwtClaimTypes.Name, user.Name }, + { "sstamp", user.SecurityStamp }, + }.ToList(); + + var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), + Array.Empty(), isPremium); + + foreach (var claim in expected) + { Assert.Contains(claim, actual); } - else + Assert.Equal(expected.Count, actual.Count); + } + + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_NonCustomOrganizationUserType_Success(User user) + { + var fixture = new Fixture().WithAutoNSubstitutions(); + foreach (var organizationUserType in Enum.GetValues().Except(new[] { OrganizationUserType.Custom })) { - Assert.DoesNotContain(claim, actual); + var org = fixture.Create(); + org.Type = organizationUserType; + + var expected = new KeyValuePair($"org{organizationUserType.ToString().ToLower()}", org.Id.ToString()); + var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); + + Assert.Contains(expected, actual); } } - } - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_ProviderClaims_Success(User user) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - var providers = new List(); - foreach (var providerUserType in Enum.GetValues()) + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_CustomOrganizationUserClaims_Success(User user, CurrentContentOrganization org) { - var provider = fixture.Create(); - provider.Type = providerUserType; - providers.Add(provider); - } + var fixture = new Fixture().WithAutoNSubstitutions(); + org.Type = OrganizationUserType.Custom; - var claims = new List>(); - - if (providers.Any()) - { - foreach (var group in providers.GroupBy(o => o.Type)) + var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); + foreach (var (permitted, claimName) in org.Permissions.ClaimsMap) { - switch (group.Key) + var claim = new KeyValuePair(claimName, org.Id.ToString()); + if (permitted) { - case ProviderUserType.ProviderAdmin: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); - } - break; - case ProviderUserType.ServiceUser: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); - } - break; + + Assert.Contains(claim, actual); + } + else + { + Assert.DoesNotContain(claim, actual); } } } - var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), providers, false); - foreach (var claim in claims) + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_ProviderClaims_Success(User user) { - Assert.Contains(claim, actual); - } - } - - public static IEnumerable TokenIsValidData() - { - return new[] - { - new object[] + var fixture = new Fixture().WithAutoNSubstitutions(); + var providers = new List(); + foreach (var providerUserType in Enum.GetValues()) { - "first_part 476669d4-9642-4af8-9b29-9366efad4ed3 test@email.com {0}", // unprotectedTokenTemplate - "first_part", // firstPart - "test@email.com", // email - Guid.Parse("476669d4-9642-4af8-9b29-9366efad4ed3"), // id - DateTime.UtcNow.AddHours(-1), // creationTime - 12, // expirationInHours - true, // isValid + var provider = fixture.Create(); + provider.Type = providerUserType; + providers.Add(provider); } - }; - } - [Theory] - [MemberData(nameof(TokenIsValidData))] - public void TokenIsValid_Success(string unprotectedTokenTemplate, string firstPart, string userEmail, Guid id, DateTime creationTime, double expirationInHours, bool isValid) - { - var protector = new TestDataProtector(string.Format(unprotectedTokenTemplate, CoreHelpers.ToEpocMilliseconds(creationTime))); + var claims = new List>(); - Assert.Equal(isValid, CoreHelpers.TokenIsValid(firstPart, protector, "protected_token", userEmail, id, expirationInHours)); - } + if (providers.Any()) + { + foreach (var group in providers.GroupBy(o => o.Type)) + { + switch (group.Key) + { + case ProviderUserType.ProviderAdmin: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); + } + break; + case ProviderUserType.ServiceUser: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); + } + break; + } + } + } - private class TestDataProtector : IDataProtector - { - private readonly string _token; - public TestDataProtector(string token) - { - _token = token; + var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), providers, false); + foreach (var claim in claims) + { + Assert.Contains(claim, actual); + } } - public IDataProtector CreateProtector(string purpose) => throw new NotImplementedException(); - public byte[] Protect(byte[] plaintext) => throw new NotImplementedException(); - public byte[] Unprotect(byte[] protectedData) - { - return Encoding.UTF8.GetBytes(_token); - } - } - [Theory] - [InlineData("hi@email.com", "hi@email.com")] // Short email with no room to obfuscate - [InlineData("name@email.com", "na**@email.com")] // Can obfuscate - [InlineData("reallylongnamethatnooneshouldhave@email", "re*******************************@email")] // Really long email and no .com, .net, etc - [InlineData("name@", "name@")] // @ symbol but no domain - [InlineData("", "")] // Empty string - [InlineData(null, null)] // null - public void ObfuscateEmail_Success(string input, string expected) - { - Assert.Equal(expected, CoreHelpers.ObfuscateEmail(input)); + public static IEnumerable TokenIsValidData() + { + return new[] + { + new object[] + { + "first_part 476669d4-9642-4af8-9b29-9366efad4ed3 test@email.com {0}", // unprotectedTokenTemplate + "first_part", // firstPart + "test@email.com", // email + Guid.Parse("476669d4-9642-4af8-9b29-9366efad4ed3"), // id + DateTime.UtcNow.AddHours(-1), // creationTime + 12, // expirationInHours + true, // isValid + } + }; + } + + [Theory] + [MemberData(nameof(TokenIsValidData))] + public void TokenIsValid_Success(string unprotectedTokenTemplate, string firstPart, string userEmail, Guid id, DateTime creationTime, double expirationInHours, bool isValid) + { + var protector = new TestDataProtector(string.Format(unprotectedTokenTemplate, CoreHelpers.ToEpocMilliseconds(creationTime))); + + Assert.Equal(isValid, CoreHelpers.TokenIsValid(firstPart, protector, "protected_token", userEmail, id, expirationInHours)); + } + + private class TestDataProtector : IDataProtector + { + private readonly string _token; + public TestDataProtector(string token) + { + _token = token; + } + public IDataProtector CreateProtector(string purpose) => throw new NotImplementedException(); + public byte[] Protect(byte[] plaintext) => throw new NotImplementedException(); + public byte[] Unprotect(byte[] protectedData) + { + return Encoding.UTF8.GetBytes(_token); + } + } + + [Theory] + [InlineData("hi@email.com", "hi@email.com")] // Short email with no room to obfuscate + [InlineData("name@email.com", "na**@email.com")] // Can obfuscate + [InlineData("reallylongnamethatnooneshouldhave@email", "re*******************************@email")] // Really long email and no .com, .net, etc + [InlineData("name@", "name@")] // @ symbol but no domain + [InlineData("", "")] // Empty string + [InlineData(null, null)] // null + public void ObfuscateEmail_Success(string input, string expected) + { + Assert.Equal(expected, CoreHelpers.ObfuscateEmail(input)); + } } } diff --git a/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs b/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs index c16a983cf9..09ee18847f 100644 --- a/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs +++ b/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs @@ -1,42 +1,43 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class EncryptedStringAttributeTests +namespace Bit.Core.Test.Utilities { - [Theory] - [InlineData(null)] - [InlineData("aXY=|Y3Q=")] // Valid AesCbc256_B64 - [InlineData("aXY=|Y3Q=|cnNhQ3Q=")] // Valid AesCbc128_HmacSha256_B64 - [InlineData("Rsa2048_OaepSha256_B64.cnNhQ3Q=")] - public void IsValid_ReturnsTrue_WhenValid(string input) + public class EncryptedStringAttributeTests { - var sut = new EncryptedStringAttribute(); + [Theory] + [InlineData(null)] + [InlineData("aXY=|Y3Q=")] // Valid AesCbc256_B64 + [InlineData("aXY=|Y3Q=|cnNhQ3Q=")] // Valid AesCbc128_HmacSha256_B64 + [InlineData("Rsa2048_OaepSha256_B64.cnNhQ3Q=")] + public void IsValid_ReturnsTrue_WhenValid(string input) + { + var sut = new EncryptedStringAttribute(); - var actual = sut.IsValid(input); + var actual = sut.IsValid(input); - Assert.True(actual); - } + Assert.True(actual); + } - [Theory] - [InlineData("")] - [InlineData(".")] - [InlineData("|")] - [InlineData("!|!")] // Invalid base 64 - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.1")] // Invalid length - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.|")] // Empty iv & ct - [InlineData("AesCbc128_HmacSha256_B64.1")] // Invalid length - [InlineData("AesCbc128_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac - [InlineData("Rsa2048_OaepSha256_B64.1|2")] // Invalid length - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|")] // Empty mac - public void IsValid_ReturnsFalse_WhenInvalid(string input) - { - var sut = new EncryptedStringAttribute(); + [Theory] + [InlineData("")] + [InlineData(".")] + [InlineData("|")] + [InlineData("!|!")] // Invalid base 64 + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.1")] // Invalid length + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.|")] // Empty iv & ct + [InlineData("AesCbc128_HmacSha256_B64.1")] // Invalid length + [InlineData("AesCbc128_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac + [InlineData("Rsa2048_OaepSha256_B64.1|2")] // Invalid length + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|")] // Empty mac + public void IsValid_ReturnsFalse_WhenInvalid(string input) + { + var sut = new EncryptedStringAttribute(); - var actual = sut.IsValid(input); + var actual = sut.IsValid(input); - Assert.False(actual); + Assert.False(actual); + } } } diff --git a/test/Core.Test/Utilities/JsonHelpersTests.cs b/test/Core.Test/Utilities/JsonHelpersTests.cs index 8c12cf22eb..8a9a26614d 100644 --- a/test/Core.Test/Utilities/JsonHelpersTests.cs +++ b/test/Core.Test/Utilities/JsonHelpersTests.cs @@ -2,64 +2,65 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Helpers; - -public class JsonHelpersTests +namespace Bit.Core.Test.Helpers { - private static void CompareJson(T value, JsonSerializerOptions options, Newtonsoft.Json.JsonSerializerSettings settings) + public class JsonHelpersTests { - var stgJson = JsonSerializer.Serialize(value, options); - var nsJson = Newtonsoft.Json.JsonConvert.SerializeObject(value, settings); + private static void CompareJson(T value, JsonSerializerOptions options, Newtonsoft.Json.JsonSerializerSettings settings) + { + var stgJson = JsonSerializer.Serialize(value, options); + var nsJson = Newtonsoft.Json.JsonConvert.SerializeObject(value, settings); - Assert.Equal(stgJson, nsJson); + Assert.Equal(stgJson, nsJson); + } + + + [Fact] + public void DefaultJsonOptions() + { + var testObject = new SimpleTestObject + { + Id = 0, + Name = "Test", + }; + + CompareJson(testObject, JsonHelpers.Default, new Newtonsoft.Json.JsonSerializerSettings()); + } + + [Fact] + public void IndentedJsonOptions() + { + var testObject = new SimpleTestObject + { + Id = 10, + Name = "Test Name" + }; + + CompareJson(testObject, JsonHelpers.Indented, new Newtonsoft.Json.JsonSerializerSettings + { + Formatting = Newtonsoft.Json.Formatting.Indented, + }); + } + + [Fact] + public void NullValueHandlingJsonOptions() + { + var testObject = new SimpleTestObject + { + Id = 14, + Name = null, + }; + + CompareJson(testObject, JsonHelpers.IgnoreWritingNull, new Newtonsoft.Json.JsonSerializerSettings + { + NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore, + }); + } } - - [Fact] - public void DefaultJsonOptions() + public class SimpleTestObject { - var testObject = new SimpleTestObject - { - Id = 0, - Name = "Test", - }; - - CompareJson(testObject, JsonHelpers.Default, new Newtonsoft.Json.JsonSerializerSettings()); - } - - [Fact] - public void IndentedJsonOptions() - { - var testObject = new SimpleTestObject - { - Id = 10, - Name = "Test Name" - }; - - CompareJson(testObject, JsonHelpers.Indented, new Newtonsoft.Json.JsonSerializerSettings - { - Formatting = Newtonsoft.Json.Formatting.Indented, - }); - } - - [Fact] - public void NullValueHandlingJsonOptions() - { - var testObject = new SimpleTestObject - { - Id = 14, - Name = null, - }; - - CompareJson(testObject, JsonHelpers.IgnoreWritingNull, new Newtonsoft.Json.JsonSerializerSettings - { - NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore, - }); + public int Id { get; set; } + public string Name { get; set; } } } - -public class SimpleTestObject -{ - public int Id { get; set; } - public string Name { get; set; } -} diff --git a/test/Core.Test/Utilities/PermissiveStringConverterTests.cs b/test/Core.Test/Utilities/PermissiveStringConverterTests.cs index dc23b1acb7..396d277e41 100644 --- a/test/Core.Test/Utilities/PermissiveStringConverterTests.cs +++ b/test/Core.Test/Utilities/PermissiveStringConverterTests.cs @@ -5,164 +5,165 @@ using Bit.Core.Utilities; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class PermissiveStringConverterTests +namespace Bit.Core.Test.Utilities { - private const string numberJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": [ 2, 3 ]}"; - private const string stringJson = "{ \"StringProp\": \"1\", \"EnumerableStringProp\": [ \"2\", \"3\" ]}"; - private const string nullAndEmptyJson = "{ \"StringProp\": null, \"EnumerableStringProp\": [] }"; - private const string singleValueJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": \"Hello!\" }"; - private const string nullJson = "{ \"StringProp\": null, \"EnumerableStringProp\": null }"; - private const string boolJson = "{ \"StringProp\": true, \"EnumerableStringProp\": [ false, 1.2]}"; - private const string objectJsonOne = "{ \"StringProp\": { \"Message\": \"Hi\"}, \"EnumerableStringProp\": []}"; - private const string objectJsonTwo = "{ \"StringProp\": \"Hi\", \"EnumerableStringProp\": {}}"; - private readonly string bigNumbersJson = - "{ \"StringProp\":" + decimal.MinValue + ", \"EnumerableStringProp\": [" + ulong.MaxValue + ", " + long.MinValue + "]}"; - - [Theory] - [InlineData(numberJson)] - [InlineData(stringJson)] - public void Read_Success(string json) + public class PermissiveStringConverterTests { - var obj = JsonSerializer.Deserialize(json); - Assert.Equal("1", obj.StringProp); - Assert.Equal(2, obj.EnumerableStringProp.Count()); - Assert.Equal("2", obj.EnumerableStringProp.ElementAt(0)); - Assert.Equal("3", obj.EnumerableStringProp.ElementAt(1)); - } + private const string numberJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": [ 2, 3 ]}"; + private const string stringJson = "{ \"StringProp\": \"1\", \"EnumerableStringProp\": [ \"2\", \"3\" ]}"; + private const string nullAndEmptyJson = "{ \"StringProp\": null, \"EnumerableStringProp\": [] }"; + private const string singleValueJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": \"Hello!\" }"; + private const string nullJson = "{ \"StringProp\": null, \"EnumerableStringProp\": null }"; + private const string boolJson = "{ \"StringProp\": true, \"EnumerableStringProp\": [ false, 1.2]}"; + private const string objectJsonOne = "{ \"StringProp\": { \"Message\": \"Hi\"}, \"EnumerableStringProp\": []}"; + private const string objectJsonTwo = "{ \"StringProp\": \"Hi\", \"EnumerableStringProp\": {}}"; + private readonly string bigNumbersJson = + "{ \"StringProp\":" + decimal.MinValue + ", \"EnumerableStringProp\": [" + ulong.MaxValue + ", " + long.MinValue + "]}"; - [Fact] - public void Read_Boolean_Success() - { - var obj = JsonSerializer.Deserialize(boolJson); - Assert.Equal("True", obj.StringProp); - Assert.Equal(2, obj.EnumerableStringProp.Count()); - Assert.Equal("False", obj.EnumerableStringProp.ElementAt(0)); - Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_Float_Success_Culture() - { - var ci = new CultureInfo("sv-SE"); - Thread.CurrentThread.CurrentCulture = ci; - Thread.CurrentThread.CurrentUICulture = ci; - - var obj = JsonSerializer.Deserialize(boolJson); - Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_BigNumbers_Success() - { - var obj = JsonSerializer.Deserialize(bigNumbersJson); - Assert.Equal(decimal.MinValue.ToString(), obj.StringProp); - Assert.Equal(2, obj.EnumerableStringProp.Count()); - Assert.Equal(ulong.MaxValue.ToString(), obj.EnumerableStringProp.ElementAt(0)); - Assert.Equal(long.MinValue.ToString(), obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_SingleValue_Success() - { - var obj = JsonSerializer.Deserialize(singleValueJson); - Assert.Equal("1", obj.StringProp); - Assert.Single(obj.EnumerableStringProp); - Assert.Equal("Hello!", obj.EnumerableStringProp.ElementAt(0)); - } - - [Fact] - public void Read_NullAndEmptyJson_Success() - { - var obj = JsonSerializer.Deserialize(nullAndEmptyJson); - Assert.Null(obj.StringProp); - Assert.Empty(obj.EnumerableStringProp); - } - - [Fact] - public void Read_Null_Success() - { - var obj = JsonSerializer.Deserialize(nullJson); - Assert.Null(obj.StringProp); - Assert.Null(obj.EnumerableStringProp); - } - - [Theory] - [InlineData(objectJsonOne)] - [InlineData(objectJsonTwo)] - public void Read_Object_Throws(string json) - { - var exception = Assert.Throws(() => JsonSerializer.Deserialize(json)); - } - - [Fact] - public void Write_Success() - { - var json = JsonSerializer.Serialize(new TestObject + [Theory] + [InlineData(numberJson)] + [InlineData(stringJson)] + public void Read_Success(string json) { - StringProp = "1", - EnumerableStringProp = new List + var obj = JsonSerializer.Deserialize(json); + Assert.Equal("1", obj.StringProp); + Assert.Equal(2, obj.EnumerableStringProp.Count()); + Assert.Equal("2", obj.EnumerableStringProp.ElementAt(0)); + Assert.Equal("3", obj.EnumerableStringProp.ElementAt(1)); + } + + [Fact] + public void Read_Boolean_Success() + { + var obj = JsonSerializer.Deserialize(boolJson); + Assert.Equal("True", obj.StringProp); + Assert.Equal(2, obj.EnumerableStringProp.Count()); + Assert.Equal("False", obj.EnumerableStringProp.ElementAt(0)); + Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); + } + + [Fact] + public void Read_Float_Success_Culture() + { + var ci = new CultureInfo("sv-SE"); + Thread.CurrentThread.CurrentCulture = ci; + Thread.CurrentThread.CurrentUICulture = ci; + + var obj = JsonSerializer.Deserialize(boolJson); + Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); + } + + [Fact] + public void Read_BigNumbers_Success() + { + var obj = JsonSerializer.Deserialize(bigNumbersJson); + Assert.Equal(decimal.MinValue.ToString(), obj.StringProp); + Assert.Equal(2, obj.EnumerableStringProp.Count()); + Assert.Equal(ulong.MaxValue.ToString(), obj.EnumerableStringProp.ElementAt(0)); + Assert.Equal(long.MinValue.ToString(), obj.EnumerableStringProp.ElementAt(1)); + } + + [Fact] + public void Read_SingleValue_Success() + { + var obj = JsonSerializer.Deserialize(singleValueJson); + Assert.Equal("1", obj.StringProp); + Assert.Single(obj.EnumerableStringProp); + Assert.Equal("Hello!", obj.EnumerableStringProp.ElementAt(0)); + } + + [Fact] + public void Read_NullAndEmptyJson_Success() + { + var obj = JsonSerializer.Deserialize(nullAndEmptyJson); + Assert.Null(obj.StringProp); + Assert.Empty(obj.EnumerableStringProp); + } + + [Fact] + public void Read_Null_Success() + { + var obj = JsonSerializer.Deserialize(nullJson); + Assert.Null(obj.StringProp); + Assert.Null(obj.EnumerableStringProp); + } + + [Theory] + [InlineData(objectJsonOne)] + [InlineData(objectJsonTwo)] + public void Read_Object_Throws(string json) + { + var exception = Assert.Throws(() => JsonSerializer.Deserialize(json)); + } + + [Fact] + public void Write_Success() + { + var json = JsonSerializer.Serialize(new TestObject { - "2", - "3", - }, - }); + StringProp = "1", + EnumerableStringProp = new List + { + "2", + "3", + }, + }); - var jsonElement = JsonDocument.Parse(json).RootElement; + var jsonElement = JsonDocument.Parse(json).RootElement; - var stringProp = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String); - Assert.Equal("1", stringProp.GetString()); - var list = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); - Assert.Equal(2, list.GetArrayLength()); - var firstElement = list[0]; - Assert.Equal(JsonValueKind.String, firstElement.ValueKind); - Assert.Equal("2", firstElement.GetString()); - var secondElement = list[1]; - Assert.Equal(JsonValueKind.String, secondElement.ValueKind); - Assert.Equal("3", secondElement.GetString()); + var stringProp = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String); + Assert.Equal("1", stringProp.GetString()); + var list = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); + Assert.Equal(2, list.GetArrayLength()); + var firstElement = list[0]; + Assert.Equal(JsonValueKind.String, firstElement.ValueKind); + Assert.Equal("2", firstElement.GetString()); + var secondElement = list[1]; + Assert.Equal(JsonValueKind.String, secondElement.ValueKind); + Assert.Equal("3", secondElement.GetString()); + } + + [Fact] + public void Write_Null() + { + // When the values are null the converters aren't actually ran and it automatically serializes null + var json = JsonSerializer.Serialize(new TestObject + { + StringProp = null, + EnumerableStringProp = null, + }); + + var jsonElement = JsonDocument.Parse(json).RootElement; + + AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.Null); + AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Null); + } + + [Fact] + public void Write_Empty() + { + // When the values are null the converters aren't actually ran and it automatically serializes null + var json = JsonSerializer.Serialize(new TestObject + { + StringProp = "", + EnumerableStringProp = Enumerable.Empty(), + }); + + var jsonElement = JsonDocument.Parse(json).RootElement; + + var stringVal = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String).GetString(); + Assert.Equal("", stringVal); + var array = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); + Assert.Equal(0, array.GetArrayLength()); + } } - [Fact] - public void Write_Null() + public class TestObject { - // When the values are null the converters aren't actually ran and it automatically serializes null - var json = JsonSerializer.Serialize(new TestObject - { - StringProp = null, - EnumerableStringProp = null, - }); + [JsonConverter(typeof(PermissiveStringConverter))] + public string StringProp { get; set; } - var jsonElement = JsonDocument.Parse(json).RootElement; - - AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.Null); - AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Null); - } - - [Fact] - public void Write_Empty() - { - // When the values are null the converters aren't actually ran and it automatically serializes null - var json = JsonSerializer.Serialize(new TestObject - { - StringProp = "", - EnumerableStringProp = Enumerable.Empty(), - }); - - var jsonElement = JsonDocument.Parse(json).RootElement; - - var stringVal = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String).GetString(); - Assert.Equal("", stringVal); - var array = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); - Assert.Equal(0, array.GetArrayLength()); + [JsonConverter(typeof(PermissiveStringEnumerableConverter))] + public IEnumerable EnumerableStringProp { get; set; } } } - -public class TestObject -{ - [JsonConverter(typeof(PermissiveStringConverter))] - public string StringProp { get; set; } - - [JsonConverter(typeof(PermissiveStringEnumerableConverter))] - public IEnumerable EnumerableStringProp { get; set; } -} diff --git a/test/Core.Test/Utilities/SelfHostedAttributeTests.cs b/test/Core.Test/Utilities/SelfHostedAttributeTests.cs index 564c328395..4261cf7f53 100644 --- a/test/Core.Test/Utilities/SelfHostedAttributeTests.cs +++ b/test/Core.Test/Utilities/SelfHostedAttributeTests.cs @@ -10,82 +10,83 @@ using Microsoft.Extensions.DependencyInjection; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class SelfHostedAttributeTests +namespace Bit.Core.Test.Utilities { - [Fact] - public void NotSelfHosted_Throws_When_SelfHosted() + public class SelfHostedAttributeTests { - // Arrange - var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; - - // Act & Assert - Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: true))); - } - - [Fact] - public void NotSelfHosted_Success_When_NotSelfHosted() - { - // Arrange - var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; - - // Act - sha.OnActionExecuting(GetContext(selfHosted: false)); - - // Assert - // The Assert here is just NOT throwing an exception - } - - - [Fact] - public void SelfHosted_Success_When_SelfHosted() - { - // Arrange - var sha = new SelfHostedAttribute { SelfHostedOnly = true }; - - // Act - sha.OnActionExecuting(GetContext(selfHosted: true)); - - // Assert - // The Assert here is just NOT throwing an exception - } - - [Fact] - public void SelfHosted_Throws_When_NotSelfHosted() - { - // Arrange - var sha = new SelfHostedAttribute { SelfHostedOnly = true }; - - // Act & Assert - Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: false))); - } - - - // This generates a ActionExecutingContext with the needed injected - // service with the given value. - private ActionExecutingContext GetContext(bool selfHosted) - { - IServiceCollection services = new ServiceCollection(); - - var globalSettings = new GlobalSettings + [Fact] + public void NotSelfHosted_Throws_When_SelfHosted() { - SelfHosted = selfHosted - }; + // Arrange + var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; - services.AddSingleton(globalSettings); + // Act & Assert + Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: true))); + } - var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = services.BuildServiceProvider(); + [Fact] + public void NotSelfHosted_Success_When_NotSelfHosted() + { + // Arrange + var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; - var context = Substitute.For( - Substitute.For(httpContext, - new RouteData(), - Substitute.For()), - new List(), - new Dictionary(), - Substitute.For()); + // Act + sha.OnActionExecuting(GetContext(selfHosted: false)); - return context; + // Assert + // The Assert here is just NOT throwing an exception + } + + + [Fact] + public void SelfHosted_Success_When_SelfHosted() + { + // Arrange + var sha = new SelfHostedAttribute { SelfHostedOnly = true }; + + // Act + sha.OnActionExecuting(GetContext(selfHosted: true)); + + // Assert + // The Assert here is just NOT throwing an exception + } + + [Fact] + public void SelfHosted_Throws_When_NotSelfHosted() + { + // Arrange + var sha = new SelfHostedAttribute { SelfHostedOnly = true }; + + // Act & Assert + Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: false))); + } + + + // This generates a ActionExecutingContext with the needed injected + // service with the given value. + private ActionExecutingContext GetContext(bool selfHosted) + { + IServiceCollection services = new ServiceCollection(); + + var globalSettings = new GlobalSettings + { + SelfHosted = selfHosted + }; + + services.AddSingleton(globalSettings); + + var httpContext = new DefaultHttpContext(); + httpContext.RequestServices = services.BuildServiceProvider(); + + var context = Substitute.For( + Substitute.For(httpContext, + new RouteData(), + Substitute.For()), + new List(), + new Dictionary(), + Substitute.For()); + + return context; + } } } diff --git a/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs b/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs index bcd3efcc13..6fac595621 100644 --- a/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs +++ b/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs @@ -1,58 +1,59 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class StrictEmailAttributeTests +namespace Bit.Core.Test.Utilities { - [Theory] - [InlineData("hello@world.com")] // regular email address - [InlineData("hello@world.planet.com")] // subdomain - [InlineData("hello+1@world.com")] // alias - [InlineData("hello.there@world.com")] // period in local-part - [InlineData("hello@wörldé.com")] // unicode domain - [InlineData("hello@world.cömé")] // unicode top-level domain - public void IsValid_ReturnsTrueWhenValid(string email) + public class StrictEmailAttributeTests { - var sut = new StrictEmailAddressAttribute(); + [Theory] + [InlineData("hello@world.com")] // regular email address + [InlineData("hello@world.planet.com")] // subdomain + [InlineData("hello+1@world.com")] // alias + [InlineData("hello.there@world.com")] // period in local-part + [InlineData("hello@wörldé.com")] // unicode domain + [InlineData("hello@world.cömé")] // unicode top-level domain + public void IsValid_ReturnsTrueWhenValid(string email) + { + var sut = new StrictEmailAddressAttribute(); - var actual = sut.IsValid(email); + var actual = sut.IsValid(email); - Assert.True(actual); - } + Assert.True(actual); + } - [Theory] - [InlineData(null)] // null - [InlineData("hello@world.com\t")] // trailing tab char - [InlineData("\thello@world.com")] // leading tab char - [InlineData("hel\tlo@world.com")] // local-part tab char - [InlineData("hello@world.com\b")] // trailing backspace char - [InlineData("\" \"hello@world.com")] // leading spaces in quotes - [InlineData("hello@world.com\" \"")] // trailing spaces in quotes - [InlineData("hel\" \"lo@world.com")] // local-part spaces in quotes - [InlineData("hello there@world.com")] // unescaped unquoted spaces - [InlineData("Hello ")] // friendly from - [InlineData("")] // wrapped angle brackets - [InlineData("hello(com)there@world.com")] // comment - [InlineData("hello@world.com.")] // trailing period - [InlineData(".hello@world.com")] // leading period - [InlineData("hello@world.com;")] // trailing semicolon - [InlineData(";hello@world.com")] // leading semicolon - [InlineData("hello@world.com; hello@world.com")] // semicolon separated list - [InlineData("hello@world.com, hello@world.com")] // comma separated list - [InlineData("hellothere@worldcom")] // dotless domain - [InlineData("hello.there@worldcom")] // dotless domain - [InlineData("hellothere@.worldcom")] // domain beginning with dot - [InlineData("hellothere@worldcom.")] // domain ending in dot - [InlineData("hellothere@world.com-")] // domain ending in hyphen - [InlineData("hellö@world.com")] // unicode at end of local-part - [InlineData("héllo@world.com")] // unicode in middle of local-part - public void IsValid_ReturnsFalseWhenInvalid(string email) - { - var sut = new StrictEmailAddressAttribute(); + [Theory] + [InlineData(null)] // null + [InlineData("hello@world.com\t")] // trailing tab char + [InlineData("\thello@world.com")] // leading tab char + [InlineData("hel\tlo@world.com")] // local-part tab char + [InlineData("hello@world.com\b")] // trailing backspace char + [InlineData("\" \"hello@world.com")] // leading spaces in quotes + [InlineData("hello@world.com\" \"")] // trailing spaces in quotes + [InlineData("hel\" \"lo@world.com")] // local-part spaces in quotes + [InlineData("hello there@world.com")] // unescaped unquoted spaces + [InlineData("Hello ")] // friendly from + [InlineData("")] // wrapped angle brackets + [InlineData("hello(com)there@world.com")] // comment + [InlineData("hello@world.com.")] // trailing period + [InlineData(".hello@world.com")] // leading period + [InlineData("hello@world.com;")] // trailing semicolon + [InlineData(";hello@world.com")] // leading semicolon + [InlineData("hello@world.com; hello@world.com")] // semicolon separated list + [InlineData("hello@world.com, hello@world.com")] // comma separated list + [InlineData("hellothere@worldcom")] // dotless domain + [InlineData("hello.there@worldcom")] // dotless domain + [InlineData("hellothere@.worldcom")] // domain beginning with dot + [InlineData("hellothere@worldcom.")] // domain ending in dot + [InlineData("hellothere@world.com-")] // domain ending in hyphen + [InlineData("hellö@world.com")] // unicode at end of local-part + [InlineData("héllo@world.com")] // unicode in middle of local-part + public void IsValid_ReturnsFalseWhenInvalid(string email) + { + var sut = new StrictEmailAddressAttribute(); - var actual = sut.IsValid(email); + var actual = sut.IsValid(email); - Assert.False(actual); + Assert.False(actual); + } } } diff --git a/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs b/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs index 2ec5a45689..2f31a75dcc 100644 --- a/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs +++ b/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs @@ -1,53 +1,54 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities; - -public class StrictEmailAddressListAttributeTests +namespace Bit.Core.Test.Utilities { - public static List EmailList => new() + public class StrictEmailAddressListAttributeTests { - new object[] { new List { "test@domain.com", "test@sub.domain.com", "hello@world.planet.com" }, true }, - new object[] { new List { "/hello@world.com", "hello@##world.pla net.com", "''thello@world.com" }, false }, - new object[] { new List { "/hello.com", "test@domain.com", "''thello@world.com" }, false }, - new object[] { new List { "héllö@world.com", "hello@world.planet.com", "hello@world.planet.com" }, false }, - new object[] { new List { }, false }, - new object[] { new List - { - "test1@domain.com", "test2@domain.com", "test3@domain.com", "test4@domain.com", "test5@domain.com", - "test6@domain.com", "test7@domain.com", "test8@domain.com", "test9@domain.com", "test10@domain.com", - "test11@domain.com", "test12@domain.com", "test13@domain.com", "test14@domain.com", "test15@domain.com", - "test16@domain.com", "test17@domain.com", "test18@domain.com", "test19@domain.com", "test20@domain.com", - "test21@domain.com", "test22@domain.com", "test23@domain.com", "test24@domain.com", "test25@domain.com", - }, false }, - new object[] { new List - { - "test1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincomtest1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincom@test.com", - "test@domain.com" - }, false } // > 256 character email + public static List EmailList => new() + { + new object[] { new List { "test@domain.com", "test@sub.domain.com", "hello@world.planet.com" }, true }, + new object[] { new List { "/hello@world.com", "hello@##world.pla net.com", "''thello@world.com" }, false }, + new object[] { new List { "/hello.com", "test@domain.com", "''thello@world.com" }, false }, + new object[] { new List { "héllö@world.com", "hello@world.planet.com", "hello@world.planet.com" }, false }, + new object[] { new List { }, false }, + new object[] { new List + { + "test1@domain.com", "test2@domain.com", "test3@domain.com", "test4@domain.com", "test5@domain.com", + "test6@domain.com", "test7@domain.com", "test8@domain.com", "test9@domain.com", "test10@domain.com", + "test11@domain.com", "test12@domain.com", "test13@domain.com", "test14@domain.com", "test15@domain.com", + "test16@domain.com", "test17@domain.com", "test18@domain.com", "test19@domain.com", "test20@domain.com", + "test21@domain.com", "test22@domain.com", "test23@domain.com", "test24@domain.com", "test25@domain.com", + }, false }, + new object[] { new List + { + "test1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincomtest1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincom@test.com", + "test@domain.com" + }, false } // > 256 character email - }; + }; - [Theory] - [MemberData(nameof(EmailList))] - public void IsListValid_ReturnsTrue_WhenValid(List emailList, bool valid) - { - var sut = new StrictEmailAddressListAttribute(); + [Theory] + [MemberData(nameof(EmailList))] + public void IsListValid_ReturnsTrue_WhenValid(List emailList, bool valid) + { + var sut = new StrictEmailAddressListAttribute(); - var actual = sut.IsValid(emailList); + var actual = sut.IsValid(emailList); - Assert.Equal(actual, valid); - } + Assert.Equal(actual, valid); + } - [Theory] - [InlineData("single@email.com", false)] - [InlineData(null, false)] - public void IsValid_ReturnsTrue_WhenValid(string email, bool valid) - { - var sut = new StrictEmailAddressListAttribute(); + [Theory] + [InlineData("single@email.com", false)] + [InlineData(null, false)] + public void IsValid_ReturnsTrue_WhenValid(string email, bool valid) + { + var sut = new StrictEmailAddressListAttribute(); - var actual = sut.IsValid(email); + var actual = sut.IsValid(email); - Assert.Equal(actual, valid); + Assert.Equal(actual, valid); + } } } diff --git a/test/Icons.Test/Resources/VerifyResources.cs b/test/Icons.Test/Resources/VerifyResources.cs index 208bd5077b..ad5d8d681a 100644 --- a/test/Icons.Test/Resources/VerifyResources.cs +++ b/test/Icons.Test/Resources/VerifyResources.cs @@ -1,19 +1,20 @@ using Xunit; -namespace Bit.Icons.Test.Resources; - -public class VerifyResources +namespace Bit.Icons.Test.Resources { - [Theory] - [InlineData("Bit.Icons.Resources.public_suffix_list.dat")] - public void Resources_FoundAndReadable(string resourceName) + public class VerifyResources { - var assembly = typeof(Program).Assembly; - - using (var resource = assembly.GetManifestResourceStream(resourceName)) + [Theory] + [InlineData("Bit.Icons.Resources.public_suffix_list.dat")] + public void Resources_FoundAndReadable(string resourceName) { - Assert.NotNull(resource); - Assert.True(resource.CanRead); + var assembly = typeof(Program).Assembly; + + using (var resource = assembly.GetManifestResourceStream(resourceName)) + { + Assert.NotNull(resource); + Assert.True(resource.CanRead); + } } } } diff --git a/test/Icons.Test/Services/IconFetchingServiceTests.cs b/test/Icons.Test/Services/IconFetchingServiceTests.cs index 59f25af244..ed317fb62c 100644 --- a/test/Icons.Test/Services/IconFetchingServiceTests.cs +++ b/test/Icons.Test/Services/IconFetchingServiceTests.cs @@ -3,48 +3,49 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Xunit; -namespace Bit.Icons.Test.Services; - -public class IconFetchingServiceTests +namespace Bit.Icons.Test.Services { - [Theory] - [InlineData("www.google.com")] // https site - [InlineData("neverssl.com")] // http site - [InlineData("ameritrade.com")] - [InlineData("icloud.com")] - [InlineData("bofa.com", Skip = "Broken in pipeline for .NET 6. Tracking link: https://bitwarden.atlassian.net/browse/PS-982")] - public async Task GetIconAsync_Success(string domain) + public class IconFetchingServiceTests { - var sut = new IconFetchingService(GetLogger()); - var result = await sut.GetIconAsync(domain); - - Assert.NotNull(result); - Assert.NotNull(result.Icon); - } - - [Theory] - [InlineData("1.1.1.1")] - [InlineData("")] - [InlineData("localhost")] - public async Task GetIconAsync_ReturnsNull(string domain) - { - var sut = new IconFetchingService(GetLogger()); - var result = await sut.GetIconAsync(domain); - - Assert.Null(result); - } - - private static ILogger GetLogger() - { - var services = new ServiceCollection(); - services.AddLogging(b => + [Theory] + [InlineData("www.google.com")] // https site + [InlineData("neverssl.com")] // http site + [InlineData("ameritrade.com")] + [InlineData("icloud.com")] + [InlineData("bofa.com", Skip = "Broken in pipeline for .NET 6. Tracking link: https://bitwarden.atlassian.net/browse/PS-982")] + public async Task GetIconAsync_Success(string domain) { - b.ClearProviders(); - b.AddDebug(); - }); + var sut = new IconFetchingService(GetLogger()); + var result = await sut.GetIconAsync(domain); - var provider = services.BuildServiceProvider(); + Assert.NotNull(result); + Assert.NotNull(result.Icon); + } - return provider.GetRequiredService>(); + [Theory] + [InlineData("1.1.1.1")] + [InlineData("")] + [InlineData("localhost")] + public async Task GetIconAsync_ReturnsNull(string domain) + { + var sut = new IconFetchingService(GetLogger()); + var result = await sut.GetIconAsync(domain); + + Assert.Null(result); + } + + private static ILogger GetLogger() + { + var services = new ServiceCollection(); + services.AddLogging(b => + { + b.ClearProviders(); + b.AddDebug(); + }); + + var provider = services.BuildServiceProvider(); + + return provider.GetRequiredService>(); + } } } diff --git a/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs b/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs index 3d03d39a92..31cab0e3cf 100644 --- a/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs +++ b/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs @@ -3,32 +3,33 @@ using Bit.IntegrationTestCommon.Factories; using Microsoft.EntityFrameworkCore; using Xunit; -namespace Bit.Identity.IntegrationTest.Controllers; - -public class AccountsControllerTests : IClassFixture +namespace Bit.Identity.IntegrationTest.Controllers { - private readonly IdentityApplicationFactory _factory; - - public AccountsControllerTests(IdentityApplicationFactory factory) + public class AccountsControllerTests : IClassFixture { - _factory = factory; - } + private readonly IdentityApplicationFactory _factory; - [Fact] - public async Task PostRegister_Success() - { - var context = await _factory.RegisterAsync(new RegisterRequestModel + public AccountsControllerTests(IdentityApplicationFactory factory) { - Email = "test+register@email.com", - MasterPasswordHash = "master_password_hash" - }); + _factory = factory; + } - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + [Fact] + public async Task PostRegister_Success() + { + var context = await _factory.RegisterAsync(new RegisterRequestModel + { + Email = "test+register@email.com", + MasterPasswordHash = "master_password_hash" + }); - var database = _factory.GetDatabaseContext(); - var user = await database.Users - .SingleAsync(u => u.Email == "test+register@email.com"); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - Assert.NotNull(user); + var database = _factory.GetDatabaseContext(); + var user = await database.Users + .SingleAsync(u => u.Email == "test+register@email.com"); + + Assert.NotNull(user); + } } } diff --git a/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs b/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs index 14600ba26e..1d7d0dd8df 100644 --- a/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs +++ b/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs @@ -9,425 +9,51 @@ using Bit.Test.Common.Helpers; using Microsoft.EntityFrameworkCore; using Xunit; -namespace Bit.Identity.IntegrationTest.Endpoints; - -public class IdentityServerTests : IClassFixture +namespace Bit.Identity.IntegrationTest.Endpoints { - private const int SecondsInMinute = 60; - private const int MinutesInHour = 60; - private const int SecondsInHour = SecondsInMinute * MinutesInHour; - private readonly IdentityApplicationFactory _factory; - - public IdentityServerTests(IdentityApplicationFactory factory) + public class IdentityServerTests : IClassFixture { - _factory = factory; - } + private const int SecondsInMinute = 60; + private const int MinutesInHour = 60; + private const int SecondsInHour = SecondsInMinute * MinutesInHour; + private readonly IdentityApplicationFactory _factory; - [Fact] - public async Task WellKnownEndpoint_Success() - { - var context = await _factory.Server.GetAsync("/.well-known/openid-configuration"); - - using var body = await AssertHelper.AssertResponseTypeIs(context); - var endpointRoot = body.RootElement; - - // WARNING: Edits to this file should NOT just be made to "get the test to work" they should be made when intentional - // changes were made to this endpoint and proper testing will take place to ensure clients are backwards compatible - // or loss of functionality is properly noted. - await using var fs = File.OpenRead("openid-configuration.json"); - using var knownConfiguration = await JsonSerializer.DeserializeAsync(fs); - var knownConfigurationRoot = knownConfiguration.RootElement; - - AssertHelper.AssertEqualJson(endpointRoot, knownConfigurationRoot); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_Success() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+tokenpassword@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel + public IdentityServerTests(IdentityApplicationFactory factory) { - Email = username, - MasterPasswordHash = "master_password_hash" - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - }), context => context.SetAuthEmail(username)); - - using var body = await AssertDefaultTokenBodyAsync(context); - var root = body.RootElement; - AssertRefreshTokenExists(root); - AssertHelper.AssertJsonProperty(root, "ForcePasswordReset", JsonValueKind.False); - AssertHelper.AssertJsonProperty(root, "ResetMasterPassword", JsonValueKind.False); - var kdf = AssertHelper.AssertJsonProperty(root, "Kdf", JsonValueKind.Number).GetInt32(); - Assert.Equal(0, kdf); - var kdfIterations = AssertHelper.AssertJsonProperty(root, "KdfIterations", JsonValueKind.Number).GetInt32(); - Assert.Equal(5000, kdfIterations); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_NoAuthEmailHeader_Fails() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+noauthemailheader@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_grant", error); - AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_InvalidBase64AuthEmailHeader_Fails() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+badauthheader@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - }), context => context.Request.Headers.Add("Auth-Email", "bad_value")); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_grant", error); - AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_WrongAuthEmailHeader_Fails() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+badauthheader@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - }), context => context.SetAuthEmail("bad_value")); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_grant", error); - AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeRefreshToken_Success() - { - var deviceId = "5a7b19df-0c9d-46bf-a104-8034b5a17182"; - var username = "test+tokenrefresh@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var (_, refreshToken) = await _factory.TokenFromPasswordAsync(username, "master_password_hash", deviceId); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "refresh_token" }, - { "client_id", "web" }, - { "refresh_token", refreshToken }, - })); - - using var body = await AssertDefaultTokenBodyAsync(context); - AssertRefreshTokenExists(body.RootElement); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_Success() - { - var username = "test+tokenclientcredentials@email.com"; - var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var database = _factory.GetDatabaseContext(); - var user = await database.Users - .FirstAsync(u => u.Email == username); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"user.{user.Id}" }, - { "client_secret", user.ApiKey }, - { "scope", "api" }, - { "DeviceIdentifier", deviceId }, - { "DeviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "DeviceName", "firefox" }, - })); - - await AssertDefaultTokenBodyAsync(context, "api"); - } - - [Theory, BitAutoData] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_Success(Organization organization, OrganizationApiKey organizationApiKey) - { - var orgRepo = _factory.Services.GetRequiredService(); - organization.Enabled = true; - organization.UseApi = true; - organization = await orgRepo.CreateAsync(organization); - organizationApiKey.OrganizationId = organization.Id; - organizationApiKey.Type = OrganizationApiKeyType.Default; - - var orgApiKeyRepo = _factory.Services.GetRequiredService(); - await orgApiKeyRepo.CreateAsync(organizationApiKey); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"organization.{organization.Id}" }, - { "client_secret", organizationApiKey.ApiKey }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - - await AssertDefaultTokenBodyAsync(context, "api.organization"); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_BadOrgId_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "organization.bad_guid_zz&" }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - /// - /// This test currently does not test any code that is not covered by other tests but - /// it shows that we probably have some dead code in - /// for installation, organization, and user they split on a '.' but have already checked that at least one - /// '.' exists in the client_id by checking it with - /// I believe that idParts.Length > 1 will ALWAYS return true - /// - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_NoIdPart_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "organization." }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_OrgDoesNotExist_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"organization.{Guid.NewGuid()}" }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Theory, BitAutoData] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationExists_Succeeds(Installation installation) - { - var installationRepo = _factory.Services.GetRequiredService(); - installation = await installationRepo.CreateAsync(installation); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"installation.{installation.Id}" }, - { "client_secret", installation.Key }, - { "scope", "api.push" }, - })); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - await AssertDefaultTokenBodyAsync(context, "api.push", 24 * SecondsInHour); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationDoesNotExist_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"installation.{Guid.NewGuid()}" }, - { "client_secret", "something" }, - { "scope", "api.push" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_BadInsallationId_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "organization.bad_guid_zz&" }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - /// - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_NoIdPart_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "installation." }, - { "client_secret", "something" }, - { "scope", "api.push" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Fact] - public async Task TokenEndpoint_ToQuickInOneSecond_BlockRequest() - { - const int AmountInOneSecondAllowed = 5; - - // The rule we are testing is 10 requests in 1 second - var username = "test+ratelimiting@email.com"; - var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var database = _factory.GetDatabaseContext(); - var user = await database.Users - .FirstAsync(u => u.Email == username); - - var tasks = new Task[AmountInOneSecondAllowed + 1]; - - for (var i = 0; i < AmountInOneSecondAllowed + 1; i++) - { - // Queue all the amount of calls allowed plus 1 - tasks[i] = MakeRequest(); + _factory = factory; } - var responses = (await Task.WhenAll(tasks)).ToList(); - - Assert.Equal(5, responses.Count(c => c.Response.StatusCode == StatusCodes.Status200OK)); - Assert.Equal(1, responses.Count(c => c.Response.StatusCode == StatusCodes.Status429TooManyRequests)); - - Task MakeRequest() + [Fact] + public async Task WellKnownEndpoint_Success() { - return _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + var context = await _factory.Server.GetAsync("/.well-known/openid-configuration"); + + using var body = await AssertHelper.AssertResponseTypeIs(context); + var endpointRoot = body.RootElement; + + // WARNING: Edits to this file should NOT just be made to "get the test to work" they should be made when intentional + // changes were made to this endpoint and proper testing will take place to ensure clients are backwards compatible + // or loss of functionality is properly noted. + await using var fs = File.OpenRead("openid-configuration.json"); + using var knownConfiguration = await JsonSerializer.DeserializeAsync(fs); + var knownConfigurationRoot = knownConfiguration.RootElement; + + AssertHelper.AssertEqualJson(endpointRoot, knownConfigurationRoot); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_Success() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+tokenpassword@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash" + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary { { "scope", "api offline_access" }, { "client_id", "web" }, @@ -437,59 +63,434 @@ public class IdentityServerTests : IClassFixture { "grant_type", "password" }, { "username", username }, { "password", "master_password_hash" }, - }), context => context.SetAuthEmail(username).SetIp("1.1.1.2")); + }), context => context.SetAuthEmail(username)); + + using var body = await AssertDefaultTokenBodyAsync(context); + var root = body.RootElement; + AssertRefreshTokenExists(root); + AssertHelper.AssertJsonProperty(root, "ForcePasswordReset", JsonValueKind.False); + AssertHelper.AssertJsonProperty(root, "ResetMasterPassword", JsonValueKind.False); + var kdf = AssertHelper.AssertJsonProperty(root, "Kdf", JsonValueKind.Number).GetInt32(); + Assert.Equal(0, kdf); + var kdfIterations = AssertHelper.AssertJsonProperty(root, "KdfIterations", JsonValueKind.Number).GetInt32(); + Assert.Equal(5000, kdfIterations); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_NoAuthEmailHeader_Fails() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+noauthemailheader@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_grant", error); + AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_InvalidBase64AuthEmailHeader_Fails() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+badauthheader@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + }), context => context.Request.Headers.Add("Auth-Email", "bad_value")); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_grant", error); + AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_WrongAuthEmailHeader_Fails() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+badauthheader@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + }), context => context.SetAuthEmail("bad_value")); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_grant", error); + AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeRefreshToken_Success() + { + var deviceId = "5a7b19df-0c9d-46bf-a104-8034b5a17182"; + var username = "test+tokenrefresh@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var (_, refreshToken) = await _factory.TokenFromPasswordAsync(username, "master_password_hash", deviceId); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "refresh_token" }, + { "client_id", "web" }, + { "refresh_token", refreshToken }, + })); + + using var body = await AssertDefaultTokenBodyAsync(context); + AssertRefreshTokenExists(body.RootElement); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_Success() + { + var username = "test+tokenclientcredentials@email.com"; + var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var database = _factory.GetDatabaseContext(); + var user = await database.Users + .FirstAsync(u => u.Email == username); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"user.{user.Id}" }, + { "client_secret", user.ApiKey }, + { "scope", "api" }, + { "DeviceIdentifier", deviceId }, + { "DeviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "DeviceName", "firefox" }, + })); + + await AssertDefaultTokenBodyAsync(context, "api"); + } + + [Theory, BitAutoData] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_Success(Organization organization, OrganizationApiKey organizationApiKey) + { + var orgRepo = _factory.Services.GetRequiredService(); + organization.Enabled = true; + organization.UseApi = true; + organization = await orgRepo.CreateAsync(organization); + organizationApiKey.OrganizationId = organization.Id; + organizationApiKey.Type = OrganizationApiKeyType.Default; + + var orgApiKeyRepo = _factory.Services.GetRequiredService(); + await orgApiKeyRepo.CreateAsync(organizationApiKey); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"organization.{organization.Id}" }, + { "client_secret", organizationApiKey.ApiKey }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + await AssertDefaultTokenBodyAsync(context, "api.organization"); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_BadOrgId_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "organization.bad_guid_zz&" }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + /// + /// This test currently does not test any code that is not covered by other tests but + /// it shows that we probably have some dead code in + /// for installation, organization, and user they split on a '.' but have already checked that at least one + /// '.' exists in the client_id by checking it with + /// I believe that idParts.Length > 1 will ALWAYS return true + /// + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_NoIdPart_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "organization." }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_OrgDoesNotExist_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"organization.{Guid.NewGuid()}" }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Theory, BitAutoData] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationExists_Succeeds(Installation installation) + { + var installationRepo = _factory.Services.GetRequiredService(); + installation = await installationRepo.CreateAsync(installation); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"installation.{installation.Id}" }, + { "client_secret", installation.Key }, + { "scope", "api.push" }, + })); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + await AssertDefaultTokenBodyAsync(context, "api.push", 24 * SecondsInHour); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationDoesNotExist_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"installation.{Guid.NewGuid()}" }, + { "client_secret", "something" }, + { "scope", "api.push" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_BadInsallationId_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "organization.bad_guid_zz&" }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + /// + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_NoIdPart_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "installation." }, + { "client_secret", "something" }, + { "scope", "api.push" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Fact] + public async Task TokenEndpoint_ToQuickInOneSecond_BlockRequest() + { + const int AmountInOneSecondAllowed = 5; + + // The rule we are testing is 10 requests in 1 second + var username = "test+ratelimiting@email.com"; + var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var database = _factory.GetDatabaseContext(); + var user = await database.Users + .FirstAsync(u => u.Email == username); + + var tasks = new Task[AmountInOneSecondAllowed + 1]; + + for (var i = 0; i < AmountInOneSecondAllowed + 1; i++) + { + // Queue all the amount of calls allowed plus 1 + tasks[i] = MakeRequest(); + } + + var responses = (await Task.WhenAll(tasks)).ToList(); + + Assert.Equal(5, responses.Count(c => c.Response.StatusCode == StatusCodes.Status200OK)); + Assert.Equal(1, responses.Count(c => c.Response.StatusCode == StatusCodes.Status429TooManyRequests)); + + Task MakeRequest() + { + return _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + }), context => context.SetAuthEmail(username).SetIp("1.1.1.2")); + } + } + + private static string DeviceTypeAsString(DeviceType deviceType) + { + return ((int)deviceType).ToString(); + } + + private static async Task AssertDefaultTokenBodyAsync(HttpContext httpContext, string expectedScope = "api offline_access", int expectedExpiresIn = SecondsInHour * 1) + { + var body = await AssertHelper.AssertResponseTypeIs(httpContext); + var root = body.RootElement; + + Assert.Equal(JsonValueKind.Object, root.ValueKind); + AssertAccessTokenExists(root); + AssertExpiresIn(root, expectedExpiresIn); + AssertTokenType(root); + AssertScope(root, expectedScope); + return body; + } + + private static void AssertTokenType(JsonElement tokenResponse) + { + var tokenTypeProperty = AssertHelper.AssertJsonProperty(tokenResponse, "token_type", JsonValueKind.String).GetString(); + Assert.Equal("Bearer", tokenTypeProperty); + } + + private static int AssertExpiresIn(JsonElement tokenResponse, int expectedExpiresIn = 3600) + { + var expiresIn = AssertHelper.AssertJsonProperty(tokenResponse, "expires_in", JsonValueKind.Number).GetInt32(); + Assert.Equal(expectedExpiresIn, expiresIn); + return expiresIn; + } + + private static string AssertAccessTokenExists(JsonElement tokenResponse) + { + return AssertHelper.AssertJsonProperty(tokenResponse, "access_token", JsonValueKind.String).GetString(); + } + + private static string AssertRefreshTokenExists(JsonElement tokenResponse) + { + return AssertHelper.AssertJsonProperty(tokenResponse, "refresh_token", JsonValueKind.String).GetString(); + } + + private static string AssertScopeExists(JsonElement tokenResponse) + { + return AssertHelper.AssertJsonProperty(tokenResponse, "scope", JsonValueKind.String).GetString(); + } + + private static void AssertScope(JsonElement tokenResponse, string expectedScope) + { + var actualScope = AssertScopeExists(tokenResponse); + Assert.Equal(expectedScope, actualScope); } } - - private static string DeviceTypeAsString(DeviceType deviceType) - { - return ((int)deviceType).ToString(); - } - - private static async Task AssertDefaultTokenBodyAsync(HttpContext httpContext, string expectedScope = "api offline_access", int expectedExpiresIn = SecondsInHour * 1) - { - var body = await AssertHelper.AssertResponseTypeIs(httpContext); - var root = body.RootElement; - - Assert.Equal(JsonValueKind.Object, root.ValueKind); - AssertAccessTokenExists(root); - AssertExpiresIn(root, expectedExpiresIn); - AssertTokenType(root); - AssertScope(root, expectedScope); - return body; - } - - private static void AssertTokenType(JsonElement tokenResponse) - { - var tokenTypeProperty = AssertHelper.AssertJsonProperty(tokenResponse, "token_type", JsonValueKind.String).GetString(); - Assert.Equal("Bearer", tokenTypeProperty); - } - - private static int AssertExpiresIn(JsonElement tokenResponse, int expectedExpiresIn = 3600) - { - var expiresIn = AssertHelper.AssertJsonProperty(tokenResponse, "expires_in", JsonValueKind.Number).GetInt32(); - Assert.Equal(expectedExpiresIn, expiresIn); - return expiresIn; - } - - private static string AssertAccessTokenExists(JsonElement tokenResponse) - { - return AssertHelper.AssertJsonProperty(tokenResponse, "access_token", JsonValueKind.String).GetString(); - } - - private static string AssertRefreshTokenExists(JsonElement tokenResponse) - { - return AssertHelper.AssertJsonProperty(tokenResponse, "refresh_token", JsonValueKind.String).GetString(); - } - - private static string AssertScopeExists(JsonElement tokenResponse) - { - return AssertHelper.AssertJsonProperty(tokenResponse, "scope", JsonValueKind.String).GetString(); - } - - private static void AssertScope(JsonElement tokenResponse, string expectedScope) - { - var actualScope = AssertScopeExists(tokenResponse); - Assert.Equal(expectedScope, actualScope); - } } diff --git a/test/Identity.Test/Controllers/AccountsControllerTests.cs b/test/Identity.Test/Controllers/AccountsControllerTests.cs index 54b5856547..6fa8f493cd 100644 --- a/test/Identity.Test/Controllers/AccountsControllerTests.cs +++ b/test/Identity.Test/Controllers/AccountsControllerTests.cs @@ -11,101 +11,102 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Identity.Test.Controllers; - -public class AccountsControllerTests : IDisposable +namespace Bit.Identity.Test.Controllers { - - private readonly AccountsController _sut; - private readonly ILogger _logger; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - - public AccountsControllerTests() + public class AccountsControllerTests : IDisposable { - _logger = Substitute.For>(); - _userRepository = Substitute.For(); - _userService = Substitute.For(); - _sut = new AccountsController( - _logger, - _userRepository, - _userService - ); - } - public void Dispose() - { - _sut?.Dispose(); - } + private readonly AccountsController _sut; + private readonly ILogger _logger; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; - [Fact] - public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() - { - var userKdfInfo = new UserKdfInformation + public AccountsControllerTests() { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 5000 - }; - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); + _logger = Substitute.For>(); + _userRepository = Substitute.For(); + _userService = Substitute.For(); + _sut = new AccountsController( + _logger, + _userRepository, + _userService + ); + } - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - - Assert.Equal(userKdfInfo.Kdf, response.Kdf); - Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); - } - - [Fact] - public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() - { - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(null!)); - - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - - Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); - Assert.Equal(100000, response.KdfIterations); - } - - [Fact] - public async Task PostRegister_ShouldRegisterUser() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Success)); - var request = new RegisterRequestModel + public void Dispose() { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; + _sut?.Dispose(); + } - await _sut.PostRegister(request); - - await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); - } - - [Fact] - public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Failed())); - var request = new RegisterRequestModel + [Fact] + public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; + var userKdfInfo = new UserKdfInformation + { + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 5000 + }; + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); - await Assert.ThrowsAsync(() => _sut.PostRegister(request)); + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + + Assert.Equal(userKdfInfo.Kdf, response.Kdf); + Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); + } + + [Fact] + public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() + { + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(null!)); + + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + + Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); + Assert.Equal(100000, response.KdfIterations); + } + + [Fact] + public async Task PostRegister_ShouldRegisterUser() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Success)); + var request = new RegisterRequestModel + { + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; + + await _sut.PostRegister(request); + + await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); + } + + [Fact] + public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Failed())); + var request = new RegisterRequestModel + { + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; + + await Assert.ThrowsAsync(() => _sut.PostRegister(request)); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs index 13d222316b..a027dc2408 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs @@ -9,111 +9,112 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class CipherBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public bool OrganizationOwned { get; set; } - public object Create(object request, ISpecimenContext context) + internal class CipherBuilder : ISpecimenBuilder { - if (context == null) + public bool OrganizationOwned { get; set; } + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || (type != typeof(Cipher) && type != typeof(List))) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - - if (!OrganizationOwned) - { - fixture.Customize(composer => composer - .Without(c => c.OrganizationId)); - } - - // Can't test valid Favorites and Folders without creating those values inide each test, - // since we won't have any UserIds until the test is running & creating data - fixture.Customize(c => c - .Without(e => e.Favorites) - .Without(e => e.Folders)); - // - var serializerOptions = new JsonSerializerOptions() - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase - }; - - if (type == typeof(Cipher)) - { - var obj = fixture.WithAutoNSubstitutions().Create(); - var cipherData = fixture.WithAutoNSubstitutions().Create(); - var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); - obj.Data = JsonSerializer.Serialize(cipherData, serializerOptions); - obj.Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); - - return obj; - } - if (type == typeof(List)) - { - var ciphers = fixture.WithAutoNSubstitutions().CreateMany().ToArray(); - for (var i = 0; i < ciphers.Count(); i++) + if (context == null) { - var cipherData = fixture.WithAutoNSubstitutions().Create(); - var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); - ciphers[i].Data = JsonSerializer.Serialize(cipherData, serializerOptions); - ciphers[i].Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); + throw new ArgumentNullException(nameof(context)); } - return ciphers; + var type = request as Type; + if (type == null || (type != typeof(Cipher) && type != typeof(List))) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + + if (!OrganizationOwned) + { + fixture.Customize(composer => composer + .Without(c => c.OrganizationId)); + } + + // Can't test valid Favorites and Folders without creating those values inide each test, + // since we won't have any UserIds until the test is running & creating data + fixture.Customize(c => c + .Without(e => e.Favorites) + .Without(e => e.Folders)); + // + var serializerOptions = new JsonSerializerOptions() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + if (type == typeof(Cipher)) + { + var obj = fixture.WithAutoNSubstitutions().Create(); + var cipherData = fixture.WithAutoNSubstitutions().Create(); + var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); + obj.Data = JsonSerializer.Serialize(cipherData, serializerOptions); + obj.Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); + + return obj; + } + if (type == typeof(List)) + { + var ciphers = fixture.WithAutoNSubstitutions().CreateMany().ToArray(); + for (var i = 0; i < ciphers.Count(); i++) + { + var cipherData = fixture.WithAutoNSubstitutions().Create(); + var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); + ciphers[i].Data = JsonSerializer.Serialize(cipherData, serializerOptions); + ciphers[i].Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); + } + + return ciphers; + } + + return new NoSpecimen(); } - - return new NoSpecimen(); } -} -internal class EfCipher : ICustomization -{ - public bool OrganizationOwned { get; set; } - public void Customize(IFixture fixture) + internal class EfCipher : ICustomization { - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new CipherBuilder() + public bool OrganizationOwned { get; set; } + public void Customize(IFixture fixture) { - OrganizationOwned = OrganizationOwned - }); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new OrganizationUserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new CipherBuilder() + { + OrganizationOwned = OrganizationOwned + }); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new OrganizationUserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } + } + + internal class EfUserCipherAutoDataAttribute : CustomAutoDataAttribute + { + public EfUserCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher()) + { } + } + + internal class EfOrganizationCipherAutoDataAttribute : CustomAutoDataAttribute + { + public EfOrganizationCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher() + { + OrganizationOwned = true, + }) + { } + } + + internal class InlineEfCipherAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfCipher) }, values) + { } } } - -internal class EfUserCipherAutoDataAttribute : CustomAutoDataAttribute -{ - public EfUserCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher()) - { } -} - -internal class EfOrganizationCipherAutoDataAttribute : CustomAutoDataAttribute -{ - public EfOrganizationCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher() - { - OrganizationOwned = true, - }) - { } -} - -internal class InlineEfCipherAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfCipher) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs index 89ffccb2b7..873e424398 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs @@ -7,56 +7,57 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class CollectionCipherBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class CollectionCipherBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(CollectionCipher)) + var type = request as Type; + if (type == null || type != typeof(CollectionCipher)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfCollectionCipher : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new CollectionCipherBuilder()); + fixture.Customizations.Add(new CollectionBuilder()); + fixture.Customizations.Add(new CipherBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfCollectionCipherAutoDataAttribute : CustomAutoDataAttribute + { + public EfCollectionCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollectionCipher()) + { } + } + + internal class InlineEfCollectionCipherAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfCollectionCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfCollectionCipher) }, values) + { } } } - -internal class EfCollectionCipher : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new CollectionCipherBuilder()); - fixture.Customizations.Add(new CollectionBuilder()); - fixture.Customizations.Add(new CipherBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfCollectionCipherAutoDataAttribute : CustomAutoDataAttribute -{ - public EfCollectionCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollectionCipher()) - { } -} - -internal class InlineEfCollectionCipherAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfCollectionCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfCollectionCipher) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs index 4cb6cfbd45..1d96bccdc5 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs @@ -6,52 +6,53 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class CollectionBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class CollectionBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Collection)) + var type = request as Type; + if (type == null || type != typeof(Collection)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfCollection : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new CollectionBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfCollectionAutoDataAttribute : CustomAutoDataAttribute + { + public EfCollectionAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollection()) + { } + } + + internal class InlineEfCollectionAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfCollectionAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfCollection) }, values) + { } } } - -internal class EfCollection : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new CollectionBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfCollectionAutoDataAttribute : CustomAutoDataAttribute -{ - public EfCollectionAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollection()) - { } -} - -internal class InlineEfCollectionAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfCollectionAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfCollection) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs index da5b5b7676..9100af6a8a 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs @@ -7,53 +7,54 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class DeviceBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class DeviceBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Device)) + var type = request as Type; + if (type == null || type != typeof(Device)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfDevice : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new DeviceBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfDeviceAutoDataAttribute : CustomAutoDataAttribute + { + public EfDeviceAutoDataAttribute() : base(new SutProviderCustomization(), new EfDevice()) + { } + } + + internal class InlineEfDeviceAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfDeviceAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfDevice) }, values) + { } } } -internal class EfDevice : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new DeviceBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfDeviceAutoDataAttribute : CustomAutoDataAttribute -{ - public EfDeviceAutoDataAttribute() : base(new SutProviderCustomization(), new EfDevice()) - { } -} - -internal class InlineEfDeviceAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfDeviceAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfDevice) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs index 87a8f796c2..82bc25f751 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs @@ -7,54 +7,55 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class EmergencyAccessBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class EmergencyAccessBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(EmergencyAccess)) + var type = request as Type; + if (type == null || type != typeof(EmergencyAccess)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.Create(); + return obj; + } + } + + internal class EfEmergencyAccess : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + // TODO: Make a base EF Customization with IgnoreVirtualMembers/GlobalSettings/All repos and inherit + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new EmergencyAccessBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.Create(); - return obj; + internal class EfEmergencyAccessAutoDataAttribute : CustomAutoDataAttribute + { + public EfEmergencyAccessAutoDataAttribute() : base(new SutProviderCustomization(), new EfEmergencyAccess()) + { } + } + + internal class InlineEfEmergencyAccessAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfEmergencyAccessAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfEmergencyAccess) }, values) + { } } } -internal class EfEmergencyAccess : ICustomization -{ - public void Customize(IFixture fixture) - { - // TODO: Make a base EF Customization with IgnoreVirtualMembers/GlobalSettings/All repos and inherit - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new EmergencyAccessBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfEmergencyAccessAutoDataAttribute : CustomAutoDataAttribute -{ - public EfEmergencyAccessAutoDataAttribute() : base(new SutProviderCustomization(), new EfEmergencyAccess()) - { } -} - -internal class InlineEfEmergencyAccessAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfEmergencyAccessAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfEmergencyAccess) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs index 4a403b70bc..4c83062b6a 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs @@ -10,112 +10,113 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using Moq; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class ServiceScopeFactoryBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - private DbContextOptions _options { get; set; } - public ServiceScopeFactoryBuilder(DbContextOptions options) + internal class ServiceScopeFactoryBuilder : ISpecimenBuilder { - _options = options; - } - - public object Create(object request, ISpecimenContext context) - { - var fixture = new Fixture(); - var serviceProvider = new Mock(); - var dbContext = new DatabaseContext(_options); - serviceProvider - .Setup(x => x.GetService(typeof(DatabaseContext))) - .Returns(dbContext); - - var serviceScope = new Mock(); - serviceScope.Setup(x => x.ServiceProvider).Returns(serviceProvider.Object); - - var serviceScopeFactory = new Mock(); - serviceScopeFactory - .Setup(x => x.CreateScope()) - .Returns(serviceScope.Object); - return serviceScopeFactory.Object; - } -} - -public class EfRepositoryListBuilder : ISpecimenBuilder where T : BaseEntityFrameworkRepository -{ - public object Create(object request, ISpecimenContext context) - { - if (context == null) + private DbContextOptions _options { get; set; } + public ServiceScopeFactoryBuilder(DbContextOptions options) { - throw new ArgumentNullException(nameof(context)); + _options = options; } - var t = request as ParameterInfo; - if (t == null || t.ParameterType != typeof(List)) - { - return new NoSpecimen(); - } - - var list = new List(); - foreach (var option in DatabaseOptionsFactory.Options) + public object Create(object request, ISpecimenContext context) { var fixture = new Fixture(); - fixture.Customize(x => x.FromFactory(new ServiceScopeFactoryBuilder(option))); - fixture.Customize(x => x.FromFactory(() => - new MapperConfiguration(cfg => - { - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - }) - .CreateMapper())); + var serviceProvider = new Mock(); + var dbContext = new DatabaseContext(_options); + serviceProvider + .Setup(x => x.GetService(typeof(DatabaseContext))) + .Returns(dbContext); - var repo = fixture.Create(); - list.Add(repo); + var serviceScope = new Mock(); + serviceScope.Setup(x => x.ServiceProvider).Returns(serviceProvider.Object); + + var serviceScopeFactory = new Mock(); + serviceScopeFactory + .Setup(x => x.CreateScope()) + .Returns(serviceScope.Object); + return serviceScopeFactory.Object; } - return list; } -} -public class IgnoreVirtualMembersCustomization : ISpecimenBuilder -{ - public object Create(object request, ISpecimenContext context) + public class EfRepositoryListBuilder : ISpecimenBuilder where T : BaseEntityFrameworkRepository { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException("context"); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var pi = request as PropertyInfo; - if (pi == null) + var t = request as ParameterInfo; + if (t == null || t.ParameterType != typeof(List)) + { + return new NoSpecimen(); + } + + var list = new List(); + foreach (var option in DatabaseOptionsFactory.Options) + { + var fixture = new Fixture(); + fixture.Customize(x => x.FromFactory(new ServiceScopeFactoryBuilder(option))); + fixture.Customize(x => x.FromFactory(() => + new MapperConfiguration(cfg => + { + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + }) + .CreateMapper())); + + var repo = fixture.Create(); + list.Add(repo); + } + return list; + } + } + + public class IgnoreVirtualMembersCustomization : ISpecimenBuilder + { + public object Create(object request, ISpecimenContext context) { + if (context == null) + { + throw new ArgumentNullException("context"); + } + + var pi = request as PropertyInfo; + if (pi == null) + { + return new NoSpecimen(); + } + + if (pi.GetGetMethod().IsVirtual && pi.DeclaringType != typeof(GlobalSettings)) + { + return null; + } return new NoSpecimen(); } - - if (pi.GetGetMethod().IsVirtual && pi.DeclaringType != typeof(GlobalSettings)) - { - return null; - } - return new NoSpecimen(); } } diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs index 70b2e9bc9e..ecb4f0ef91 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs @@ -6,51 +6,52 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class EventBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class EventBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Event)) + var type = request as Type; + if (type == null || type != typeof(Event)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfEvent : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new EventBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfEventAutoDataAttribute : CustomAutoDataAttribute + { + public EfEventAutoDataAttribute() : base(new SutProviderCustomization(), new EfEvent()) + { } + } + + internal class InlineEfEventAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfEventAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfEvent) }, values) + { } } } -internal class EfEvent : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new EventBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfEventAutoDataAttribute : CustomAutoDataAttribute -{ - public EfEventAutoDataAttribute() : base(new SutProviderCustomization(), new EfEvent()) - { } -} - -internal class InlineEfEventAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfEventAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfEvent) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs index 884933ffd5..290fffb603 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs @@ -7,53 +7,54 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class FolderBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class FolderBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Folder)) + var type = request as Type; + if (type == null || type != typeof(Folder)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfFolder : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new FolderBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfFolderAutoDataAttribute : CustomAutoDataAttribute + { + public EfFolderAutoDataAttribute() : base(new SutProviderCustomization(), new EfFolder()) + { } + } + + internal class InlineEfFolderAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfFolderAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfFolder) }, values) + { } } } -internal class EfFolder : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new FolderBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfFolderAutoDataAttribute : CustomAutoDataAttribute -{ - public EfFolderAutoDataAttribute() : base(new SutProviderCustomization(), new EfFolder()) - { } -} - -internal class InlineEfFolderAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfFolderAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfFolder) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs index d431132de4..7824426bb4 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs @@ -6,50 +6,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class GrantBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class GrantBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Grant)) + var type = request as Type; + if (type == null || type != typeof(Grant)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfGrant : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new GrantBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfGrantAutoDataAttribute : CustomAutoDataAttribute + { + public EfGrantAutoDataAttribute() : base(new SutProviderCustomization(), new EfGrant()) + { } + } + + internal class InlineEfGrantAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfGrantAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfGrant) }, values) + { } } } - -internal class EfGrant : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new GrantBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfGrantAutoDataAttribute : CustomAutoDataAttribute -{ - public EfGrantAutoDataAttribute() : base(new SutProviderCustomization(), new EfGrant()) - { } -} - -internal class InlineEfGrantAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfGrantAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfGrant) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs index c6cca49015..cfb232ab1d 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs @@ -6,52 +6,53 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class GroupBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class GroupBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Group)) + var type = request as Type; + if (type == null || type != typeof(Group)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfGroup : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new GroupBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfGroupAutoDataAttribute : CustomAutoDataAttribute + { + public EfGroupAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroup()) + { } + } + + internal class InlineEfGroupAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfGroupAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfGroup) }, values) + { } } } - -internal class EfGroup : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new GroupBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfGroupAutoDataAttribute : CustomAutoDataAttribute -{ - public EfGroupAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroup()) - { } -} - -internal class InlineEfGroupAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfGroupAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfGroup) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs index 2b68cde322..d7303b59ca 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs @@ -5,50 +5,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class GroupUserBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class GroupUserBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(GroupUser)) + var type = request as Type; + if (type == null || type != typeof(GroupUser)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfGroupUser : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new GroupUserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfGroupUserAutoDataAttribute : CustomAutoDataAttribute + { + public EfGroupUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroupUser()) + { } + } + + internal class InlineEfGroupUserAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfGroupUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfGroupUser) }, values) + { } } } -internal class EfGroupUser : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new GroupUserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfGroupUserAutoDataAttribute : CustomAutoDataAttribute -{ - public EfGroupUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroupUser()) - { } -} - -internal class InlineEfGroupUserAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfGroupUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfGroupUser) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs index c090a2e38e..1a8c546271 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs @@ -5,50 +5,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class InstallationBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class InstallationBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Installation)) + var type = request as Type; + if (type == null || type != typeof(Installation)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfInstallation : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new InstallationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfInstallationAutoDataAttribute : CustomAutoDataAttribute + { + public EfInstallationAutoDataAttribute() : base(new SutProviderCustomization(), new EfInstallation()) + { } + } + + internal class InlineEfInstallationAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfInstallationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfInstallation) }, values) + { } } } -internal class EfInstallation : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new InstallationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfInstallationAutoDataAttribute : CustomAutoDataAttribute -{ - public EfInstallationAutoDataAttribute() : base(new SutProviderCustomization(), new EfInstallation()) - { } -} - -internal class InlineEfInstallationAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfInstallationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfInstallation) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs index 800ee14d20..f097603900 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs @@ -7,51 +7,52 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class OrganizationBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class OrganizationBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Organization)) + var type = request as Type; + if (type == null || type != typeof(Organization)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var providers = fixture.Create>(); + var organization = new Fixture().WithAutoNSubstitutions().Create(); + organization.SetTwoFactorProviders(providers); + return organization; + } + } + + internal class EfOrganization : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - var providers = fixture.Create>(); - var organization = new Fixture().WithAutoNSubstitutions().Create(); - organization.SetTwoFactorProviders(providers); - return organization; + internal class EfOrganizationAutoDataAttribute : CustomAutoDataAttribute + { + public EfOrganizationAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganization()) + { } + } + + internal class InlineEfOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfOrganization) }, values) + { } } } - -internal class EfOrganization : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfOrganizationAutoDataAttribute : CustomAutoDataAttribute -{ - public EfOrganizationAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganization()) - { } -} - -internal class InlineEfOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfOrganization) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs index ede2d2129a..c4b97ad4e1 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs @@ -5,51 +5,52 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class OrganizationSponsorshipBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class OrganizationSponsorshipBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(OrganizationSponsorship)) + var type = request as Type; + if (type == null || type != typeof(OrganizationSponsorship)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfOrganizationSponsorship : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationSponsorshipBuilder()); + fixture.Customizations.Add(new OrganizationUserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfOrganizationSponsorshipAutoDataAttribute : CustomAutoDataAttribute + { + public EfOrganizationSponsorshipAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationSponsorship(), new EfOrganization()) + { } + } + + internal class InlineEfOrganizationSponsorshipAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfOrganizationSponsorshipAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfOrganizationSponsorship), typeof(EfOrganization) }, values) + { } } } - -internal class EfOrganizationSponsorship : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationSponsorshipBuilder()); - fixture.Customizations.Add(new OrganizationUserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfOrganizationSponsorshipAutoDataAttribute : CustomAutoDataAttribute -{ - public EfOrganizationSponsorshipAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationSponsorship(), new EfOrganization()) - { } -} - -internal class InlineEfOrganizationSponsorshipAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfOrganizationSponsorshipAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfOrganizationSponsorship), typeof(EfOrganization) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs index c457a463d9..1ae72117e7 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs @@ -11,72 +11,73 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class OrganizationUserBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class OrganizationUserBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == typeof(OrganizationUserCustomization)) - { - var fixture = new Fixture(); - var orgUser = fixture.WithAutoNSubstitutions().Create(); - var orgUserPermissions = fixture.WithAutoNSubstitutions().Create(); - orgUser.Permissions = JsonSerializer.Serialize(orgUserPermissions, new JsonSerializerOptions() + var type = request as Type; + if (type == typeof(OrganizationUserCustomization)) { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - return orgUser; - } - else if (type == typeof(List)) - { - var fixture = new Fixture(); - var orgUsers = fixture.WithAutoNSubstitutions().CreateMany(2); - foreach (var orgUser in orgUsers) - { - var providers = fixture.Create>(); + var fixture = new Fixture(); + var orgUser = fixture.WithAutoNSubstitutions().Create(); var orgUserPermissions = fixture.WithAutoNSubstitutions().Create(); orgUser.Permissions = JsonSerializer.Serialize(orgUserPermissions, new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, }); + return orgUser; } - return orgUsers; + else if (type == typeof(List)) + { + var fixture = new Fixture(); + var orgUsers = fixture.WithAutoNSubstitutions().CreateMany(2); + foreach (var orgUser in orgUsers) + { + var providers = fixture.Create>(); + var orgUserPermissions = fixture.WithAutoNSubstitutions().Create(); + orgUser.Permissions = JsonSerializer.Serialize(orgUserPermissions, new JsonSerializerOptions() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + } + return orgUsers; + } + return new NoSpecimen(); } - return new NoSpecimen(); } -} -internal class EfOrganizationUser : ICustomization -{ - public void Customize(IFixture fixture) + internal class EfOrganizationUser : ICustomization { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationUserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationUserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } + } + + internal class EfOrganizationUserAutoDataAttribute : CustomAutoDataAttribute + { + public EfOrganizationUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationUser()) + { } + } + + internal class InlineEfOrganizationUserAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfOrganizationUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfOrganizationUser) }, values) + { } } } - -internal class EfOrganizationUserAutoDataAttribute : CustomAutoDataAttribute -{ - public EfOrganizationUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationUser()) - { } -} - -internal class InlineEfOrganizationUserAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfOrganizationUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfOrganizationUser) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs index 70cea3e011..0b6424d549 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs @@ -5,75 +5,76 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class PolicyBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class PolicyBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Policy)) + var type = request as Type; + if (type == null || type != typeof(Policy)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfPolicy : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new PolicyBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } - - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; } -} -internal class EfPolicy : ICustomization -{ - public void Customize(IFixture fixture) + internal class EfPolicyApplicableToUser : ICustomization { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new PolicyBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new PolicyBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } } -} -internal class EfPolicyApplicableToUser : ICustomization -{ - public void Customize(IFixture fixture) + internal class EfPolicyAutoDataAttribute : CustomAutoDataAttribute { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new PolicyBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + public EfPolicyAutoDataAttribute() : base(new SutProviderCustomization(), new EfPolicy()) + { } + } + + internal class EfPolicyApplicableToUserInlineAutoDataAttribute : InlineCustomAutoDataAttribute + { + public EfPolicyApplicableToUserInlineAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), typeof(EfPolicyApplicableToUser) }, values) + { } + } + + internal class InlineEfPolicyAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfPolicyAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfPolicy) }, values) + { } } } - -internal class EfPolicyAutoDataAttribute : CustomAutoDataAttribute -{ - public EfPolicyAutoDataAttribute() : base(new SutProviderCustomization(), new EfPolicy()) - { } -} - -internal class EfPolicyApplicableToUserInlineAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public EfPolicyApplicableToUserInlineAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), typeof(EfPolicyApplicableToUser) }, values) - { } -} - -internal class InlineEfPolicyAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfPolicyAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfPolicy) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs index 75f03e34bc..e2a3812cc9 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs @@ -2,39 +2,40 @@ using System.Reflection; using AutoFixture.Kernel; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture.Relays; - -// Creates a string the same length as any availible MaxLength data annotation -// Modified version of the StringLenfthRelay provided by AutoFixture -// https://github.com/AutoFixture/AutoFixture/blob/master/Src/AutoFixture/DataAnnotations/StringLengthAttributeRelay.cs -public class MaxLengthStringRelay : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture.Relays { - public object Create(object request, ISpecimenContext context) + // Creates a string the same length as any availible MaxLength data annotation + // Modified version of the StringLenfthRelay provided by AutoFixture + // https://github.com/AutoFixture/AutoFixture/blob/master/Src/AutoFixture/DataAnnotations/StringLengthAttributeRelay.cs + public class MaxLengthStringRelay : ISpecimenBuilder { - if (request == null) + public object Create(object request, ISpecimenContext context) { - return new NoSpecimen(); + if (request == null) + { + return new NoSpecimen(); + } + + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var p = request as PropertyInfo; + if (p == null) + { + return new NoSpecimen(); + } + + var a = (MaxLengthAttribute)p.GetCustomAttributes(typeof(MaxLengthAttribute), false).SingleOrDefault(); + + if (a == null) + { + return new NoSpecimen(); + } + + return context.Resolve(new ConstrainedStringRequest(a.Length, a.Length)); } - - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var p = request as PropertyInfo; - if (p == null) - { - return new NoSpecimen(); - } - - var a = (MaxLengthAttribute)p.GetCustomAttributes(typeof(MaxLengthAttribute), false).SingleOrDefault(); - - if (a == null) - { - return new NoSpecimen(); - } - - return context.Resolve(new ConstrainedStringRequest(a.Length, a.Length)); } } diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs index 162bdf6e5b..222ea4ac0e 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs @@ -7,63 +7,64 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class SendBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public bool OrganizationOwned { get; set; } - public object Create(object request, ISpecimenContext context) + internal class SendBuilder : ISpecimenBuilder { - if (context == null) + public bool OrganizationOwned { get; set; } + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Send)) - { - return new NoSpecimen(); - } + var type = request as Type; + if (type == null || type != typeof(Send)) + { + return new NoSpecimen(); + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - if (!OrganizationOwned) - { - fixture.Customize(composer => composer - .Without(c => c.OrganizationId)); + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + if (!OrganizationOwned) + { + fixture.Customize(composer => composer + .Without(c => c.OrganizationId)); + } + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + } + + internal class EfSend : ICustomization + { + public bool OrganizationOwned { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new SendBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } + } + + internal class EfUserSendAutoDataAttribute : CustomAutoDataAttribute + { + public EfUserSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend()) + { } + } + + internal class EfOrganizationSendAutoDataAttribute : CustomAutoDataAttribute + { + public EfOrganizationSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend() + { + OrganizationOwned = true, + }) + { } } } - -internal class EfSend : ICustomization -{ - public bool OrganizationOwned { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new SendBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfUserSendAutoDataAttribute : CustomAutoDataAttribute -{ - public EfUserSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend()) - { } -} - -internal class EfOrganizationSendAutoDataAttribute : CustomAutoDataAttribute -{ - public EfOrganizationSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend() - { - OrganizationOwned = true, - }) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs index 4cad2154f8..83f3064f35 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs @@ -6,53 +6,54 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class SsoConfigBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class SsoConfigBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(SsoConfig)) + var type = request as Type; + if (type == null || type != typeof(SsoConfig)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var ssoConfig = fixture.WithAutoNSubstitutions().Create(); + var ssoConfigData = fixture.WithAutoNSubstitutions().Create(); + ssoConfig.SetData(ssoConfigData); + return ssoConfig; + } + } + + internal class EfSsoConfig : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new SsoConfigBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - var ssoConfig = fixture.WithAutoNSubstitutions().Create(); - var ssoConfigData = fixture.WithAutoNSubstitutions().Create(); - ssoConfig.SetData(ssoConfigData); - return ssoConfig; + internal class EfSsoConfigAutoDataAttribute : CustomAutoDataAttribute + { + public EfSsoConfigAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoConfig()) + { } + } + + internal class InlineEfSsoConfigAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfSsoConfigAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfSsoConfig) }, values) + { } } } - -internal class EfSsoConfig : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new SsoConfigBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfSsoConfigAutoDataAttribute : CustomAutoDataAttribute -{ - public EfSsoConfigAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoConfig()) - { } -} - -internal class InlineEfSsoConfigAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfSsoConfigAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfSsoConfig) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs index f2712e0186..32b6ddf247 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs @@ -5,32 +5,33 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class EfSsoUser : ICustomization +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public void Customize(IFixture fixture) + internal class EfSsoUser : ICustomization { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customize(composer => composer.Without(ou => ou.Id)); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customize(composer => composer.Without(ou => ou.Id)); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } + } + + internal class EfSsoUserAutoDataAttribute : CustomAutoDataAttribute + { + public EfSsoUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoUser()) + { } + } + + internal class InlineEfSsoUserAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfSsoUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfSsoUser) }, values) + { } } } - -internal class EfSsoUserAutoDataAttribute : CustomAutoDataAttribute -{ - public EfSsoUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoUser()) - { } -} - -internal class InlineEfSsoUserAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfSsoUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfSsoUser) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs index c8cd8c692c..b22c6d8c26 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs @@ -6,51 +6,52 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class TaxRateBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public object Create(object request, ISpecimenContext context) + internal class TaxRateBuilder : ISpecimenBuilder { - if (context == null) + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(TaxRate)) + var type = request as Type; + if (type == null || type != typeof(TaxRate)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; + } + } + + internal class EfTaxRate : ICustomization + { + public void Customize(IFixture fixture) { - return new NoSpecimen(); + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new TaxRateBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } + } - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + internal class EfTaxRateAutoDataAttribute : CustomAutoDataAttribute + { + public EfTaxRateAutoDataAttribute() : base(new SutProviderCustomization(), new EfTaxRate()) + { } + } + + internal class InlineEfTaxRateAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfTaxRateAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfTaxRate) }, values) + { } } } -internal class EfTaxRate : ICustomization -{ - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new TaxRateBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfTaxRateAutoDataAttribute : CustomAutoDataAttribute -{ - public EfTaxRateAutoDataAttribute() : base(new SutProviderCustomization(), new EfTaxRate()) - { } -} - -internal class InlineEfTaxRateAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfTaxRateAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfTaxRate) }, values) - { } -} - diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs index 7dbe42fc11..437cdcd2a5 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs @@ -7,63 +7,64 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class TransactionBuilder : ISpecimenBuilder +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public bool OrganizationOwned { get; set; } - public object Create(object request, ISpecimenContext context) + internal class TransactionBuilder : ISpecimenBuilder { - if (context == null) + public bool OrganizationOwned { get; set; } + public object Create(object request, ISpecimenContext context) { - throw new ArgumentNullException(nameof(context)); - } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == null || type != typeof(Transaction)) - { - return new NoSpecimen(); - } + var type = request as Type; + if (type == null || type != typeof(Transaction)) + { + return new NoSpecimen(); + } - var fixture = new Fixture(); - if (!OrganizationOwned) - { - fixture.Customize(composer => composer - .Without(c => c.OrganizationId)); + var fixture = new Fixture(); + if (!OrganizationOwned) + { + fixture.Customize(composer => composer + .Without(c => c.OrganizationId)); + } + fixture.Customizations.Add(new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } - fixture.Customizations.Add(new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + } + + internal class EfTransaction : ICustomization + { + public bool OrganizationOwned { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new TransactionBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } + } + + internal class EfUserTransactionAutoDataAttribute : CustomAutoDataAttribute + { + public EfUserTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction()) + { } + } + + internal class EfOrganizationTransactionAutoDataAttribute : CustomAutoDataAttribute + { + public EfOrganizationTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction() + { + OrganizationOwned = true, + }) + { } } } - -internal class EfTransaction : ICustomization -{ - public bool OrganizationOwned { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new TransactionBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } -} - -internal class EfUserTransactionAutoDataAttribute : CustomAutoDataAttribute -{ - public EfUserTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction()) - { } -} - -internal class EfOrganizationTransactionAutoDataAttribute : CustomAutoDataAttribute -{ - public EfOrganizationTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction() - { - OrganizationOwned = true, - }) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs index 98222e8f32..f54b7b758d 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs @@ -3,29 +3,30 @@ using Bit.Core.Test.AutoFixture.UserFixtures; using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - -internal class EfUser : UserFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture { - public override void Customize(IFixture fixture) + internal class EfUser : UserFixture { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - base.Customize(fixture); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + public override void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + base.Customize(fixture); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } + } + + internal class EfUserAutoDataAttribute : CustomAutoDataAttribute + { + public EfUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfUser()) + { } + } + + internal class InlineEfUserAutoDataAttribute : InlineCustomAutoDataAttribute + { + public InlineEfUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfUser) }, values) + { } } } - -internal class EfUserAutoDataAttribute : CustomAutoDataAttribute -{ - public EfUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfUser()) - { } -} - -internal class InlineEfUserAutoDataAttribute : InlineCustomAutoDataAttribute -{ - public InlineEfUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfUser) }, values) - { } -} diff --git a/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs b/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs index fbf0d98286..25ac5912b1 100644 --- a/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs +++ b/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs @@ -2,24 +2,25 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Microsoft.EntityFrameworkCore; -namespace Bit.Infrastructure.EFIntegration.Test.Helpers; - -public static class DatabaseOptionsFactory +namespace Bit.Infrastructure.EFIntegration.Test.Helpers { - public static List> Options { get; } = new(); - - static DatabaseOptionsFactory() + public static class DatabaseOptionsFactory { - var globalSettings = GlobalSettingsFactory.GlobalSettings; - if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.PostgreSql?.ConnectionString)) + public static List> Options { get; } = new(); + + static DatabaseOptionsFactory() { - AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); - Options.Add(new DbContextOptionsBuilder().UseNpgsql(globalSettings.PostgreSql.ConnectionString).Options); - } - if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.MySql?.ConnectionString)) - { - var mySqlConnectionString = globalSettings.MySql.ConnectionString; - Options.Add(new DbContextOptionsBuilder().UseMySql(mySqlConnectionString, ServerVersion.AutoDetect(mySqlConnectionString)).Options); + var globalSettings = GlobalSettingsFactory.GlobalSettings; + if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.PostgreSql?.ConnectionString)) + { + AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); + Options.Add(new DbContextOptionsBuilder().UseNpgsql(globalSettings.PostgreSql.ConnectionString).Options); + } + if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.MySql?.ConnectionString)) + { + var mySqlConnectionString = globalSettings.MySql.ConnectionString; + Options.Add(new DbContextOptionsBuilder().UseMySql(mySqlConnectionString, ServerVersion.AutoDetect(mySqlConnectionString)).Options); + } } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs index 9b70bffe76..21e9f4ee17 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs @@ -9,183 +9,184 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class CipherRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [Theory(Skip = "Run ad-hoc"), EfUserCipherAutoData] - public async void RefreshDb(List suts) + public class CipherRepositoryTests { - foreach (var sut in suts) + [Theory(Skip = "Run ad-hoc"), EfUserCipherAutoData] + public async void RefreshDb(List suts) { - await sut.RefreshDb(); - } - } - - [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] - public async void CreateAsync_Works_DataMatches(Cipher cipher, User user, Organization org, - CipherCompare equalityComparer, List suts, List efUserRepos, - List efOrgRepos, SqlRepo.CipherRepository sqlCipherRepo, - SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo) - { - var savedCiphers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var efUser = await efUserRepos[i].CreateAsync(user); - sut.ClearChangeTracking(); - cipher.UserId = efUser.Id; - - if (cipher.OrganizationId.HasValue) + foreach (var sut in suts) { - var efOrg = await efOrgRepos[i].CreateAsync(org); + await sut.RefreshDb(); + } + } + + [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] + public async void CreateAsync_Works_DataMatches(Cipher cipher, User user, Organization org, + CipherCompare equalityComparer, List suts, List efUserRepos, + List efOrgRepos, SqlRepo.CipherRepository sqlCipherRepo, + SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo) + { + var savedCiphers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var efUser = await efUserRepos[i].CreateAsync(user); sut.ClearChangeTracking(); - cipher.OrganizationId = efOrg.Id; - } + cipher.UserId = efUser.Id; - var postEfCipher = await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); - - var savedCipher = await sut.GetByIdAsync(postEfCipher.Id); - savedCiphers.Add(savedCipher); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - cipher.UserId = sqlUser.Id; - - if (cipher.OrganizationId.HasValue) - { - var sqlOrg = await sqlOrgRepo.CreateAsync(org); - cipher.OrganizationId = sqlOrg.Id; - } - - var sqlCipher = await sqlCipherRepo.CreateAsync(cipher); - var savedSqlCipher = await sqlCipherRepo.GetByIdAsync(sqlCipher.Id); - savedCiphers.Add(savedSqlCipher); - - var distinctItems = savedCiphers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserCipherAutoData] - public async void CreateAsync_BumpsUserAccountRevisionDate(Cipher cipher, User user, List suts, List efUserRepos) - { - var bumpedUsers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var efUser = await efUserRepos[i].CreateAsync(user); - efUserRepos[i].ClearChangeTracking(); - cipher.UserId = efUser.Id; - cipher.OrganizationId = null; - - var postEfCipher = await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); - - var bumpedUser = await efUserRepos[i].GetByIdAsync(efUser.Id); - bumpedUsers.Add(bumpedUser); - } - - Assert.True(bumpedUsers.All(u => u.AccountRevisionDate.ToShortDateString() == DateTime.UtcNow.ToShortDateString())); - } - - [CiSkippedTheory, EfOrganizationCipherAutoData] - public async void CreateAsync_BumpsOrgUserAccountRevisionDates(Cipher cipher, List users, - List orgUsers, Collection collection, Organization org, List suts, List efUserRepos, List efOrgRepos, - List efOrgUserRepos, List efCollectionRepos) - { - var savedCiphers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var efUsers = await efUserRepos[i].CreateMany(users); - efUserRepos[i].ClearChangeTracking(); - var efOrg = await efOrgRepos[i].CreateAsync(org); - efOrgRepos[i].ClearChangeTracking(); - - cipher.OrganizationId = efOrg.Id; - - collection.OrganizationId = efOrg.Id; - var efCollection = await efCollectionRepos[i].CreateAsync(collection); - efCollectionRepos[i].ClearChangeTracking(); - - IEnumerable[] lists = { efUsers, orgUsers }; - var maxOrgUsers = lists.Min(l => l.Count()); - - orgUsers = orgUsers.Take(maxOrgUsers).ToList(); - efUsers = efUsers.Take(maxOrgUsers).ToList(); - - for (var j = 0; j < maxOrgUsers; j++) - { - orgUsers[j].OrganizationId = efOrg.Id; - orgUsers[j].UserId = efUsers[j].Id; - } - - orgUsers = await efOrgUserRepos[i].CreateMany(orgUsers); - - var selectionReadOnlyList = new List(); - orgUsers.ForEach(ou => selectionReadOnlyList.Add(new SelectionReadOnly() { Id = ou.Id })); - - await efCollectionRepos[i].UpdateUsersAsync(efCollection.Id, selectionReadOnlyList); - efCollectionRepos[i].ClearChangeTracking(); - - foreach (var ou in orgUsers) - { - var collectionUser = new CollectionUser() + if (cipher.OrganizationId.HasValue) { - CollectionId = efCollection.Id, - OrganizationUserId = ou.Id - }; + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); + cipher.OrganizationId = efOrg.Id; + } + + var postEfCipher = await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + var savedCipher = await sut.GetByIdAsync(postEfCipher.Id); + savedCiphers.Add(savedCipher); } - cipher.UserId = null; - var postEfCipher = await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); - - var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); - var modifiedUsers = await sut.Run(query).ToListAsync(); - Assert.True(modifiedUsers - .All(u => u.AccountRevisionDate.ToShortDateString() == - DateTime.UtcNow.ToShortDateString())); - } - } - - [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] - public async void DeleteAsync_CipherIsDeleted( - Cipher cipher, - User user, - Organization org, - List suts, - List efUserRepos, - List efOrgRepos - ) - { - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - efOrgRepos[i].ClearChangeTracking(); - var postEfUser = await efUserRepos[i].CreateAsync(user); - efUserRepos[i].ClearChangeTracking(); + var sqlUser = await sqlUserRepo.CreateAsync(user); + cipher.UserId = sqlUser.Id; if (cipher.OrganizationId.HasValue) { - cipher.OrganizationId = postEfOrg.Id; + var sqlOrg = await sqlOrgRepo.CreateAsync(org); + cipher.OrganizationId = sqlOrg.Id; } - cipher.UserId = postEfUser.Id; - await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); + var sqlCipher = await sqlCipherRepo.CreateAsync(cipher); + var savedSqlCipher = await sqlCipherRepo.GetByIdAsync(sqlCipher.Id); + savedCiphers.Add(savedSqlCipher); - await sut.DeleteAsync(cipher); - sut.ClearChangeTracking(); + var distinctItems = savedCiphers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } - var savedCipher = await sut.GetByIdAsync(cipher.Id); - Assert.True(savedCipher == null); + [CiSkippedTheory, EfUserCipherAutoData] + public async void CreateAsync_BumpsUserAccountRevisionDate(Cipher cipher, User user, List suts, List efUserRepos) + { + var bumpedUsers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var efUser = await efUserRepos[i].CreateAsync(user); + efUserRepos[i].ClearChangeTracking(); + cipher.UserId = efUser.Id; + cipher.OrganizationId = null; + + var postEfCipher = await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + var bumpedUser = await efUserRepos[i].GetByIdAsync(efUser.Id); + bumpedUsers.Add(bumpedUser); + } + + Assert.True(bumpedUsers.All(u => u.AccountRevisionDate.ToShortDateString() == DateTime.UtcNow.ToShortDateString())); + } + + [CiSkippedTheory, EfOrganizationCipherAutoData] + public async void CreateAsync_BumpsOrgUserAccountRevisionDates(Cipher cipher, List users, + List orgUsers, Collection collection, Organization org, List suts, List efUserRepos, List efOrgRepos, + List efOrgUserRepos, List efCollectionRepos) + { + var savedCiphers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var efUsers = await efUserRepos[i].CreateMany(users); + efUserRepos[i].ClearChangeTracking(); + var efOrg = await efOrgRepos[i].CreateAsync(org); + efOrgRepos[i].ClearChangeTracking(); + + cipher.OrganizationId = efOrg.Id; + + collection.OrganizationId = efOrg.Id; + var efCollection = await efCollectionRepos[i].CreateAsync(collection); + efCollectionRepos[i].ClearChangeTracking(); + + IEnumerable[] lists = { efUsers, orgUsers }; + var maxOrgUsers = lists.Min(l => l.Count()); + + orgUsers = orgUsers.Take(maxOrgUsers).ToList(); + efUsers = efUsers.Take(maxOrgUsers).ToList(); + + for (var j = 0; j < maxOrgUsers; j++) + { + orgUsers[j].OrganizationId = efOrg.Id; + orgUsers[j].UserId = efUsers[j].Id; + } + + orgUsers = await efOrgUserRepos[i].CreateMany(orgUsers); + + var selectionReadOnlyList = new List(); + orgUsers.ForEach(ou => selectionReadOnlyList.Add(new SelectionReadOnly() { Id = ou.Id })); + + await efCollectionRepos[i].UpdateUsersAsync(efCollection.Id, selectionReadOnlyList); + efCollectionRepos[i].ClearChangeTracking(); + + foreach (var ou in orgUsers) + { + var collectionUser = new CollectionUser() + { + CollectionId = efCollection.Id, + OrganizationUserId = ou.Id + }; + } + + cipher.UserId = null; + var postEfCipher = await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); + var modifiedUsers = await sut.Run(query).ToListAsync(); + Assert.True(modifiedUsers + .All(u => u.AccountRevisionDate.ToShortDateString() == + DateTime.UtcNow.ToShortDateString())); + } + } + + [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] + public async void DeleteAsync_CipherIsDeleted( + Cipher cipher, + User user, + Organization org, + List suts, + List efUserRepos, + List efOrgRepos + ) + { + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + efOrgRepos[i].ClearChangeTracking(); + var postEfUser = await efUserRepos[i].CreateAsync(user); + efUserRepos[i].ClearChangeTracking(); + + if (cipher.OrganizationId.HasValue) + { + cipher.OrganizationId = postEfOrg.Id; + } + cipher.UserId = postEfUser.Id; + + await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + await sut.DeleteAsync(cipher); + sut.ClearChangeTracking(); + + var savedCipher = await sut.GetByIdAsync(cipher.Id); + Assert.True(savedCipher == null); + } } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs b/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs index 1fb20c684e..ed2bcf74be 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs @@ -6,44 +6,45 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class CollectionRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfCollectionAutoData] - public async void CreateAsync_Works_DataMatches( - Collection collection, - Organization organization, - CollectionCompare equalityComparer, - List suts, - List efOrganizationRepos, - SqlRepo.CollectionRepository sqlCollectionRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo - ) + public class CollectionRepositoryTests { - var savedCollections = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfCollectionAutoData] + public async void CreateAsync_Works_DataMatches( + Collection collection, + Organization organization, + CollectionCompare equalityComparer, + List suts, + List efOrganizationRepos, + SqlRepo.CollectionRepository sqlCollectionRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo + ) { - var i = suts.IndexOf(sut); - var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); - sut.ClearChangeTracking(); + var savedCollections = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); + sut.ClearChangeTracking(); - collection.OrganizationId = efOrganization.Id; - var postEfCollection = await sut.CreateAsync(collection); - sut.ClearChangeTracking(); + collection.OrganizationId = efOrganization.Id; + var postEfCollection = await sut.CreateAsync(collection); + sut.ClearChangeTracking(); - var savedCollection = await sut.GetByIdAsync(postEfCollection.Id); - savedCollections.Add(savedCollection); + var savedCollection = await sut.GetByIdAsync(postEfCollection.Id); + savedCollections.Add(savedCollection); + } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + collection.OrganizationId = sqlOrganization.Id; + + var sqlCollection = await sqlCollectionRepo.CreateAsync(collection); + var savedSqlCollection = await sqlCollectionRepo.GetByIdAsync(sqlCollection.Id); + savedCollections.Add(savedSqlCollection); + + var distinctItems = savedCollections.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - collection.OrganizationId = sqlOrganization.Id; - - var sqlCollection = await sqlCollectionRepo.CreateAsync(collection); - var savedSqlCollection = await sqlCollectionRepo.GetByIdAsync(sqlCollection.Id); - savedCollections.Add(savedSqlCollection); - - var distinctItems = savedCollections.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs index fc1f5c8b31..4c5de177c1 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs @@ -6,41 +6,42 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class DeviceRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfDeviceAutoData] - public async void CreateAsync_Works_DataMatches(Device device, User user, - DeviceCompare equalityComparer, List suts, - List efUserRepos, SqlRepo.DeviceRepository sqlDeviceRepo, - SqlRepo.UserRepository sqlUserRepo) + public class DeviceRepositoryTests { - var savedDevices = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfDeviceAutoData] + public async void CreateAsync_Works_DataMatches(Device device, User user, + DeviceCompare equalityComparer, List suts, + List efUserRepos, SqlRepo.DeviceRepository sqlDeviceRepo, + SqlRepo.UserRepository sqlUserRepo) { - var i = suts.IndexOf(sut); + var savedDevices = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - device.UserId = efUser.Id; - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + device.UserId = efUser.Id; + sut.ClearChangeTracking(); - var postEfDevice = await sut.CreateAsync(device); - sut.ClearChangeTracking(); + var postEfDevice = await sut.CreateAsync(device); + sut.ClearChangeTracking(); - var savedDevice = await sut.GetByIdAsync(postEfDevice.Id); - savedDevices.Add(savedDevice); + var savedDevice = await sut.GetByIdAsync(postEfDevice.Id); + savedDevices.Add(savedDevice); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + device.UserId = sqlUser.Id; + + var sqlDevice = await sqlDeviceRepo.CreateAsync(device); + var savedSqlDevice = await sqlDeviceRepo.GetByIdAsync(sqlDevice.Id); + savedDevices.Add(savedSqlDevice); + + var distinctItems = savedDevices.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlUser = await sqlUserRepo.CreateAsync(user); - device.UserId = sqlUser.Id; - - var sqlDevice = await sqlDeviceRepo.CreateAsync(device); - var savedSqlDevice = await sqlDeviceRepo.GetByIdAsync(sqlDevice.Id); - savedDevices.Add(savedSqlDevice); - - var distinctItems = savedDevices.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } - } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs index d014d463a5..1bb31d4762 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs @@ -6,53 +6,54 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class EmergencyAccessRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfEmergencyAccessAutoData] - public async void CreateAsync_Works_DataMatches( - EmergencyAccess emergencyAccess, - List users, - EmergencyAccessCompare equalityComparer, - List suts, - List efUserRepos, - SqlRepo.EmergencyAccessRepository sqlEmergencyAccessRepo, - SqlRepo.UserRepository sqlUserRepo - ) + public class EmergencyAccessRepositoryTests { - var savedEmergencyAccesss = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfEmergencyAccessAutoData] + public async void CreateAsync_Works_DataMatches( + EmergencyAccess emergencyAccess, + List users, + EmergencyAccessCompare equalityComparer, + List suts, + List efUserRepos, + SqlRepo.EmergencyAccessRepository sqlEmergencyAccessRepo, + SqlRepo.UserRepository sqlUserRepo + ) { - var i = suts.IndexOf(sut); + var savedEmergencyAccesss = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + for (int j = 0; j < users.Count; j++) + { + users[j] = await efUserRepos[i].CreateAsync(users[j]); + } + sut.ClearChangeTracking(); + + emergencyAccess.GrantorId = users[0].Id; + emergencyAccess.GranteeId = users[0].Id; + var postEfEmergencyAccess = await sut.CreateAsync(emergencyAccess); + sut.ClearChangeTracking(); + + var savedEmergencyAccess = await sut.GetByIdAsync(postEfEmergencyAccess.Id); + savedEmergencyAccesss.Add(savedEmergencyAccess); + } for (int j = 0; j < users.Count; j++) { - users[j] = await efUserRepos[i].CreateAsync(users[j]); + users[j] = await sqlUserRepo.CreateAsync(users[j]); } - sut.ClearChangeTracking(); emergencyAccess.GrantorId = users[0].Id; emergencyAccess.GranteeId = users[0].Id; - var postEfEmergencyAccess = await sut.CreateAsync(emergencyAccess); - sut.ClearChangeTracking(); + var sqlEmergencyAccess = await sqlEmergencyAccessRepo.CreateAsync(emergencyAccess); + var savedSqlEmergencyAccess = await sqlEmergencyAccessRepo.GetByIdAsync(sqlEmergencyAccess.Id); + savedEmergencyAccesss.Add(savedSqlEmergencyAccess); - var savedEmergencyAccess = await sut.GetByIdAsync(postEfEmergencyAccess.Id); - savedEmergencyAccesss.Add(savedEmergencyAccess); + var distinctItems = savedEmergencyAccesss.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - for (int j = 0; j < users.Count; j++) - { - users[j] = await sqlUserRepo.CreateAsync(users[j]); - } - - emergencyAccess.GrantorId = users[0].Id; - emergencyAccess.GranteeId = users[0].Id; - var sqlEmergencyAccess = await sqlEmergencyAccessRepo.CreateAsync(emergencyAccess); - var savedSqlEmergencyAccess = await sqlEmergencyAccessRepo.GetByIdAsync(sqlEmergencyAccess.Id); - savedEmergencyAccesss.Add(savedSqlEmergencyAccess); - - var distinctItems = savedEmergencyAccesss.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs index 230b51dd69..f5be069bdd 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs @@ -1,20 +1,21 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class CipherCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Cipher x, Cipher y) + public class CipherCompare : IEqualityComparer { - return x.Type == y.Type && - x.Data == y.Data && - x.Favorites == y.Favorites && - x.Attachments == y.Attachments; - } + public bool Equals(Cipher x, Cipher y) + { + return x.Type == y.Type && + x.Data == y.Data && + x.Favorites == y.Favorites && + x.Attachments == y.Attachments; + } - public int GetHashCode([DisallowNull] Cipher obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Cipher obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs index 56cb0acf7a..a7cef8f6d8 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs @@ -1,18 +1,19 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class CollectionCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Collection x, Collection y) + public class CollectionCompare : IEqualityComparer { - return x.Name == y.Name && - x.ExternalId == y.ExternalId; - } + public bool Equals(Collection x, Collection y) + { + return x.Name == y.Name && + x.ExternalId == y.ExternalId; + } - public int GetHashCode([DisallowNull] Collection obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Collection obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs index 086199b380..ac8a24d203 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs @@ -1,20 +1,21 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class DeviceCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Device x, Device y) + public class DeviceCompare : IEqualityComparer { - return x.Name == y.Name && - x.Type == y.Type && - x.Identifier == y.Identifier && - x.PushToken == y.PushToken; - } + public bool Equals(Device x, Device y) + { + return x.Name == y.Name && + x.Type == y.Type && + x.Identifier == y.Identifier && + x.PushToken == y.PushToken; + } - public int GetHashCode([DisallowNull] Device obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Device obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs index eb182d6e9a..bc2592f439 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs @@ -1,23 +1,24 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class EmergencyAccessCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(EmergencyAccess x, EmergencyAccess y) + public class EmergencyAccessCompare : IEqualityComparer { - return x.Email == y.Email && - x.KeyEncrypted == y.KeyEncrypted && - x.Type == y.Type && - x.Status == y.Status && - x.WaitTimeDays == y.WaitTimeDays && - x.RecoveryInitiatedDate == y.RecoveryInitiatedDate && - x.LastNotificationDate == y.LastNotificationDate; - } + public bool Equals(EmergencyAccess x, EmergencyAccess y) + { + return x.Email == y.Email && + x.KeyEncrypted == y.KeyEncrypted && + x.Type == y.Type && + x.Status == y.Status && + x.WaitTimeDays == y.WaitTimeDays && + x.RecoveryInitiatedDate == y.RecoveryInitiatedDate && + x.LastNotificationDate == y.LastNotificationDate; + } - public int GetHashCode([DisallowNull] EmergencyAccess obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] EmergencyAccess obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs index e414f7c253..a42f8cb5ed 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs @@ -1,19 +1,20 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class EventCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Event x, Event y) + public class EventCompare : IEqualityComparer { - return x.Date.ToShortDateString() == y.Date.ToShortDateString() && - x.Type == y.Type && - x.IpAddress == y.IpAddress; - } + public bool Equals(Event x, Event y) + { + return x.Date.ToShortDateString() == y.Date.ToShortDateString() && + x.Type == y.Type && + x.IpAddress == y.IpAddress; + } - public int GetHashCode([DisallowNull] Event obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Event obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs index 2bdb71385c..61e261f8a7 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs @@ -1,17 +1,18 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class FolderCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Folder x, Folder y) + public class FolderCompare : IEqualityComparer { - return x.Name == y.Name; - } + public bool Equals(Folder x, Folder y) + { + return x.Name == y.Name; + } - public int GetHashCode([DisallowNull] Folder obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Folder obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs index 7621577166..978d4d62dd 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs @@ -1,24 +1,25 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class GrantCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Grant x, Grant y) + public class GrantCompare : IEqualityComparer { - return x.Key == y.Key && - x.Type == y.Type && - x.SubjectId == y.SubjectId && - x.ClientId == y.ClientId && - x.Description == y.Description && - x.ExpirationDate == y.ExpirationDate && - x.ConsumedDate == y.ConsumedDate && - x.Data == y.Data; - } + public bool Equals(Grant x, Grant y) + { + return x.Key == y.Key && + x.Type == y.Type && + x.SubjectId == y.SubjectId && + x.ClientId == y.ClientId && + x.Description == y.Description && + x.ExpirationDate == y.ExpirationDate && + x.ConsumedDate == y.ConsumedDate && + x.Data == y.Data; + } - public int GetHashCode([DisallowNull] Grant obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Grant obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs index dcb0be2ff1..aa2e1ae898 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs @@ -1,19 +1,20 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class GroupCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Group x, Group y) + public class GroupCompare : IEqualityComparer { - return x.Name == y.Name && - x.AccessAll == y.AccessAll && - x.ExternalId == y.ExternalId; - } + public bool Equals(Group x, Group y) + { + return x.Name == y.Name && + x.AccessAll == y.AccessAll && + x.ExternalId == y.ExternalId; + } - public int GetHashCode([DisallowNull] Group obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Group obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs index 7794785b31..38a92daa37 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs @@ -1,19 +1,20 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class InstallationCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Installation x, Installation y) + public class InstallationCompare : IEqualityComparer { - return x.Email == y.Email && - x.Key == y.Key && - x.Enabled == y.Enabled; - } + public bool Equals(Installation x, Installation y) + { + return x.Email == y.Email && + x.Key == y.Key && + x.Enabled == y.Enabled; + } - public int GetHashCode([DisallowNull] Installation obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Installation obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs index f1879937af..a8f32643eb 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs @@ -1,53 +1,54 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class OrganizationCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Organization x, Organization y) + public class OrganizationCompare : IEqualityComparer { - var a = x.ExpirationDate.ToString(); - var b = y.ExpirationDate.ToString(); - return x.Identifier.Equals(y.Identifier) && - x.Name.Equals(y.Name) && - x.BusinessName.Equals(y.BusinessName) && - x.BusinessAddress1.Equals(y.BusinessAddress1) && - x.BusinessAddress2.Equals(y.BusinessAddress2) && - x.BusinessAddress3.Equals(y.BusinessAddress3) && - x.BusinessCountry.Equals(y.BusinessCountry) && - x.BusinessTaxNumber.Equals(y.BusinessTaxNumber) && - x.BillingEmail.Equals(y.BillingEmail) && - x.Plan.Equals(y.Plan) && - x.PlanType.Equals(y.PlanType) && - x.Seats.Equals(y.Seats) && - x.MaxCollections.Equals(y.MaxCollections) && - x.UsePolicies.Equals(y.UsePolicies) && - x.UseSso.Equals(y.UseSso) && - x.UseKeyConnector.Equals(y.UseKeyConnector) && - x.UseScim.Equals(y.UseScim) && - x.UseGroups.Equals(y.UseGroups) && - x.UseDirectory.Equals(y.UseDirectory) && - x.UseEvents.Equals(y.UseEvents) && - x.UseTotp.Equals(y.UseTotp) && - x.Use2fa.Equals(y.Use2fa) && - x.UseApi.Equals(y.UseApi) && - x.SelfHost.Equals(y.SelfHost) && - x.UsersGetPremium.Equals(y.UsersGetPremium) && - x.Storage.Equals(y.Storage) && - x.MaxStorageGb.Equals(y.MaxStorageGb) && - x.Gateway.Equals(y.Gateway) && - x.GatewayCustomerId.Equals(y.GatewayCustomerId) && - x.GatewaySubscriptionId.Equals(y.GatewaySubscriptionId) && - x.ReferenceData.Equals(y.ReferenceData) && - x.Enabled.Equals(y.Enabled) && - x.LicenseKey.Equals(y.LicenseKey) && - x.TwoFactorProviders.Equals(y.TwoFactorProviders) && - x.ExpirationDate.ToString().Equals(y.ExpirationDate.ToString()); - } + public bool Equals(Organization x, Organization y) + { + var a = x.ExpirationDate.ToString(); + var b = y.ExpirationDate.ToString(); + return x.Identifier.Equals(y.Identifier) && + x.Name.Equals(y.Name) && + x.BusinessName.Equals(y.BusinessName) && + x.BusinessAddress1.Equals(y.BusinessAddress1) && + x.BusinessAddress2.Equals(y.BusinessAddress2) && + x.BusinessAddress3.Equals(y.BusinessAddress3) && + x.BusinessCountry.Equals(y.BusinessCountry) && + x.BusinessTaxNumber.Equals(y.BusinessTaxNumber) && + x.BillingEmail.Equals(y.BillingEmail) && + x.Plan.Equals(y.Plan) && + x.PlanType.Equals(y.PlanType) && + x.Seats.Equals(y.Seats) && + x.MaxCollections.Equals(y.MaxCollections) && + x.UsePolicies.Equals(y.UsePolicies) && + x.UseSso.Equals(y.UseSso) && + x.UseKeyConnector.Equals(y.UseKeyConnector) && + x.UseScim.Equals(y.UseScim) && + x.UseGroups.Equals(y.UseGroups) && + x.UseDirectory.Equals(y.UseDirectory) && + x.UseEvents.Equals(y.UseEvents) && + x.UseTotp.Equals(y.UseTotp) && + x.Use2fa.Equals(y.Use2fa) && + x.UseApi.Equals(y.UseApi) && + x.SelfHost.Equals(y.SelfHost) && + x.UsersGetPremium.Equals(y.UsersGetPremium) && + x.Storage.Equals(y.Storage) && + x.MaxStorageGb.Equals(y.MaxStorageGb) && + x.Gateway.Equals(y.Gateway) && + x.GatewayCustomerId.Equals(y.GatewayCustomerId) && + x.GatewaySubscriptionId.Equals(y.GatewaySubscriptionId) && + x.ReferenceData.Equals(y.ReferenceData) && + x.Enabled.Equals(y.Enabled) && + x.LicenseKey.Equals(y.LicenseKey) && + x.TwoFactorProviders.Equals(y.TwoFactorProviders) && + x.ExpirationDate.ToString().Equals(y.ExpirationDate.ToString()); + } - public int GetHashCode([DisallowNull] Organization obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Organization obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs index e17e765922..c90aaf065d 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs @@ -1,22 +1,23 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class OrganizationSponsorshipCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(OrganizationSponsorship x, OrganizationSponsorship y) + public class OrganizationSponsorshipCompare : IEqualityComparer { - return x.SponsoringOrganizationId.Equals(y.SponsoringOrganizationId) && - x.SponsoringOrganizationUserId.Equals(y.SponsoringOrganizationUserId) && - x.SponsoredOrganizationId.Equals(y.SponsoredOrganizationId) && - x.OfferedToEmail.Equals(y.OfferedToEmail) && - x.ToDelete.Equals(y.ToDelete) && - x.ValidUntil.ToString().Equals(y.ValidUntil.ToString()); - } + public bool Equals(OrganizationSponsorship x, OrganizationSponsorship y) + { + return x.SponsoringOrganizationId.Equals(y.SponsoringOrganizationId) && + x.SponsoringOrganizationUserId.Equals(y.SponsoringOrganizationUserId) && + x.SponsoredOrganizationId.Equals(y.SponsoredOrganizationId) && + x.OfferedToEmail.Equals(y.OfferedToEmail) && + x.ToDelete.Equals(y.ToDelete) && + x.ValidUntil.ToString().Equals(y.ValidUntil.ToString()); + } - public int GetHashCode([DisallowNull] OrganizationSponsorship obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] OrganizationSponsorship obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs index 6d947cc6c7..bb7895a2f6 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs @@ -1,22 +1,23 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class OrganizationUserCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(OrganizationUser x, OrganizationUser y) + public class OrganizationUserCompare : IEqualityComparer { - return x.Email == y.Email && - x.Status == y.Status && - x.Type == y.Type && - x.AccessAll == y.AccessAll && - x.ExternalId == y.ExternalId && - x.Permissions == y.Permissions; - } + public bool Equals(OrganizationUser x, OrganizationUser y) + { + return x.Email == y.Email && + x.Status == y.Status && + x.Type == y.Type && + x.AccessAll == y.AccessAll && + x.ExternalId == y.ExternalId && + x.Permissions == y.Permissions; + } - public int GetHashCode([DisallowNull] OrganizationUser obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] OrganizationUser obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs index f3bd7dc7a9..758675c5a8 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs @@ -1,28 +1,29 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class PolicyCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Policy x, Policy y) + public class PolicyCompare : IEqualityComparer { - return x.Type == y.Type && - x.Data == y.Data && - x.Enabled == y.Enabled; + public bool Equals(Policy x, Policy y) + { + return x.Type == y.Type && + x.Data == y.Data && + x.Enabled == y.Enabled; + } + + public int GetHashCode([DisallowNull] Policy obj) + { + return base.GetHashCode(); + } } - public int GetHashCode([DisallowNull] Policy obj) + public class PolicyCompareIncludingOrganization : PolicyCompare { - return base.GetHashCode(); - } -} - -public class PolicyCompareIncludingOrganization : PolicyCompare -{ - public new bool Equals(Policy x, Policy y) - { - return base.Equals(x, y) && - x.OrganizationId == y.OrganizationId; + public new bool Equals(Policy x, Policy y) + { + return base.Equals(x, y) && + x.OrganizationId == y.OrganizationId; + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs index b4723051c6..7057997799 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs @@ -1,26 +1,27 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class SendCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Send x, Send y) + public class SendCompare : IEqualityComparer { - return x.Type == y.Type && - x.Data == y.Data && - x.Key == y.Key && - x.Password == y.Password && - x.MaxAccessCount == y.MaxAccessCount && - x.AccessCount == y.AccessCount && - x.ExpirationDate?.ToShortDateString() == y.ExpirationDate?.ToShortDateString() && - x.DeletionDate.ToShortDateString() == y.DeletionDate.ToShortDateString() && - x.Disabled == y.Disabled && - x.HideEmail == y.HideEmail; - } + public bool Equals(Send x, Send y) + { + return x.Type == y.Type && + x.Data == y.Data && + x.Key == y.Key && + x.Password == y.Password && + x.MaxAccessCount == y.MaxAccessCount && + x.AccessCount == y.AccessCount && + x.ExpirationDate?.ToShortDateString() == y.ExpirationDate?.ToShortDateString() && + x.DeletionDate.ToShortDateString() == y.DeletionDate.ToShortDateString() && + x.Disabled == y.Disabled && + x.HideEmail == y.HideEmail; + } - public int GetHashCode([DisallowNull] Send obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Send obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs index 766b8c6857..8d6accd86a 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs @@ -1,19 +1,20 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class SsoConfigCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(SsoConfig x, SsoConfig y) + public class SsoConfigCompare : IEqualityComparer { - return x.Enabled == y.Enabled && - x.OrganizationId == y.OrganizationId && - x.Data == y.Data; - } + public bool Equals(SsoConfig x, SsoConfig y) + { + return x.Enabled == y.Enabled && + x.OrganizationId == y.OrganizationId && + x.Data == y.Data; + } - public int GetHashCode([DisallowNull] SsoConfig obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] SsoConfig obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs index fffd512c6d..a500545143 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs @@ -1,17 +1,18 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class SsoUserCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(SsoUser x, SsoUser y) + public class SsoUserCompare : IEqualityComparer { - return x.ExternalId == y.ExternalId; - } + public bool Equals(SsoUser x, SsoUser y) + { + return x.ExternalId == y.ExternalId; + } - public int GetHashCode([DisallowNull] SsoUser obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] SsoUser obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs index ff3c0a600f..c2305b959d 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs @@ -1,21 +1,22 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class TaxRateCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(TaxRate x, TaxRate y) + public class TaxRateCompare : IEqualityComparer { - return x.Country == y.Country && - x.State == y.State && - x.PostalCode == y.PostalCode && - x.Rate == y.Rate && - x.Active == y.Active; - } + public bool Equals(TaxRate x, TaxRate y) + { + return x.Country == y.Country && + x.State == y.State && + x.PostalCode == y.PostalCode && + x.Rate == y.Rate && + x.Active == y.Active; + } - public int GetHashCode([DisallowNull] TaxRate obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] TaxRate obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs index fadcdf5b1d..2ce594ec4c 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs @@ -1,23 +1,24 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class TransactionCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(Transaction x, Transaction y) + public class TransactionCompare : IEqualityComparer { - return x.Type == y.Type && - x.Amount == y.Amount && - x.Refunded == y.Refunded && - x.Details == y.Details && - x.PaymentMethodType == y.PaymentMethodType && - x.Gateway == y.Gateway && - x.GatewayId == y.GatewayId; - } + public bool Equals(Transaction x, Transaction y) + { + return x.Type == y.Type && + x.Amount == y.Amount && + x.Refunded == y.Refunded && + x.Details == y.Details && + x.PaymentMethodType == y.PaymentMethodType && + x.Gateway == y.Gateway && + x.GatewayId == y.GatewayId; + } - public int GetHashCode([DisallowNull] Transaction obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] Transaction obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs index 90a6af51bd..311d4a01fe 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs @@ -1,39 +1,40 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class UserCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(User x, User y) + public class UserCompare : IEqualityComparer { - return x.Name == y.Name && - x.Email == y.Email && - x.EmailVerified == y.EmailVerified && - x.MasterPassword == y.MasterPassword && - x.MasterPasswordHint == y.MasterPasswordHint && - x.Culture == y.Culture && - x.SecurityStamp == y.SecurityStamp && - x.TwoFactorProviders == y.TwoFactorProviders && - x.TwoFactorRecoveryCode == y.TwoFactorRecoveryCode && - x.EquivalentDomains == y.EquivalentDomains && - x.Key == y.Key && - x.PublicKey == y.PublicKey && - x.PrivateKey == y.PrivateKey && - x.Premium == y.Premium && - x.Storage == y.Storage && - x.MaxStorageGb == y.MaxStorageGb && - x.Gateway == y.Gateway && - x.GatewayCustomerId == y.GatewayCustomerId && - x.ReferenceData == y.ReferenceData && - x.LicenseKey == y.LicenseKey && - x.ApiKey == y.ApiKey && - x.Kdf == y.Kdf && - x.KdfIterations == y.KdfIterations; - } + public bool Equals(User x, User y) + { + return x.Name == y.Name && + x.Email == y.Email && + x.EmailVerified == y.EmailVerified && + x.MasterPassword == y.MasterPassword && + x.MasterPasswordHint == y.MasterPasswordHint && + x.Culture == y.Culture && + x.SecurityStamp == y.SecurityStamp && + x.TwoFactorProviders == y.TwoFactorProviders && + x.TwoFactorRecoveryCode == y.TwoFactorRecoveryCode && + x.EquivalentDomains == y.EquivalentDomains && + x.Key == y.Key && + x.PublicKey == y.PublicKey && + x.PrivateKey == y.PrivateKey && + x.Premium == y.Premium && + x.Storage == y.Storage && + x.MaxStorageGb == y.MaxStorageGb && + x.Gateway == y.Gateway && + x.GatewayCustomerId == y.GatewayCustomerId && + x.ReferenceData == y.ReferenceData && + x.LicenseKey == y.LicenseKey && + x.ApiKey == y.ApiKey && + x.Kdf == y.Kdf && + x.KdfIterations == y.KdfIterations; + } - public int GetHashCode([DisallowNull] User obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] User obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs index 079d37c3fc..143903de38 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs @@ -1,18 +1,19 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - -public class UserKdfInformationCompare : IEqualityComparer +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers { - public bool Equals(UserKdfInformation x, UserKdfInformation y) + public class UserKdfInformationCompare : IEqualityComparer { - return x.Kdf == y.Kdf && - x.KdfIterations == y.KdfIterations; - } + public bool Equals(UserKdfInformation x, UserKdfInformation y) + { + return x.Kdf == y.Kdf && + x.KdfIterations == y.KdfIterations; + } - public int GetHashCode([DisallowNull] UserKdfInformation obj) - { - return base.GetHashCode(); + public int GetHashCode([DisallowNull] UserKdfInformation obj) + { + return base.GetHashCode(); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs index 53edbd3c44..ae3f4fe9bd 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs @@ -6,43 +6,44 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class FolderRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfFolderAutoData] - public async void CreateAsync_Works_DataMatches( - Folder folder, - User user, - FolderCompare equalityComparer, - List suts, - List efUserRepos, - SqlRepo.FolderRepository sqlFolderRepo, - SqlRepo.UserRepository sqlUserRepo) + public class FolderRepositoryTests { - var savedFolders = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfFolderAutoData] + public async void CreateAsync_Works_DataMatches( + Folder folder, + User user, + FolderCompare equalityComparer, + List suts, + List efUserRepos, + SqlRepo.FolderRepository sqlFolderRepo, + SqlRepo.UserRepository sqlUserRepo) { - var i = suts.IndexOf(sut); + var savedFolders = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + sut.ClearChangeTracking(); - folder.UserId = efUser.Id; - var postEfFolder = await sut.CreateAsync(folder); - sut.ClearChangeTracking(); + folder.UserId = efUser.Id; + var postEfFolder = await sut.CreateAsync(folder); + sut.ClearChangeTracking(); - var savedFolder = await sut.GetByIdAsync(folder.Id); - savedFolders.Add(savedFolder); + var savedFolder = await sut.GetByIdAsync(folder.Id); + savedFolders.Add(savedFolder); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + + folder.UserId = sqlUser.Id; + var sqlFolder = await sqlFolderRepo.CreateAsync(folder); + savedFolders.Add(await sqlFolderRepo.GetByIdAsync(sqlFolder.Id)); + + var distinctItems = savedFolders.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - - folder.UserId = sqlUser.Id; - var sqlFolder = await sqlFolderRepo.CreateAsync(folder); - savedFolders.Add(await sqlFolderRepo.GetByIdAsync(sqlFolder.Id)); - - var distinctItems = savedFolders.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs index 9827b0c03f..90b8d5bbc7 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs @@ -6,33 +6,34 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class InstallationRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfInstallationAutoData] - public async void CreateAsync_Works_DataMatches( - Installation installation, - InstallationCompare equalityComparer, - List suts, - SqlRepo.InstallationRepository sqlInstallationRepo - ) + public class InstallationRepositoryTests { - var savedInstallations = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfInstallationAutoData] + public async void CreateAsync_Works_DataMatches( + Installation installation, + InstallationCompare equalityComparer, + List suts, + SqlRepo.InstallationRepository sqlInstallationRepo + ) { - var postEfInstallation = await sut.CreateAsync(installation); - sut.ClearChangeTracking(); + var savedInstallations = new List(); + foreach (var sut in suts) + { + var postEfInstallation = await sut.CreateAsync(installation); + sut.ClearChangeTracking(); - var savedInstallation = await sut.GetByIdAsync(postEfInstallation.Id); - savedInstallations.Add(savedInstallation); + var savedInstallation = await sut.GetByIdAsync(postEfInstallation.Id); + savedInstallations.Add(savedInstallation); + } + + var sqlInstallation = await sqlInstallationRepo.CreateAsync(installation); + var savedSqlInstallation = await sqlInstallationRepo.GetByIdAsync(sqlInstallation.Id); + savedInstallations.Add(savedSqlInstallation); + + var distinctItems = savedInstallations.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var sqlInstallation = await sqlInstallationRepo.CreateAsync(installation); - var savedSqlInstallation = await sqlInstallationRepo.GetByIdAsync(sqlInstallation.Id); - savedInstallations.Add(savedSqlInstallation); - - var distinctItems = savedInstallations.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs index 04e314d560..eb6713afbc 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs @@ -7,143 +7,144 @@ using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using Organization = Bit.Core.Entities.Organization; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class OrganizationRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfOrganizationAutoData] - public async void CreateAsync_Works_DataMatches( - Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, - List suts) + public class OrganizationRepositoryTests { - var savedOrganizations = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationAutoData] + public async void CreateAsync_Works_DataMatches( + Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, + List suts) { - var postEfOrganization = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var savedOrganizations = new List(); + foreach (var sut in suts) + { + var postEfOrganization = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var savedOrganization = await sut.GetByIdAsync(organization.Id); - savedOrganizations.Add(savedOrganization); + var savedOrganization = await sut.GetByIdAsync(organization.Id); + savedOrganizations.Add(savedOrganization); + } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(sqlOrganization.Id)); + + var distinctItems = savedOrganizations.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(sqlOrganization.Id)); - - var distinctItems = savedOrganizations.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationAutoData] - public async void ReplaceAsync_Works_DataMatches(Organization postOrganization, - Organization replaceOrganization, SqlRepo.OrganizationRepository sqlOrganizationRepo, - OrganizationCompare equalityComparer, List suts) - { - var savedOrganizations = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationAutoData] + public async void ReplaceAsync_Works_DataMatches(Organization postOrganization, + Organization replaceOrganization, SqlRepo.OrganizationRepository sqlOrganizationRepo, + OrganizationCompare equalityComparer, List suts) { - var postEfOrganization = await sut.CreateAsync(postOrganization); - sut.ClearChangeTracking(); + var savedOrganizations = new List(); + foreach (var sut in suts) + { + var postEfOrganization = await sut.CreateAsync(postOrganization); + sut.ClearChangeTracking(); - replaceOrganization.Id = postEfOrganization.Id; - await sut.ReplaceAsync(replaceOrganization); - sut.ClearChangeTracking(); + replaceOrganization.Id = postEfOrganization.Id; + await sut.ReplaceAsync(replaceOrganization); + sut.ClearChangeTracking(); - var replacedOrganization = await sut.GetByIdAsync(replaceOrganization.Id); - savedOrganizations.Add(replacedOrganization); + var replacedOrganization = await sut.GetByIdAsync(replaceOrganization.Id); + savedOrganizations.Add(replacedOrganization); + } + + var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(postOrganization); + replaceOrganization.Id = postSqlOrganization.Id; + await sqlOrganizationRepo.ReplaceAsync(replaceOrganization); + savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(replaceOrganization.Id)); + + var distinctItems = savedOrganizations.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(postOrganization); - replaceOrganization.Id = postSqlOrganization.Id; - await sqlOrganizationRepo.ReplaceAsync(replaceOrganization); - savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(replaceOrganization.Id)); - - var distinctItems = savedOrganizations.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationAutoData] - public async void DeleteAsync_Works_DataMatches(Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) - { - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationAutoData] + public async void DeleteAsync_Works_DataMatches(Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) { - var postEfOrganization = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var postEfOrganization = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var savedEfOrganization = await sut.GetByIdAsync(postEfOrganization.Id); - sut.ClearChangeTracking(); - Assert.True(savedEfOrganization != null); + var savedEfOrganization = await sut.GetByIdAsync(postEfOrganization.Id); + sut.ClearChangeTracking(); + Assert.True(savedEfOrganization != null); - await sut.DeleteAsync(savedEfOrganization); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfOrganization); + sut.ClearChangeTracking(); - savedEfOrganization = await sut.GetByIdAsync(savedEfOrganization.Id); - Assert.True(savedEfOrganization == null); + savedEfOrganization = await sut.GetByIdAsync(savedEfOrganization.Id); + Assert.True(savedEfOrganization == null); + } + + var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + var savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); + Assert.True(savedSqlOrganization != null); + + await sqlOrganizationRepo.DeleteAsync(postSqlOrganization); + savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); + Assert.True(savedSqlOrganization == null); } - var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - var savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); - Assert.True(savedSqlOrganization != null); - - await sqlOrganizationRepo.DeleteAsync(postSqlOrganization); - savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); - Assert.True(savedSqlOrganization == null); - } - - [CiSkippedTheory, EfOrganizationAutoData] - public async void GetByIdentifierAsync_Works_DataMatches(Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, - List suts) - { - var returnedOrgs = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationAutoData] + public async void GetByIdentifierAsync_Works_DataMatches(Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, + List suts) { - var postEfOrg = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var returnedOrgs = new List(); + foreach (var sut in suts) + { + var postEfOrg = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var returnedOrg = await sut.GetByIdentifierAsync(postEfOrg.Identifier.ToUpperInvariant()); - returnedOrgs.Add(returnedOrg); + var returnedOrg = await sut.GetByIdentifierAsync(postEfOrg.Identifier.ToUpperInvariant()); + returnedOrgs.Add(returnedOrg); + } + + var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); + returnedOrgs.Add(await sqlOrganizationRepo.GetByIdentifierAsync(postSqlOrg.Identifier.ToUpperInvariant())); + + var distinctItems = returnedOrgs.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); - returnedOrgs.Add(await sqlOrganizationRepo.GetByIdentifierAsync(postSqlOrg.Identifier.ToUpperInvariant())); - - var distinctItems = returnedOrgs.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationAutoData] - public async void GetManyByEnabledAsync_Works_DataMatches(Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) - { - var returnedOrgs = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationAutoData] + public async void GetManyByEnabledAsync_Works_DataMatches(Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) { - var postEfOrg = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var returnedOrgs = new List(); + foreach (var sut in suts) + { + var postEfOrg = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var efReturnedOrgs = await sut.GetManyByEnabledAsync(); - returnedOrgs.Concat(efReturnedOrgs); + var efReturnedOrgs = await sut.GetManyByEnabledAsync(); + returnedOrgs.Concat(efReturnedOrgs); + } + + var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); + returnedOrgs.Concat(await sqlOrganizationRepo.GetManyByEnabledAsync()); + + Assert.True(returnedOrgs.All(o => o.Enabled)); } - var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); - returnedOrgs.Concat(await sqlOrganizationRepo.GetManyByEnabledAsync()); - - Assert.True(returnedOrgs.All(o => o.Enabled)); - } - - // testing data matches here would require manipulating all organization abilities in the db - [CiSkippedTheory, EfOrganizationAutoData] - public async void GetManyAbilitiesAsync_Works(SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) - { - var list = new List(); - foreach (var sut in suts) + // testing data matches here would require manipulating all organization abilities in the db + [CiSkippedTheory, EfOrganizationAutoData] + public async void GetManyAbilitiesAsync_Works(SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) { - list.Concat(await sut.GetManyAbilitiesAsync()); - } + var list = new List(); + foreach (var sut in suts) + { + list.Concat(await sut.GetManyAbilitiesAsync()); + } - list.Concat(await sqlOrganizationRepo.GetManyAbilitiesAsync()); - Assert.True(list.All(x => x.GetType() == typeof(OrganizationAbility))); + list.Concat(await sqlOrganizationRepo.GetManyAbilitiesAsync()); + Assert.True(list.All(x => x.GetType() == typeof(OrganizationAbility))); + } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs index ee7d0d271c..29482df29a 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs @@ -6,126 +6,127 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class OrganizationSponsorshipRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfOrganizationSponsorshipAutoData] - public async void CreateAsync_Works_DataMatches( - OrganizationSponsorship organizationSponsorship, Organization sponsoringOrg, - List efOrgRepos, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, - OrganizationSponsorshipCompare equalityComparer, - List suts) + public class OrganizationSponsorshipRepositoryTests { - organizationSponsorship.SponsoredOrganizationId = null; - - var savedOrganizationSponsorships = new List(); - foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) + [CiSkippedTheory, EfOrganizationSponsorshipAutoData] + public async void CreateAsync_Works_DataMatches( + OrganizationSponsorship organizationSponsorship, Organization sponsoringOrg, + List efOrgRepos, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, + OrganizationSponsorshipCompare equalityComparer, + List suts) { - var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); - sut.ClearChangeTracking(); - organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + organizationSponsorship.SponsoredOrganizationId = null; - await sut.CreateAsync(organizationSponsorship); - sut.ClearChangeTracking(); + var savedOrganizationSponsorships = new List(); + foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) + { + var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); + sut.ClearChangeTracking(); + organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - var savedOrganizationSponsorship = await sut.GetByIdAsync(organizationSponsorship.Id); - savedOrganizationSponsorships.Add(savedOrganizationSponsorship); + await sut.CreateAsync(organizationSponsorship); + sut.ClearChangeTracking(); + + var savedOrganizationSponsorship = await sut.GetByIdAsync(organizationSponsorship.Id); + savedOrganizationSponsorships.Add(savedOrganizationSponsorship); + } + + var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); + organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; + + var sqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); + savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(sqlOrganizationSponsorship.Id)); + + var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); - organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; - - var sqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); - savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(sqlOrganizationSponsorship.Id)); - - var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationSponsorshipAutoData] - public async void ReplaceAsync_Works_DataMatches(OrganizationSponsorship postOrganizationSponsorship, - OrganizationSponsorship replaceOrganizationSponsorship, Organization sponsoringOrg, - List efOrgRepos, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, - OrganizationSponsorshipCompare equalityComparer, List suts) - { - postOrganizationSponsorship.SponsoredOrganizationId = null; - replaceOrganizationSponsorship.SponsoredOrganizationId = null; - - var savedOrganizationSponsorships = new List(); - foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) + [CiSkippedTheory, EfOrganizationSponsorshipAutoData] + public async void ReplaceAsync_Works_DataMatches(OrganizationSponsorship postOrganizationSponsorship, + OrganizationSponsorship replaceOrganizationSponsorship, Organization sponsoringOrg, + List efOrgRepos, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, + OrganizationSponsorshipCompare equalityComparer, List suts) { - var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); - sut.ClearChangeTracking(); - postOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - replaceOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + postOrganizationSponsorship.SponsoredOrganizationId = null; + replaceOrganizationSponsorship.SponsoredOrganizationId = null; - var postEfOrganizationSponsorship = await sut.CreateAsync(postOrganizationSponsorship); - sut.ClearChangeTracking(); + var savedOrganizationSponsorships = new List(); + foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) + { + var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); + sut.ClearChangeTracking(); + postOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + replaceOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - replaceOrganizationSponsorship.Id = postEfOrganizationSponsorship.Id; - await sut.ReplaceAsync(replaceOrganizationSponsorship); - sut.ClearChangeTracking(); + var postEfOrganizationSponsorship = await sut.CreateAsync(postOrganizationSponsorship); + sut.ClearChangeTracking(); - var replacedOrganizationSponsorship = await sut.GetByIdAsync(replaceOrganizationSponsorship.Id); - savedOrganizationSponsorships.Add(replacedOrganizationSponsorship); + replaceOrganizationSponsorship.Id = postEfOrganizationSponsorship.Id; + await sut.ReplaceAsync(replaceOrganizationSponsorship); + sut.ClearChangeTracking(); + + var replacedOrganizationSponsorship = await sut.GetByIdAsync(replaceOrganizationSponsorship.Id); + savedOrganizationSponsorships.Add(replacedOrganizationSponsorship); + } + + var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); + postOrganizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; + + var postSqlOrganization = await sqlOrganizationSponsorshipRepo.CreateAsync(postOrganizationSponsorship); + replaceOrganizationSponsorship.Id = postSqlOrganization.Id; + await sqlOrganizationSponsorshipRepo.ReplaceAsync(replaceOrganizationSponsorship); + savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(replaceOrganizationSponsorship.Id)); + + var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); - postOrganizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; - - var postSqlOrganization = await sqlOrganizationSponsorshipRepo.CreateAsync(postOrganizationSponsorship); - replaceOrganizationSponsorship.Id = postSqlOrganization.Id; - await sqlOrganizationSponsorshipRepo.ReplaceAsync(replaceOrganizationSponsorship); - savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(replaceOrganizationSponsorship.Id)); - - var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationSponsorshipAutoData] - public async void DeleteAsync_Works_DataMatches(OrganizationSponsorship organizationSponsorship, - Organization sponsoringOrg, - List efOrgRepos, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, - List suts) - { - organizationSponsorship.SponsoredOrganizationId = null; - - foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) + [CiSkippedTheory, EfOrganizationSponsorshipAutoData] + public async void DeleteAsync_Works_DataMatches(OrganizationSponsorship organizationSponsorship, + Organization sponsoringOrg, + List efOrgRepos, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, + List suts) { - var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); - sut.ClearChangeTracking(); - organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + organizationSponsorship.SponsoredOrganizationId = null; - var postEfOrganizationSponsorship = await sut.CreateAsync(organizationSponsorship); - sut.ClearChangeTracking(); + foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) + { + var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); + sut.ClearChangeTracking(); + organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - var savedEfOrganizationSponsorship = await sut.GetByIdAsync(postEfOrganizationSponsorship.Id); - sut.ClearChangeTracking(); - Assert.True(savedEfOrganizationSponsorship != null); + var postEfOrganizationSponsorship = await sut.CreateAsync(organizationSponsorship); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfOrganizationSponsorship); - sut.ClearChangeTracking(); + var savedEfOrganizationSponsorship = await sut.GetByIdAsync(postEfOrganizationSponsorship.Id); + sut.ClearChangeTracking(); + Assert.True(savedEfOrganizationSponsorship != null); - savedEfOrganizationSponsorship = await sut.GetByIdAsync(savedEfOrganizationSponsorship.Id); - Assert.True(savedEfOrganizationSponsorship == null); + await sut.DeleteAsync(savedEfOrganizationSponsorship); + sut.ClearChangeTracking(); + + savedEfOrganizationSponsorship = await sut.GetByIdAsync(savedEfOrganizationSponsorship.Id); + Assert.True(savedEfOrganizationSponsorship == null); + } + + var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); + organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; + + var postSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); + var savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); + Assert.True(savedSqlOrganizationSponsorship != null); + + await sqlOrganizationSponsorshipRepo.DeleteAsync(postSqlOrganizationSponsorship); + savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); + Assert.True(savedSqlOrganizationSponsorship == null); } - - var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); - organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; - - var postSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); - var savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); - Assert.True(savedSqlOrganizationSponsorship != null); - - await sqlOrganizationSponsorshipRepo.DeleteAsync(postSqlOrganizationSponsorship); - savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); - Assert.True(savedSqlOrganizationSponsorship == null); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs index 2becc0fc65..34f1c6f4b1 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs @@ -7,141 +7,142 @@ using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using OrganizationUser = Bit.Core.Entities.OrganizationUser; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class OrganizationUserRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfOrganizationUserAutoData] - public async void CreateAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, - OrganizationUserCompare equalityComparer, List suts, - List efOrgRepos, List efUserRepos, - SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) + public class OrganizationUserRepositoryTests { - var savedOrgUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationUserAutoData] + public async void CreateAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, + OrganizationUserCompare equalityComparer, List suts, + List efOrgRepos, List efUserRepos, + SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) { - var i = suts.IndexOf(sut); - var postEfUser = await efUserRepos[i].CreateAsync(user); - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedOrgUsers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + var postEfUser = await efUserRepos[i].CreateAsync(user); + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - orgUser.UserId = postEfUser.Id; - orgUser.OrganizationId = postEfOrg.Id; - var postEfOrgUser = await sut.CreateAsync(orgUser); - sut.ClearChangeTracking(); + orgUser.UserId = postEfUser.Id; + orgUser.OrganizationId = postEfOrg.Id; + var postEfOrgUser = await sut.CreateAsync(orgUser); + sut.ClearChangeTracking(); - var savedOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); - savedOrgUsers.Add(savedOrgUser); + var savedOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); + savedOrgUsers.Add(savedOrgUser); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var postSqlOrg = await sqlOrgRepo.CreateAsync(org); + + orgUser.UserId = postSqlUser.Id; + orgUser.OrganizationId = postSqlOrg.Id; + var sqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); + + var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(sqlOrgUser.Id); + savedOrgUsers.Add(savedSqlOrgUser); + + var distinctItems = savedOrgUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var postSqlOrg = await sqlOrgRepo.CreateAsync(org); - - orgUser.UserId = postSqlUser.Id; - orgUser.OrganizationId = postSqlOrg.Id; - var sqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); - - var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(sqlOrgUser.Id); - savedOrgUsers.Add(savedSqlOrgUser); - - var distinctItems = savedOrgUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationUserAutoData] - public async void ReplaceAsync_Works_DataMatches( - OrganizationUser postOrgUser, - OrganizationUser replaceOrgUser, - User user, - Organization org, - OrganizationUserCompare equalityComparer, - List suts, - List efUserRepos, - List efOrgRepos, - SqlRepo.OrganizationUserRepository sqlOrgUserRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo - ) - { - var savedOrgUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationUserAutoData] + public async void ReplaceAsync_Works_DataMatches( + OrganizationUser postOrgUser, + OrganizationUser replaceOrgUser, + User user, + Organization org, + OrganizationUserCompare equalityComparer, + List suts, + List efUserRepos, + List efOrgRepos, + SqlRepo.OrganizationUserRepository sqlOrgUserRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo + ) { - var i = suts.IndexOf(sut); - var postEfUser = await efUserRepos[i].CreateAsync(user); - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedOrgUsers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + var postEfUser = await efUserRepos[i].CreateAsync(user); + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - postOrgUser.UserId = replaceOrgUser.UserId = postEfUser.Id; - postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postEfOrg.Id; - var postEfOrgUser = await sut.CreateAsync(postOrgUser); - sut.ClearChangeTracking(); + postOrgUser.UserId = replaceOrgUser.UserId = postEfUser.Id; + postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postEfOrg.Id; + var postEfOrgUser = await sut.CreateAsync(postOrgUser); + sut.ClearChangeTracking(); - replaceOrgUser.Id = postOrgUser.Id; - await sut.ReplaceAsync(replaceOrgUser); - sut.ClearChangeTracking(); + replaceOrgUser.Id = postOrgUser.Id; + await sut.ReplaceAsync(replaceOrgUser); + sut.ClearChangeTracking(); - var replacedOrganizationUser = await sut.GetByIdAsync(replaceOrgUser.Id); - savedOrgUsers.Add(replacedOrganizationUser); + var replacedOrganizationUser = await sut.GetByIdAsync(replaceOrgUser.Id); + savedOrgUsers.Add(replacedOrganizationUser); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var postSqlOrg = await sqlOrgRepo.CreateAsync(org); + + postOrgUser.UserId = replaceOrgUser.UserId = postSqlUser.Id; + postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postSqlOrg.Id; + var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(postOrgUser); + + replaceOrgUser.Id = postSqlOrgUser.Id; + await sqlOrgUserRepo.ReplaceAsync(replaceOrgUser); + + var replacedSqlUser = await sqlOrgUserRepo.GetByIdAsync(replaceOrgUser.Id); + + var distinctItems = savedOrgUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var postSqlOrg = await sqlOrgRepo.CreateAsync(org); - - postOrgUser.UserId = replaceOrgUser.UserId = postSqlUser.Id; - postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postSqlOrg.Id; - var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(postOrgUser); - - replaceOrgUser.Id = postSqlOrgUser.Id; - await sqlOrgUserRepo.ReplaceAsync(replaceOrgUser); - - var replacedSqlUser = await sqlOrgUserRepo.GetByIdAsync(replaceOrgUser.Id); - - var distinctItems = savedOrgUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfOrganizationUserAutoData] - public async void DeleteAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, List suts, - List efUserRepos, List efOrgRepos, - SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) - { - foreach (var sut in suts) + [CiSkippedTheory, EfOrganizationUserAutoData] + public async void DeleteAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, List suts, + List efUserRepos, List efOrgRepos, + SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) { - var i = suts.IndexOf(sut); - var postEfUser = await efUserRepos[i].CreateAsync(user); - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + var postEfUser = await efUserRepos[i].CreateAsync(user); + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - orgUser.UserId = postEfUser.Id; - orgUser.OrganizationId = postEfOrg.Id; - var postEfOrgUser = await sut.CreateAsync(orgUser); - sut.ClearChangeTracking(); + orgUser.UserId = postEfUser.Id; + orgUser.OrganizationId = postEfOrg.Id; + var postEfOrgUser = await sut.CreateAsync(orgUser); + sut.ClearChangeTracking(); - var savedEfOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); - Assert.True(savedEfOrgUser != null); - sut.ClearChangeTracking(); + var savedEfOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); + Assert.True(savedEfOrgUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfOrgUser); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfOrgUser); + sut.ClearChangeTracking(); - savedEfOrgUser = await sut.GetByIdAsync(savedEfOrgUser.Id); - Assert.True(savedEfOrgUser == null); + savedEfOrgUser = await sut.GetByIdAsync(savedEfOrgUser.Id); + Assert.True(savedEfOrgUser == null); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var postSqlOrg = await sqlOrgRepo.CreateAsync(org); + + orgUser.UserId = postSqlUser.Id; + orgUser.OrganizationId = postSqlOrg.Id; + var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); + + var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); + Assert.True(savedSqlOrgUser != null); + + await sqlOrgUserRepo.DeleteAsync(postSqlOrgUser); + savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); + Assert.True(savedSqlOrgUser == null); } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var postSqlOrg = await sqlOrgRepo.CreateAsync(org); - - orgUser.UserId = postSqlUser.Id; - orgUser.OrganizationId = postSqlOrg.Id; - var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); - - var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); - Assert.True(savedSqlOrgUser != null); - - await sqlOrgUserRepo.DeleteAsync(postSqlOrgUser); - savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); - Assert.True(savedSqlOrgUser == null); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs index 18a2676cd7..d013de430e 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs @@ -12,184 +12,185 @@ using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using Policy = Bit.Core.Entities.Policy; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class PolicyRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfPolicyAutoData] - public async void CreateAsync_Works_DataMatches( - Policy policy, - Organization organization, - PolicyCompare equalityComparer, - List suts, - List efOrganizationRepos, - SqlRepo.PolicyRepository sqlPolicyRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo - ) + public class PolicyRepositoryTests { - var savedPolicys = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfPolicyAutoData] + public async void CreateAsync_Works_DataMatches( + Policy policy, + Organization organization, + PolicyCompare equalityComparer, + List suts, + List efOrganizationRepos, + SqlRepo.PolicyRepository sqlPolicyRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo + ) { - var i = suts.IndexOf(sut); + var savedPolicys = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); - sut.ClearChangeTracking(); + var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); + sut.ClearChangeTracking(); - policy.OrganizationId = efOrganization.Id; - var postEfPolicy = await sut.CreateAsync(policy); - sut.ClearChangeTracking(); + policy.OrganizationId = efOrganization.Id; + var postEfPolicy = await sut.CreateAsync(policy); + sut.ClearChangeTracking(); - var savedPolicy = await sut.GetByIdAsync(postEfPolicy.Id); - savedPolicys.Add(savedPolicy); + var savedPolicy = await sut.GetByIdAsync(postEfPolicy.Id); + savedPolicys.Add(savedPolicy); + } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + + policy.OrganizationId = sqlOrganization.Id; + var sqlPolicy = await sqlPolicyRepo.CreateAsync(policy); + var savedSqlPolicy = await sqlPolicyRepo.GetByIdAsync(sqlPolicy.Id); + savedPolicys.Add(savedSqlPolicy); + + var distinctItems = savedPolicys.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + [CiSkippedTheory] + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Ordinary user + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, true, true, true, false)] // Invited user + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Owner, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Owner + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Admin, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Admin + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, true, OrganizationUserStatusType.Confirmed, false, true, true, false)] // canManagePolicies + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, true)] // Provider + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, false, true, false)] // Policy disabled + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, false, false)] // No policy of Type + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, false, true, true, false)] // User not minStatus - policy.OrganizationId = sqlOrganization.Id; - var sqlPolicy = await sqlPolicyRepo.CreateAsync(policy); - var savedSqlPolicy = await sqlPolicyRepo.GetByIdAsync(sqlPolicy.Id); - savedPolicys.Add(savedSqlPolicy); + public async void GetManyByTypeApplicableToUser_Works_DataMatches( + // Inline data + OrganizationUserType userType, + bool canManagePolicies, + OrganizationUserStatusType orgUserStatus, + bool includeInvited, + bool policyEnabled, + bool policySameType, + bool isProvider, - var distinctItems = savedPolicys.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } + // Auto data - models + Policy policy, + User user, + Organization organization, + OrganizationUser orgUser, + Provider provider, + ProviderOrganization providerOrganization, + ProviderUser providerUser, + PolicyCompareIncludingOrganization equalityComparer, - [CiSkippedTheory] - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Ordinary user - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, true, true, true, false)] // Invited user - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Owner, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Owner - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Admin, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Admin - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, true, OrganizationUserStatusType.Confirmed, false, true, true, false)] // canManagePolicies - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, true)] // Provider - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, false, true, false)] // Policy disabled - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, false, false)] // No policy of Type - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, false, true, true, false)] // User not minStatus + // Auto data - EF repos + List suts, + List efUserRepository, + List efOrganizationRepository, + List efOrganizationUserRepository, + List efProviderRepository, + List efProviderOrganizationRepository, + List efProviderUserRepository, - public async void GetManyByTypeApplicableToUser_Works_DataMatches( - // Inline data - OrganizationUserType userType, - bool canManagePolicies, - OrganizationUserStatusType orgUserStatus, - bool includeInvited, - bool policyEnabled, - bool policySameType, - bool isProvider, - - // Auto data - models - Policy policy, - User user, - Organization organization, - OrganizationUser orgUser, - Provider provider, - ProviderOrganization providerOrganization, - ProviderUser providerUser, - PolicyCompareIncludingOrganization equalityComparer, - - // Auto data - EF repos - List suts, - List efUserRepository, - List efOrganizationRepository, - List efOrganizationUserRepository, - List efProviderRepository, - List efProviderOrganizationRepository, - List efProviderUserRepository, - - // Auto data - SQL repos - SqlRepo.PolicyRepository sqlPolicyRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.ProviderRepository sqlProviderRepo, - SqlRepo.OrganizationUserRepository sqlOrganizationUserRepo, - SqlRepo.ProviderOrganizationRepository sqlProviderOrganizationRepo, - SqlRepo.ProviderUserRepository sqlProviderUserRepo - ) - { - // Combine EF and SQL repos into one list per type - var policyRepos = suts.ToList(); - policyRepos.Add(sqlPolicyRepo); - var userRepos = efUserRepository.ToList(); - userRepos.Add(sqlUserRepo); - var orgRepos = efOrganizationRepository.ToList(); - orgRepos.Add(sqlOrganizationRepo); - var orgUserRepos = efOrganizationUserRepository.ToList(); - orgUserRepos.Add(sqlOrganizationUserRepo); - var providerRepos = efProviderRepository.ToList(); - providerRepos.Add(sqlProviderRepo); - var providerOrgRepos = efProviderOrganizationRepository.ToList(); - providerOrgRepos.Add(sqlProviderOrganizationRepo); - var providerUserRepos = efProviderUserRepository.ToList(); - providerUserRepos.Add(sqlProviderUserRepo); - - // Arrange data - var savedPolicyType = PolicyType.SingleOrg; - var queriedPolicyType = policySameType ? savedPolicyType : PolicyType.DisableSend; - - orgUser.Type = userType; - orgUser.Status = orgUserStatus; - var permissionsData = new Permissions { ManagePolicies = canManagePolicies }; - orgUser.Permissions = JsonSerializer.Serialize(permissionsData, new JsonSerializerOptions + // Auto data - SQL repos + SqlRepo.PolicyRepository sqlPolicyRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.ProviderRepository sqlProviderRepo, + SqlRepo.OrganizationUserRepository sqlOrganizationUserRepo, + SqlRepo.ProviderOrganizationRepository sqlProviderOrganizationRepo, + SqlRepo.ProviderUserRepository sqlProviderUserRepo + ) { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); + // Combine EF and SQL repos into one list per type + var policyRepos = suts.ToList(); + policyRepos.Add(sqlPolicyRepo); + var userRepos = efUserRepository.ToList(); + userRepos.Add(sqlUserRepo); + var orgRepos = efOrganizationRepository.ToList(); + orgRepos.Add(sqlOrganizationRepo); + var orgUserRepos = efOrganizationUserRepository.ToList(); + orgUserRepos.Add(sqlOrganizationUserRepo); + var providerRepos = efProviderRepository.ToList(); + providerRepos.Add(sqlProviderRepo); + var providerOrgRepos = efProviderOrganizationRepository.ToList(); + providerOrgRepos.Add(sqlProviderOrganizationRepo); + var providerUserRepos = efProviderUserRepository.ToList(); + providerUserRepos.Add(sqlProviderUserRepo); - policy.Enabled = policyEnabled; - policy.Type = savedPolicyType; + // Arrange data + var savedPolicyType = PolicyType.SingleOrg; + var queriedPolicyType = policySameType ? savedPolicyType : PolicyType.DisableSend; - var results = new List(); - - foreach (var policyRepo in policyRepos) - { - var i = policyRepos.IndexOf(policyRepo); - - // Seed database - var savedUser = await userRepos[i].CreateAsync(user); - var savedOrg = await orgRepos[i].CreateAsync(organization); - - // Invited orgUsers are not associated with an account yet, so they are identified by Email not UserId - if (orgUserStatus == OrganizationUserStatusType.Invited) + orgUser.Type = userType; + orgUser.Status = orgUserStatus; + var permissionsData = new Permissions { ManagePolicies = canManagePolicies }; + orgUser.Permissions = JsonSerializer.Serialize(permissionsData, new JsonSerializerOptions { - orgUser.Email = savedUser.Email; - orgUser.UserId = null; - } - else + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + + policy.Enabled = policyEnabled; + policy.Type = savedPolicyType; + + var results = new List(); + + foreach (var policyRepo in policyRepos) { - orgUser.UserId = savedUser.Id; + var i = policyRepos.IndexOf(policyRepo); + + // Seed database + var savedUser = await userRepos[i].CreateAsync(user); + var savedOrg = await orgRepos[i].CreateAsync(organization); + + // Invited orgUsers are not associated with an account yet, so they are identified by Email not UserId + if (orgUserStatus == OrganizationUserStatusType.Invited) + { + orgUser.Email = savedUser.Email; + orgUser.UserId = null; + } + else + { + orgUser.UserId = savedUser.Id; + } + + orgUser.OrganizationId = savedOrg.Id; + await orgUserRepos[i].CreateAsync(orgUser); + + if (isProvider) + { + var savedProvider = await providerRepos[i].CreateAsync(provider); + + providerOrganization.OrganizationId = savedOrg.Id; + providerOrganization.ProviderId = savedProvider.Id; + await providerOrgRepos[i].CreateAsync(providerOrganization); + + providerUser.UserId = savedUser.Id; + providerUser.ProviderId = savedProvider.Id; + await providerUserRepos[i].CreateAsync(providerUser); + } + + policy.OrganizationId = savedOrg.Id; + await policyRepo.CreateAsync(policy); + if (suts.Contains(policyRepo)) + { + (policyRepo as EfRepo.BaseEntityFrameworkRepository).ClearChangeTracking(); + } + + var minStatus = includeInvited ? OrganizationUserStatusType.Invited : OrganizationUserStatusType.Accepted; + + // Act + var result = await policyRepo.GetManyByTypeApplicableToUserIdAsync(savedUser.Id, queriedPolicyType, minStatus); + results.Add(result.FirstOrDefault()); } - orgUser.OrganizationId = savedOrg.Id; - await orgUserRepos[i].CreateAsync(orgUser); + // Assert + var distinctItems = results.Distinct(equalityComparer); - if (isProvider) - { - var savedProvider = await providerRepos[i].CreateAsync(provider); - - providerOrganization.OrganizationId = savedOrg.Id; - providerOrganization.ProviderId = savedProvider.Id; - await providerOrgRepos[i].CreateAsync(providerOrganization); - - providerUser.UserId = savedUser.Id; - providerUser.ProviderId = savedProvider.Id; - await providerUserRepos[i].CreateAsync(providerUser); - } - - policy.OrganizationId = savedOrg.Id; - await policyRepo.CreateAsync(policy); - if (suts.Contains(policyRepo)) - { - (policyRepo as EfRepo.BaseEntityFrameworkRepository).ClearChangeTracking(); - } - - var minStatus = includeInvited ? OrganizationUserStatusType.Invited : OrganizationUserStatusType.Accepted; - - // Act - var result = await policyRepo.GetManyByTypeApplicableToUserIdAsync(savedUser.Id, queriedPolicyType, minStatus); - results.Add(result.FirstOrDefault()); + Assert.True(results.All(r => r == null) || + !distinctItems.Skip(1).Any()); } - - // Assert - var distinctItems = results.Distinct(equalityComparer); - - Assert.True(results.All(r => r == null) || - !distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs index 628b5562c4..6158be3ee9 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs @@ -6,59 +6,60 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class SendRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfUserSendAutoData, EfOrganizationSendAutoData] - public async void CreateAsync_Works_DataMatches( - Send send, - User user, - Organization org, - SendCompare equalityComparer, - List suts, - List efUserRepos, - List efOrgRepos, - SqlRepo.SendRepository sqlSendRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo - ) + public class SendRepositoryTests { - var savedSends = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfUserSendAutoData, EfOrganizationSendAutoData] + public async void CreateAsync_Works_DataMatches( + Send send, + User user, + Organization org, + SendCompare equalityComparer, + List suts, + List efUserRepos, + List efOrgRepos, + SqlRepo.SendRepository sqlSendRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo + ) { - var i = suts.IndexOf(sut); + var savedSends = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + if (send.OrganizationId.HasValue) + { + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); + send.OrganizationId = efOrg.Id; + } + var efUser = await efUserRepos[i].CreateAsync(user); + sut.ClearChangeTracking(); + + send.UserId = efUser.Id; + var postEfSend = await sut.CreateAsync(send); + sut.ClearChangeTracking(); + + var savedSend = await sut.GetByIdAsync(postEfSend.Id); + savedSends.Add(savedSend); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); if (send.OrganizationId.HasValue) { - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); - send.OrganizationId = efOrg.Id; + var sqlOrg = await sqlOrgRepo.CreateAsync(org); + send.OrganizationId = sqlOrg.Id; } - var efUser = await efUserRepos[i].CreateAsync(user); - sut.ClearChangeTracking(); - send.UserId = efUser.Id; - var postEfSend = await sut.CreateAsync(send); - sut.ClearChangeTracking(); + send.UserId = sqlUser.Id; + var sqlSend = await sqlSendRepo.CreateAsync(send); + var savedSqlSend = await sqlSendRepo.GetByIdAsync(sqlSend.Id); + savedSends.Add(savedSqlSend); - var savedSend = await sut.GetByIdAsync(postEfSend.Id); - savedSends.Add(savedSend); + var distinctItems = savedSends.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - if (send.OrganizationId.HasValue) - { - var sqlOrg = await sqlOrgRepo.CreateAsync(org); - send.OrganizationId = sqlOrg.Id; - } - - send.UserId = sqlUser.Id; - var sqlSend = await sqlSendRepo.CreateAsync(send); - var savedSqlSend = await sqlSendRepo.GetByIdAsync(sqlSend.Id); - savedSends.Add(savedSqlSend); - - var distinctItems = savedSends.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs index 7858bc1f0e..c36c9efb44 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs @@ -6,221 +6,222 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class SsoConfigRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfSsoConfigAutoData] - public async void CreateAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, - SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) + public class SsoConfigRepositoryTests { - var savedSsoConfigs = new List(); - - foreach (var sut in suts) + [CiSkippedTheory, EfSsoConfigAutoData] + public async void CreateAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, + SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) { - var i = suts.IndexOf(sut); + var savedSsoConfigs = new List(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - ssoConfig.OrganizationId = savedEfOrg.Id; - var postEfSsoConfig = await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfSsoConfig = await sut.GetByIdAsync(ssoConfig.Id); - Assert.True(savedEfSsoConfig != null); - savedSsoConfigs.Add(savedEfSsoConfig); + ssoConfig.OrganizationId = savedEfOrg.Id; + var postEfSsoConfig = await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); + + var savedEfSsoConfig = await sut.GetByIdAsync(ssoConfig.Id); + Assert.True(savedEfSsoConfig != null); + savedSsoConfigs.Add(savedEfSsoConfig); + } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + ssoConfig.OrganizationId = sqlOrganization.Id; + + var sqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(sqlSsoConfig.Id); + Assert.True(savedSqlSsoConfig != null); + savedSsoConfigs.Add(savedSqlSsoConfig); + + var distinctItems = savedSsoConfigs.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - ssoConfig.OrganizationId = sqlOrganization.Id; - - var sqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(sqlSsoConfig.Id); - Assert.True(savedSqlSsoConfig != null); - savedSsoConfigs.Add(savedSqlSsoConfig); - - var distinctItems = savedSsoConfigs.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfSsoConfigAutoData] - public async void ReplaceAsync_Works_DataMatches(SsoConfig postSsoConfig, SsoConfig replaceSsoConfig, - Organization org, SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) - { - var savedSsoConfigs = new List(); - - foreach (var sut in suts) + [CiSkippedTheory, EfSsoConfigAutoData] + public async void ReplaceAsync_Works_DataMatches(SsoConfig postSsoConfig, SsoConfig replaceSsoConfig, + Organization org, SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) { - var i = suts.IndexOf(sut); + var savedSsoConfigs = new List(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - postSsoConfig.OrganizationId = replaceSsoConfig.OrganizationId = savedEfOrg.Id; - var postEfSsoConfig = await sut.CreateAsync(postSsoConfig); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - replaceSsoConfig.Id = postEfSsoConfig.Id; - savedSsoConfigs.Add(postEfSsoConfig); - await sut.ReplaceAsync(replaceSsoConfig); - sut.ClearChangeTracking(); + postSsoConfig.OrganizationId = replaceSsoConfig.OrganizationId = savedEfOrg.Id; + var postEfSsoConfig = await sut.CreateAsync(postSsoConfig); + sut.ClearChangeTracking(); - var replacedSsoConfig = await sut.GetByIdAsync(replaceSsoConfig.Id); - Assert.True(replacedSsoConfig != null); - savedSsoConfigs.Add(replacedSsoConfig); + replaceSsoConfig.Id = postEfSsoConfig.Id; + savedSsoConfigs.Add(postEfSsoConfig); + await sut.ReplaceAsync(replaceSsoConfig); + sut.ClearChangeTracking(); + + var replacedSsoConfig = await sut.GetByIdAsync(replaceSsoConfig.Id); + Assert.True(replacedSsoConfig != null); + savedSsoConfigs.Add(replacedSsoConfig); + } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + postSsoConfig.OrganizationId = sqlOrganization.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(postSsoConfig); + replaceSsoConfig.Id = postSqlSsoConfig.Id; + savedSsoConfigs.Add(postSqlSsoConfig); + + await sqlSsoConfigRepo.ReplaceAsync(replaceSsoConfig); + var replacedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(replaceSsoConfig.Id); + Assert.True(replacedSqlSsoConfig != null); + savedSsoConfigs.Add(replacedSqlSsoConfig); + + var distinctItems = savedSsoConfigs.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(2).Any()); } - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - postSsoConfig.OrganizationId = sqlOrganization.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(postSsoConfig); - replaceSsoConfig.Id = postSqlSsoConfig.Id; - savedSsoConfigs.Add(postSqlSsoConfig); - - await sqlSsoConfigRepo.ReplaceAsync(replaceSsoConfig); - var replacedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(replaceSsoConfig.Id); - Assert.True(replacedSqlSsoConfig != null); - savedSsoConfigs.Add(replacedSqlSsoConfig); - - var distinctItems = savedSsoConfigs.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(2).Any()); - } - - [CiSkippedTheory, EfSsoConfigAutoData] - public async void DeleteAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) - { - foreach (var sut in suts) + [CiSkippedTheory, EfSsoConfigAutoData] + public async void DeleteAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) { - var i = suts.IndexOf(sut); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - var postEfSsoConfig = await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + var postEfSsoConfig = await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - var savedEfSsoConfig = await sut.GetByIdAsync(postEfSsoConfig.Id); - Assert.True(savedEfSsoConfig != null); - sut.ClearChangeTracking(); + var savedEfSsoConfig = await sut.GetByIdAsync(postEfSsoConfig.Id); + Assert.True(savedEfSsoConfig != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfSsoConfig); - var deletedEfSsoConfig = await sut.GetByIdAsync(savedEfSsoConfig.Id); - Assert.True(deletedEfSsoConfig == null); + await sut.DeleteAsync(savedEfSsoConfig); + var deletedEfSsoConfig = await sut.GetByIdAsync(savedEfSsoConfig.Id); + Assert.True(deletedEfSsoConfig == null); + } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + ssoConfig.OrganizationId = sqlOrganization.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); + Assert.True(savedSqlSsoConfig != null); + + await sqlSsoConfigRepo.DeleteAsync(savedSqlSsoConfig); + savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); + Assert.True(savedSqlSsoConfig == null); } - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - ssoConfig.OrganizationId = sqlOrganization.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); - Assert.True(savedSqlSsoConfig != null); - - await sqlSsoConfigRepo.DeleteAsync(savedSqlSsoConfig); - savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); - Assert.True(savedSqlSsoConfig == null); - } - - [CiSkippedTheory, EfSsoConfigAutoData] - public async void GetByOrganizationIdAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, - SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) - { - var returnedList = new List(); - - foreach (var sut in suts) + [CiSkippedTheory, EfSsoConfigAutoData] + public async void GetByOrganizationIdAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, + SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) { - var i = suts.IndexOf(sut); + var returnedList = new List(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - ssoConfig.OrganizationId = savedEfOrg.Id; - await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfUser = await sut.GetByOrganizationIdAsync(savedEfOrg.Id); - Assert.True(savedEfUser != null); - returnedList.Add(savedEfUser); + ssoConfig.OrganizationId = savedEfOrg.Id; + await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); + + var savedEfUser = await sut.GetByOrganizationIdAsync(savedEfOrg.Id); + Assert.True(savedEfUser != null); + returnedList.Add(savedEfUser); + } + + var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); + ssoConfig.OrganizationId = savedSqlOrg.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByOrganizationIdAsync(ssoConfig.OrganizationId); + Assert.True(savedSqlSsoConfig != null); + returnedList.Add(savedSqlSsoConfig); + + var distinctItems = returnedList.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); - ssoConfig.OrganizationId = savedSqlOrg.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByOrganizationIdAsync(ssoConfig.OrganizationId); - Assert.True(savedSqlSsoConfig != null); - returnedList.Add(savedSqlSsoConfig); - - var distinctItems = returnedList.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfSsoConfigAutoData] - public async void GetByIdentifierAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, - SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) - { - var returnedList = new List(); - - foreach (var sut in suts) + [CiSkippedTheory, EfSsoConfigAutoData] + public async void GetByIdentifierAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, + SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) { - var i = suts.IndexOf(sut); + var returnedList = new List(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - ssoConfig.OrganizationId = savedEfOrg.Id; - await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfSsoConfig = await sut.GetByIdentifierAsync(org.Identifier); - Assert.True(savedEfSsoConfig != null); - returnedList.Add(savedEfSsoConfig); + ssoConfig.OrganizationId = savedEfOrg.Id; + await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); + + var savedEfSsoConfig = await sut.GetByIdentifierAsync(org.Identifier); + Assert.True(savedEfSsoConfig != null); + returnedList.Add(savedEfSsoConfig); + } + + var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); + ssoConfig.OrganizationId = savedSqlOrg.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdentifierAsync(org.Identifier); + Assert.True(savedSqlSsoConfig != null); + returnedList.Add(savedSqlSsoConfig); + + var distinctItems = returnedList.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); - ssoConfig.OrganizationId = savedSqlOrg.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdentifierAsync(org.Identifier); - Assert.True(savedSqlSsoConfig != null); - returnedList.Add(savedSqlSsoConfig); - - var distinctItems = returnedList.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - // Testing that data matches here would involve manipulating all SsoConfig records in the db - [CiSkippedTheory, EfSsoConfigAutoData] - public async void GetManyByRevisionNotBeforeDate_Works(SsoConfig ssoConfig, DateTime notBeforeDate, - Organization org, List suts, - List efOrgRepos) - { - foreach (var sut in suts) + // Testing that data matches here would involve manipulating all SsoConfig records in the db + [CiSkippedTheory, EfSsoConfigAutoData] + public async void GetManyByRevisionNotBeforeDate_Works(SsoConfig ssoConfig, DateTime notBeforeDate, + Organization org, List suts, + List efOrgRepos) { - var i = suts.IndexOf(sut); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - var returnedEfSsoConfigs = await sut.GetManyByRevisionNotBeforeDate(notBeforeDate); - Assert.True(returnedEfSsoConfigs.All(sc => sc.RevisionDate >= notBeforeDate)); + var returnedEfSsoConfigs = await sut.GetManyByRevisionNotBeforeDate(notBeforeDate); + Assert.True(returnedEfSsoConfigs.All(sc => sc.RevisionDate >= notBeforeDate)); + } } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs index bc43a05261..9e9b66eeaa 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs @@ -6,181 +6,182 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class SsoUserRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfSsoUserAutoData] - public async void CreateAsync_Works_DataMatches(SsoUser ssoUser, User user, Organization org, - SsoUserCompare equalityComparer, List suts, - List efOrgRepos, List efUserRepos, - SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo, - SqlRepo.UserRepository sqlUserRepo) + public class SsoUserRepositoryTests { - var createdSsoUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfSsoUserAutoData] + public async void CreateAsync_Works_DataMatches(SsoUser ssoUser, User user, Organization org, + SsoUserCompare equalityComparer, List suts, + List efOrgRepos, List efUserRepos, + SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo, + SqlRepo.UserRepository sqlUserRepo) { - var i = suts.IndexOf(sut); + var createdSsoUsers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoUser.UserId = efUser.Id; - ssoUser.OrganizationId = efOrg.Id; - var postEfSsoUser = await sut.CreateAsync(ssoUser); - sut.ClearChangeTracking(); + ssoUser.UserId = efUser.Id; + ssoUser.OrganizationId = efOrg.Id; + var postEfSsoUser = await sut.CreateAsync(ssoUser); + sut.ClearChangeTracking(); - var savedSsoUser = await sut.GetByIdAsync(ssoUser.Id); - createdSsoUsers.Add(savedSsoUser); + var savedSsoUser = await sut.GetByIdAsync(ssoUser.Id); + createdSsoUsers.Add(savedSsoUser); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + var sqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + + createdSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(sqlSsoUser.Id)); + + var distinctSsoUsers = createdSsoUsers.Distinct(equalityComparer); + Assert.True(!distinctSsoUsers.Skip(1).Any()); } - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - var sqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - - createdSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(sqlSsoUser.Id)); - - var distinctSsoUsers = createdSsoUsers.Distinct(equalityComparer); - Assert.True(!distinctSsoUsers.Skip(1).Any()); - } - - [CiSkippedTheory, EfSsoUserAutoData] - public async void ReplaceAsync_Works_DataMatches(SsoUser postSsoUser, SsoUser replaceSsoUser, - Organization org, User user, SsoUserCompare equalityComparer, - List suts, List efUserRepos, - List efOrgRepos, SqlRepo.SsoUserRepository sqlSsoUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo, SqlRepo.UserRepository sqlUserRepo) - { - var savedSsoUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfSsoUserAutoData] + public async void ReplaceAsync_Works_DataMatches(SsoUser postSsoUser, SsoUser replaceSsoUser, + Organization org, User user, SsoUserCompare equalityComparer, + List suts, List efUserRepos, + List efOrgRepos, SqlRepo.SsoUserRepository sqlSsoUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo, SqlRepo.UserRepository sqlUserRepo) { - var i = suts.IndexOf(sut); + var savedSsoUsers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - postSsoUser.UserId = efUser.Id; - postSsoUser.OrganizationId = efOrg.Id; - var postEfSsoUser = await sut.CreateAsync(postSsoUser); - sut.ClearChangeTracking(); + postSsoUser.UserId = efUser.Id; + postSsoUser.OrganizationId = efOrg.Id; + var postEfSsoUser = await sut.CreateAsync(postSsoUser); + sut.ClearChangeTracking(); - replaceSsoUser.Id = postEfSsoUser.Id; - replaceSsoUser.UserId = postEfSsoUser.UserId; - replaceSsoUser.OrganizationId = postEfSsoUser.OrganizationId; - await sut.ReplaceAsync(replaceSsoUser); - sut.ClearChangeTracking(); + replaceSsoUser.Id = postEfSsoUser.Id; + replaceSsoUser.UserId = postEfSsoUser.UserId; + replaceSsoUser.OrganizationId = postEfSsoUser.OrganizationId; + await sut.ReplaceAsync(replaceSsoUser); + sut.ClearChangeTracking(); - var replacedSsoUser = await sut.GetByIdAsync(replaceSsoUser.Id); - savedSsoUsers.Add(replacedSsoUser); + var replacedSsoUser = await sut.GetByIdAsync(replaceSsoUser.Id); + savedSsoUsers.Add(replacedSsoUser); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + + postSsoUser.UserId = sqlUser.Id; + postSsoUser.OrganizationId = sqlOrganization.Id; + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(postSsoUser); + + replaceSsoUser.Id = postSqlSsoUser.Id; + replaceSsoUser.UserId = postSqlSsoUser.UserId; + replaceSsoUser.OrganizationId = postSqlSsoUser.OrganizationId; + await sqlSsoUserRepo.ReplaceAsync(replaceSsoUser); + + savedSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(replaceSsoUser.Id)); + + var distinctItems = savedSsoUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - - postSsoUser.UserId = sqlUser.Id; - postSsoUser.OrganizationId = sqlOrganization.Id; - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(postSsoUser); - - replaceSsoUser.Id = postSqlSsoUser.Id; - replaceSsoUser.UserId = postSqlSsoUser.UserId; - replaceSsoUser.OrganizationId = postSqlSsoUser.OrganizationId; - await sqlSsoUserRepo.ReplaceAsync(replaceSsoUser); - - savedSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(replaceSsoUser.Id)); - - var distinctItems = savedSsoUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfSsoUserAutoData] - public async void DeleteAsync_Works_DataMatches(SsoUser ssoUser, Organization org, User user, List suts, - List efUserRepos, List efOrgRepos, - SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) - { - foreach (var sut in suts) + [CiSkippedTheory, EfSsoUserAutoData] + public async void DeleteAsync_Works_DataMatches(SsoUser ssoUser, Organization org, User user, List suts, + List efUserRepos, List efOrgRepos, + SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) { - var i = suts.IndexOf(sut); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var savedEfUser = await efUserRepos[i].CreateAsync(user); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfUser = await efUserRepos[i].CreateAsync(user); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoUser.UserId = savedEfUser.Id; - ssoUser.OrganizationId = savedEfOrg.Id; - var postEfSsoUser = await sut.CreateAsync(ssoUser); - sut.ClearChangeTracking(); + ssoUser.UserId = savedEfUser.Id; + ssoUser.OrganizationId = savedEfOrg.Id; + var postEfSsoUser = await sut.CreateAsync(ssoUser); + sut.ClearChangeTracking(); - var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); - Assert.True(savedEfSsoUser != null); - sut.ClearChangeTracking(); + var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); + Assert.True(savedEfSsoUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfSsoUser); - savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); - Assert.True(savedEfSsoUser == null); + await sut.DeleteAsync(savedEfSsoUser); + savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); + Assert.True(savedEfSsoUser == null); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser != null); + + await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser); + savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser == null); } - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser != null); - - await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser); - savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser == null); - } - - [CiSkippedTheory, EfSsoUserAutoData] - public async void DeleteAsync_UserIdOrganizationId_Works_DataMatches(SsoUser ssoUser, - User user, Organization org, List suts, - List efUserRepos, List efOrgRepos, - SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo - ) - { - foreach (var sut in suts) + [CiSkippedTheory, EfSsoUserAutoData] + public async void DeleteAsync_UserIdOrganizationId_Works_DataMatches(SsoUser ssoUser, + User user, Organization org, List suts, + List efUserRepos, List efOrgRepos, + SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo + ) { - var i = suts.IndexOf(sut); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var savedEfUser = await efUserRepos[i].CreateAsync(user); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfUser = await efUserRepos[i].CreateAsync(user); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoUser.UserId = savedEfUser.Id; - ssoUser.OrganizationId = savedEfOrg.Id; - var postEfSsoUser = await sut.CreateAsync(ssoUser); - sut.ClearChangeTracking(); + ssoUser.UserId = savedEfUser.Id; + ssoUser.OrganizationId = savedEfOrg.Id; + var postEfSsoUser = await sut.CreateAsync(ssoUser); + sut.ClearChangeTracking(); - var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); - Assert.True(savedEfSsoUser != null); - sut.ClearChangeTracking(); + var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); + Assert.True(savedEfSsoUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfSsoUser.UserId, savedEfSsoUser.OrganizationId); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfSsoUser.UserId, savedEfSsoUser.OrganizationId); + sut.ClearChangeTracking(); - savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); - Assert.True(savedEfSsoUser == null); + savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); + Assert.True(savedEfSsoUser == null); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser != null); + + await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser.UserId, savedSqlSsoUser.OrganizationId); + savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser == null); } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser != null); - - await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser.UserId, savedSqlSsoUser.OrganizationId); - savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser == null); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs index 8892f6c70d..d5616f78e5 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs @@ -6,34 +6,35 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class TaxRateRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfTaxRateAutoData] - public async void CreateAsync_Works_DataMatches( - TaxRate taxRate, - TaxRateCompare equalityComparer, - List suts, - SqlRepo.TaxRateRepository sqlTaxRateRepo - ) + public class TaxRateRepositoryTests { - var savedTaxRates = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfTaxRateAutoData] + public async void CreateAsync_Works_DataMatches( + TaxRate taxRate, + TaxRateCompare equalityComparer, + List suts, + SqlRepo.TaxRateRepository sqlTaxRateRepo + ) { - var i = suts.IndexOf(sut); - var postEfTaxRate = await sut.CreateAsync(taxRate); - sut.ClearChangeTracking(); + var savedTaxRates = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + var postEfTaxRate = await sut.CreateAsync(taxRate); + sut.ClearChangeTracking(); - var savedTaxRate = await sut.GetByIdAsync(postEfTaxRate.Id); - savedTaxRates.Add(savedTaxRate); + var savedTaxRate = await sut.GetByIdAsync(postEfTaxRate.Id); + savedTaxRates.Add(savedTaxRate); + } + + var sqlTaxRate = await sqlTaxRateRepo.CreateAsync(taxRate); + var savedSqlTaxRate = await sqlTaxRateRepo.GetByIdAsync(sqlTaxRate.Id); + savedTaxRates.Add(savedSqlTaxRate); + + var distinctItems = savedTaxRates.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var sqlTaxRate = await sqlTaxRateRepo.CreateAsync(taxRate); - var savedSqlTaxRate = await sqlTaxRateRepo.GetByIdAsync(sqlTaxRate.Id); - savedTaxRates.Add(savedSqlTaxRate); - - var distinctItems = savedTaxRates.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs index 2f0d2cd8aa..563a0377e3 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs @@ -6,58 +6,59 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class TransactionRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - - [CiSkippedTheory, EfUserTransactionAutoData, EfOrganizationTransactionAutoData] - public async void CreateAsync_Works_DataMatches( - Transaction transaction, - User user, - Organization org, - TransactionCompare equalityComparer, - List suts, - List efUserRepos, - List efOrgRepos, - SqlRepo.TransactionRepository sqlTransactionRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo - ) + public class TransactionRepositoryTests { - var savedTransactions = new List(); - foreach (var sut in suts) + + [CiSkippedTheory, EfUserTransactionAutoData, EfOrganizationTransactionAutoData] + public async void CreateAsync_Works_DataMatches( + Transaction transaction, + User user, + Organization org, + TransactionCompare equalityComparer, + List suts, + List efUserRepos, + List efOrgRepos, + SqlRepo.TransactionRepository sqlTransactionRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo + ) { - var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); + var savedTransactions = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + var efUser = await efUserRepos[i].CreateAsync(user); + if (transaction.OrganizationId.HasValue) + { + var efOrg = await efOrgRepos[i].CreateAsync(org); + transaction.OrganizationId = efOrg.Id; + } + sut.ClearChangeTracking(); + + transaction.UserId = efUser.Id; + var postEfTransaction = await sut.CreateAsync(transaction); + sut.ClearChangeTracking(); + + var savedTransaction = await sut.GetByIdAsync(postEfTransaction.Id); + savedTransactions.Add(savedTransaction); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); if (transaction.OrganizationId.HasValue) { - var efOrg = await efOrgRepos[i].CreateAsync(org); - transaction.OrganizationId = efOrg.Id; + var sqlOrg = await sqlOrgRepo.CreateAsync(org); + transaction.OrganizationId = sqlOrg.Id; } - sut.ClearChangeTracking(); - transaction.UserId = efUser.Id; - var postEfTransaction = await sut.CreateAsync(transaction); - sut.ClearChangeTracking(); + transaction.UserId = sqlUser.Id; + var sqlTransaction = await sqlTransactionRepo.CreateAsync(transaction); + var savedSqlTransaction = await sqlTransactionRepo.GetByIdAsync(sqlTransaction.Id); + savedTransactions.Add(savedSqlTransaction); - var savedTransaction = await sut.GetByIdAsync(postEfTransaction.Id); - savedTransactions.Add(savedTransaction); + var distinctItems = savedTransactions.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - if (transaction.OrganizationId.HasValue) - { - var sqlOrg = await sqlOrgRepo.CreateAsync(org); - transaction.OrganizationId = sqlOrg.Id; - } - - transaction.UserId = sqlUser.Id; - var sqlTransaction = await sqlTransactionRepo.CreateAsync(transaction); - var savedSqlTransaction = await sqlTransactionRepo.GetByIdAsync(sqlTransaction.Id); - savedTransactions.Add(savedSqlTransaction); - - var distinctItems = savedTransactions.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs index ce04ffdfba..d362cb954d 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs @@ -7,283 +7,284 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories; - -public class UserRepositoryTests +namespace Bit.Infrastructure.EFIntegration.Test.Repositories { - [CiSkippedTheory, EfUserAutoData] - public async void CreateAsync_Works_DataMatches( - User user, UserCompare equalityComparer, - List suts, - SqlRepo.UserRepository sqlUserRepo - ) + public class UserRepositoryTests { - var savedUsers = new List(); - - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void CreateAsync_Works_DataMatches( + User user, UserCompare equalityComparer, + List suts, + SqlRepo.UserRepository sqlUserRepo + ) { - var postEfUser = await sut.CreateAsync(user); + var savedUsers = new List(); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); - var savedUser = await sut.GetByIdAsync(postEfUser.Id); - savedUsers.Add(savedUser); + sut.ClearChangeTracking(); + + var savedUser = await sut.GetByIdAsync(postEfUser.Id); + savedUsers.Add(savedUser); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + savedUsers.Add(await sqlUserRepo.GetByIdAsync(sqlUser.Id)); + + var distinctItems = savedUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var sqlUser = await sqlUserRepo.CreateAsync(user); - savedUsers.Add(await sqlUserRepo.GetByIdAsync(sqlUser.Id)); - - var distinctItems = savedUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void ReplaceAsync_Works_DataMatches(User postUser, User replaceUser, - UserCompare equalityComparer, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var savedUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void ReplaceAsync_Works_DataMatches(User postUser, User replaceUser, + UserCompare equalityComparer, List suts, + SqlRepo.UserRepository sqlUserRepo) { - var postEfUser = await sut.CreateAsync(postUser); - replaceUser.Id = postEfUser.Id; - await sut.ReplaceAsync(replaceUser); - var replacedUser = await sut.GetByIdAsync(replaceUser.Id); - savedUsers.Add(replacedUser); + var savedUsers = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(postUser); + replaceUser.Id = postEfUser.Id; + await sut.ReplaceAsync(replaceUser); + var replacedUser = await sut.GetByIdAsync(replaceUser.Id); + savedUsers.Add(replacedUser); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(postUser); + replaceUser.Id = postSqlUser.Id; + await sqlUserRepo.ReplaceAsync(replaceUser); + savedUsers.Add(await sqlUserRepo.GetByIdAsync(replaceUser.Id)); + + var distinctItems = savedUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - var postSqlUser = await sqlUserRepo.CreateAsync(postUser); - replaceUser.Id = postSqlUser.Id; - await sqlUserRepo.ReplaceAsync(replaceUser); - savedUsers.Add(await sqlUserRepo.GetByIdAsync(replaceUser.Id)); - - var distinctItems = savedUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void DeleteAsync_Works_DataMatches(User user, List suts, SqlRepo.UserRepository sqlUserRepo) - { - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void DeleteAsync_Works_DataMatches(User user, List suts, SqlRepo.UserRepository sqlUserRepo) { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - var savedEfUser = await sut.GetByIdAsync(postEfUser.Id); - Assert.True(savedEfUser != null); - sut.ClearChangeTracking(); + var savedEfUser = await sut.GetByIdAsync(postEfUser.Id); + Assert.True(savedEfUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfUser); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfUser); + sut.ClearChangeTracking(); - savedEfUser = await sut.GetByIdAsync(savedEfUser.Id); - Assert.True(savedEfUser == null); + savedEfUser = await sut.GetByIdAsync(savedEfUser.Id); + Assert.True(savedEfUser == null); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); + Assert.True(savedSqlUser != null); + + await sqlUserRepo.DeleteAsync(postSqlUser); + savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); + Assert.True(savedSqlUser == null); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); - Assert.True(savedSqlUser != null); + [CiSkippedTheory, EfUserAutoData] + public async void GetByEmailAsync_Works_DataMatches(User user, UserCompare equalityComparer, + List suts, SqlRepo.UserRepository sqlUserRepo) + { + var savedUsers = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + var savedUser = await sut.GetByEmailAsync(postEfUser.Email.ToUpperInvariant()); + savedUsers.Add(savedUser); + } - await sqlUserRepo.DeleteAsync(postSqlUser); - savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); - Assert.True(savedSqlUser == null); - } + var postSqlUser = await sqlUserRepo.CreateAsync(user); + savedUsers.Add(await sqlUserRepo.GetByEmailAsync(postSqlUser.Email.ToUpperInvariant())); - [CiSkippedTheory, EfUserAutoData] - public async void GetByEmailAsync_Works_DataMatches(User user, UserCompare equalityComparer, + var distinctItems = savedUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetKdfInformationByEmailAsync_Works_DataMatches(User user, + UserKdfInformationCompare equalityComparer, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var savedKdfInformation = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + var kdfInformation = await sut.GetKdfInformationByEmailAsync(postEfUser.Email.ToUpperInvariant()); + savedKdfInformation.Add(kdfInformation); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var sqlKdfInformation = await sqlUserRepo.GetKdfInformationByEmailAsync(postSqlUser.Email); + savedKdfInformation.Add(sqlKdfInformation); + + var distinctItems = savedKdfInformation.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void SearchAsync_Works_DataMatches(User user, int skip, int take, + UserCompare equalityCompare, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var searchedEfUsers = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + var searchedEfUsersCollection = await sut.SearchAsync(postEfUser.Email.ToUpperInvariant(), skip, take); + searchedEfUsers.Concat(searchedEfUsersCollection.ToList()); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var searchedSqlUsers = await sqlUserRepo.SearchAsync(postSqlUser.Email.ToUpperInvariant(), skip, take); + + var distinctItems = searchedEfUsers.Concat(searchedSqlUsers).Distinct(equalityCompare); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetManyByPremiumAsync_Works_DataMatches(User user, List suts, SqlRepo.UserRepository sqlUserRepo) - { - var savedUsers = new List(); - foreach (var sut in suts) { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - var savedUser = await sut.GetByEmailAsync(postEfUser.Email.ToUpperInvariant()); - savedUsers.Add(savedUser); + var returnedUsers = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + var searchedEfUsers = await sut.GetManyByPremiumAsync(user.Premium); + returnedUsers.Concat(searchedEfUsers.ToList()); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var searchedSqlUsers = await sqlUserRepo.GetManyByPremiumAsync(user.Premium); + returnedUsers.Concat(searchedSqlUsers.ToList()); + + Assert.True(returnedUsers.All(x => x.Premium == user.Premium)); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - savedUsers.Add(await sqlUserRepo.GetByEmailAsync(postSqlUser.Email.ToUpperInvariant())); - - var distinctItems = savedUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetKdfInformationByEmailAsync_Works_DataMatches(User user, - UserKdfInformationCompare equalityComparer, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var savedKdfInformation = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void GetPublicKeyAsync_Works_DataMatches(User user, List suts, + SqlRepo.UserRepository sqlUserRepo) { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - var kdfInformation = await sut.GetKdfInformationByEmailAsync(postEfUser.Email.ToUpperInvariant()); - savedKdfInformation.Add(kdfInformation); + var returnedKeys = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); + returnedKeys.Add(efKey); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); + returnedKeys.Add(sqlKey); + + Assert.True(!returnedKeys.Distinct().Skip(1).Any()); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var sqlKdfInformation = await sqlUserRepo.GetKdfInformationByEmailAsync(postSqlUser.Email); - savedKdfInformation.Add(sqlKdfInformation); - - var distinctItems = savedKdfInformation.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void SearchAsync_Works_DataMatches(User user, int skip, int take, - UserCompare equalityCompare, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var searchedEfUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void GetAccountRevisionDateAsync(User user, List suts, + SqlRepo.UserRepository sqlUserRepo) { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + var returnedKeys = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - var searchedEfUsersCollection = await sut.SearchAsync(postEfUser.Email.ToUpperInvariant(), skip, take); - searchedEfUsers.Concat(searchedEfUsersCollection.ToList()); + var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); + returnedKeys.Add(efKey); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); + returnedKeys.Add(sqlKey); + + Assert.True(!returnedKeys.Distinct().Skip(1).Any()); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var searchedSqlUsers = await sqlUserRepo.SearchAsync(postSqlUser.Email.ToUpperInvariant(), skip, take); - - var distinctItems = searchedEfUsers.Concat(searchedSqlUsers).Distinct(equalityCompare); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetManyByPremiumAsync_Works_DataMatches(User user, - List suts, SqlRepo.UserRepository sqlUserRepo) - { - var returnedUsers = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void UpdateRenewalReminderDateAsync_Works_DataMatches(User user, + DateTime updatedReminderDate, List suts, + SqlRepo.UserRepository sqlUserRepo) { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + var savedDates = new List(); + foreach (var sut in suts) + { + var postEfUser = user; + postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - var searchedEfUsers = await sut.GetManyByPremiumAsync(user.Premium); - returnedUsers.Concat(searchedEfUsers.ToList()); + await sut.UpdateRenewalReminderDateAsync(postEfUser.Id, updatedReminderDate); + sut.ClearChangeTracking(); + + var replacedUser = await sut.GetByIdAsync(postEfUser.Id); + savedDates.Add(replacedUser.RenewalReminderDate); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + await sqlUserRepo.UpdateRenewalReminderDateAsync(postSqlUser.Id, updatedReminderDate); + var replacedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); + savedDates.Add(replacedSqlUser.RenewalReminderDate); + + var distinctItems = savedDates.GroupBy(e => e.ToString()); + Assert.True(!distinctItems.Skip(1).Any() && + savedDates.All(e => e.ToString() == updatedReminderDate.ToString())); } - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var searchedSqlUsers = await sqlUserRepo.GetManyByPremiumAsync(user.Premium); - returnedUsers.Concat(searchedSqlUsers.ToList()); - - Assert.True(returnedUsers.All(x => x.Premium == user.Premium)); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetPublicKeyAsync_Works_DataMatches(User user, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var returnedKeys = new List(); - foreach (var sut in suts) + [CiSkippedTheory, EfUserAutoData] + public async void GetBySsoUserAsync_Works_DataMatches(User user, Organization org, + SsoUser ssoUser, UserCompare equalityComparer, List suts, + List ssoUserRepos, List orgRepos, + SqlRepo.UserRepository sqlUserRepo, SqlRepo.SsoUserRepository sqlSsoUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + var returnedList = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); - var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); - returnedKeys.Add(efKey); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + var efOrg = await orgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); + + ssoUser.UserId = postEfUser.Id; + ssoUser.OrganizationId = efOrg.Id; + var postEfSsoUser = await ssoUserRepos[i].CreateAsync(ssoUser); + sut.ClearChangeTracking(); + + var returnedUser = await sut.GetBySsoUserAsync(postEfSsoUser.ExternalId.ToUpperInvariant(), efOrg.Id); + returnedList.Add(returnedUser); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + + var returnedSqlUser = await sqlUserRepo + .GetBySsoUserAsync(postSqlSsoUser.ExternalId, sqlOrganization.Id); + returnedList.Add(returnedSqlUser); + + var distinctItems = returnedList.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); - returnedKeys.Add(sqlKey); - - Assert.True(!returnedKeys.Distinct().Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetAccountRevisionDateAsync(User user, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var returnedKeys = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); - returnedKeys.Add(efKey); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); - returnedKeys.Add(sqlKey); - - Assert.True(!returnedKeys.Distinct().Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void UpdateRenewalReminderDateAsync_Works_DataMatches(User user, - DateTime updatedReminderDate, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var savedDates = new List(); - foreach (var sut in suts) - { - var postEfUser = user; - postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - await sut.UpdateRenewalReminderDateAsync(postEfUser.Id, updatedReminderDate); - sut.ClearChangeTracking(); - - var replacedUser = await sut.GetByIdAsync(postEfUser.Id); - savedDates.Add(replacedUser.RenewalReminderDate); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - await sqlUserRepo.UpdateRenewalReminderDateAsync(postSqlUser.Id, updatedReminderDate); - var replacedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); - savedDates.Add(replacedSqlUser.RenewalReminderDate); - - var distinctItems = savedDates.GroupBy(e => e.ToString()); - Assert.True(!distinctItems.Skip(1).Any() && - savedDates.All(e => e.ToString() == updatedReminderDate.ToString())); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetBySsoUserAsync_Works_DataMatches(User user, Organization org, - SsoUser ssoUser, UserCompare equalityComparer, List suts, - List ssoUserRepos, List orgRepos, - SqlRepo.UserRepository sqlUserRepo, SqlRepo.SsoUserRepository sqlSsoUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) - { - var returnedList = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - var efOrg = await orgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); - - ssoUser.UserId = postEfUser.Id; - ssoUser.OrganizationId = efOrg.Id; - var postEfSsoUser = await ssoUserRepos[i].CreateAsync(ssoUser); - sut.ClearChangeTracking(); - - var returnedUser = await sut.GetBySsoUserAsync(postEfSsoUser.ExternalId.ToUpperInvariant(), efOrg.Id); - returnedList.Add(returnedUser); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - - var returnedSqlUser = await sqlUserRepo - .GetBySsoUserAsync(postSqlSsoUser.ExternalId, sqlOrganization.Id); - returnedList.Add(returnedSqlUser); - - var distinctItems = returnedList.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs index 501ded6134..0a8741b554 100644 --- a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs +++ b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs @@ -7,39 +7,40 @@ using Bit.Identity; using Bit.Test.Common.Helpers; using Microsoft.AspNetCore.Http; -namespace Bit.IntegrationTestCommon.Factories; - -public class IdentityApplicationFactory : WebApplicationFactoryBase +namespace Bit.IntegrationTestCommon.Factories { - public const string DefaultDeviceIdentifier = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - - public async Task RegisterAsync(RegisterRequestModel model) + public class IdentityApplicationFactory : WebApplicationFactoryBase { - return await Server.PostAsync("/accounts/register", JsonContent.Create(model)); - } + public const string DefaultDeviceIdentifier = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - public async Task<(string Token, string RefreshToken)> TokenFromPasswordAsync(string username, - string password, - string deviceIdentifier = DefaultDeviceIdentifier, - string clientId = "web", - DeviceType deviceType = DeviceType.FirefoxBrowser, - string deviceName = "firefox") - { - var context = await Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + public async Task RegisterAsync(RegisterRequestModel model) { - { "scope", "api offline_access" }, - { "client_id", clientId }, - { "deviceType", ((int)deviceType).ToString() }, - { "deviceIdentifier", deviceIdentifier }, - { "deviceName", deviceName }, - { "grant_type", "password" }, - { "username", username }, - { "password", password }, - }), context => context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username))); + return await Server.PostAsync("/accounts/register", JsonContent.Create(model)); + } - using var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; + public async Task<(string Token, string RefreshToken)> TokenFromPasswordAsync(string username, + string password, + string deviceIdentifier = DefaultDeviceIdentifier, + string clientId = "web", + DeviceType deviceType = DeviceType.FirefoxBrowser, + string deviceName = "firefox") + { + var context = await Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", clientId }, + { "deviceType", ((int)deviceType).ToString() }, + { "deviceIdentifier", deviceIdentifier }, + { "deviceName", deviceName }, + { "grant_type", "password" }, + { "username", username }, + { "password", password }, + }), context => context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username))); - return (root.GetProperty("access_token").GetString(), root.GetProperty("refresh_token").GetString()); + using var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + return (root.GetProperty("access_token").GetString(), root.GetProperty("refresh_token").GetString()); + } } } diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs index 45a1454ae7..04b4c0de48 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs @@ -9,102 +9,103 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; -namespace Bit.IntegrationTestCommon.Factories; - -public static class FactoryConstants +namespace Bit.IntegrationTestCommon.Factories { - public const string DefaultDatabaseName = "test_database"; - public const string WhitelistedIp = "1.1.1.1"; -} - -public abstract class WebApplicationFactoryBase : WebApplicationFactory - where T : class -{ - /// - /// The database name to use for this instance of the factory. By default it will use a shared database name so all instances will connect to the same database during it's lifetime. - /// - /// - /// This will need to be set BEFORE using the Server property - /// - public string DatabaseName { get; set; } = FactoryConstants.DefaultDatabaseName; - - /// - /// Configure the web host to use an EF in memory database - /// - protected override void ConfigureWebHost(IWebHostBuilder builder) + public static class FactoryConstants { - builder.ConfigureAppConfiguration(c => + public const string DefaultDatabaseName = "test_database"; + public const string WhitelistedIp = "1.1.1.1"; + } + + public abstract class WebApplicationFactoryBase : WebApplicationFactory + where T : class + { + /// + /// The database name to use for this instance of the factory. By default it will use a shared database name so all instances will connect to the same database during it's lifetime. + /// + /// + /// This will need to be set BEFORE using the Server property + /// + public string DatabaseName { get; set; } = FactoryConstants.DefaultDatabaseName; + + /// + /// Configure the web host to use an EF in memory database + /// + protected override void ConfigureWebHost(IWebHostBuilder builder) { - c.SetBasePath(AppContext.BaseDirectory) - .AddJsonFile("appsettings.json") - .AddJsonFile("appsettings.Development.json"); - - c.AddUserSecrets(typeof(Identity.Startup).Assembly, optional: true); - c.AddInMemoryCollection(new Dictionary + builder.ConfigureAppConfiguration(c => { - // Manually insert a EF provider so that ConfigureServices will add EF repositories but we will override - // DbContextOptions to use an in memory database - { "globalSettings:databaseProvider", "postgres" }, - { "globalSettings:postgreSql:connectionString", "Host=localhost;Username=test;Password=test;Database=test" }, + c.SetBasePath(AppContext.BaseDirectory) + .AddJsonFile("appsettings.json") + .AddJsonFile("appsettings.Development.json"); - // Clear the redis connection string for distributed caching, forcing an in-memory implementation - { "globalSettings:redis:connectionString", ""} - }); - }); - - builder.ConfigureTestServices(services => - { - var dbContextOptions = services.First(sd => sd.ServiceType == typeof(DbContextOptions)); - services.Remove(dbContextOptions); - services.AddScoped(_ => - { - return new DbContextOptionsBuilder() - .UseInMemoryDatabase(DatabaseName) - .Options; - }); - - // QUESTION: The normal licensing service should run fine on developer machines but not in CI - // should we have a fork here to leave the normal service for developers? - // TODO: Eventually add the license file to CI - var licensingService = services.First(sd => sd.ServiceType == typeof(ILicensingService)); - services.Remove(licensingService); - services.AddSingleton(); - - // FUTURE CONSIDERATION: Add way to run this self hosted/cloud, for now it is cloud only - var pushRegistrationService = services.First(sd => sd.ServiceType == typeof(IPushRegistrationService)); - services.Remove(pushRegistrationService); - services.AddSingleton(); - - // Even though we are cloud we currently set this up as cloud, we can use the EF/selfhosted service - // instead of using Noop for this service - // TODO: Install and use azurite in CI pipeline - var eventWriteService = services.First(sd => sd.ServiceType == typeof(IEventWriteService)); - services.Remove(eventWriteService); - services.AddSingleton(); - - var eventRepositoryService = services.First(sd => sd.ServiceType == typeof(IEventRepository)); - services.Remove(eventRepositoryService); - services.AddSingleton(); - - // Our Rate limiter works so well that it begins to fail tests unless we carve out - // one whitelisted ip. We should still test the rate limiter though and they should change the Ip - // to something that is NOT whitelisted - services.Configure(options => - { - options.IpWhitelist = new List + c.AddUserSecrets(typeof(Identity.Startup).Assembly, optional: true); + c.AddInMemoryCollection(new Dictionary { - FactoryConstants.WhitelistedIp, - }; + // Manually insert a EF provider so that ConfigureServices will add EF repositories but we will override + // DbContextOptions to use an in memory database + { "globalSettings:databaseProvider", "postgres" }, + { "globalSettings:postgreSql:connectionString", "Host=localhost;Username=test;Password=test;Database=test" }, + + // Clear the redis connection string for distributed caching, forcing an in-memory implementation + { "globalSettings:redis:connectionString", ""} + }); }); - // Fix IP Rate Limiting - services.AddSingleton(); - }); - } + builder.ConfigureTestServices(services => + { + var dbContextOptions = services.First(sd => sd.ServiceType == typeof(DbContextOptions)); + services.Remove(dbContextOptions); + services.AddScoped(_ => + { + return new DbContextOptionsBuilder() + .UseInMemoryDatabase(DatabaseName) + .Options; + }); - public DatabaseContext GetDatabaseContext() - { - var scope = Services.CreateScope(); - return scope.ServiceProvider.GetRequiredService(); + // QUESTION: The normal licensing service should run fine on developer machines but not in CI + // should we have a fork here to leave the normal service for developers? + // TODO: Eventually add the license file to CI + var licensingService = services.First(sd => sd.ServiceType == typeof(ILicensingService)); + services.Remove(licensingService); + services.AddSingleton(); + + // FUTURE CONSIDERATION: Add way to run this self hosted/cloud, for now it is cloud only + var pushRegistrationService = services.First(sd => sd.ServiceType == typeof(IPushRegistrationService)); + services.Remove(pushRegistrationService); + services.AddSingleton(); + + // Even though we are cloud we currently set this up as cloud, we can use the EF/selfhosted service + // instead of using Noop for this service + // TODO: Install and use azurite in CI pipeline + var eventWriteService = services.First(sd => sd.ServiceType == typeof(IEventWriteService)); + services.Remove(eventWriteService); + services.AddSingleton(); + + var eventRepositoryService = services.First(sd => sd.ServiceType == typeof(IEventRepository)); + services.Remove(eventRepositoryService); + services.AddSingleton(); + + // Our Rate limiter works so well that it begins to fail tests unless we carve out + // one whitelisted ip. We should still test the rate limiter though and they should change the Ip + // to something that is NOT whitelisted + services.Configure(options => + { + options.IpWhitelist = new List + { + FactoryConstants.WhitelistedIp, + }; + }); + + // Fix IP Rate Limiting + services.AddSingleton(); + }); + } + + public DatabaseContext GetDatabaseContext() + { + var scope = Services.CreateScope(); + return scope.ServiceProvider.GetRequiredService(); + } } } diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs index 88fc21006d..ed428a772a 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs @@ -4,64 +4,65 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.Primitives; -namespace Bit.IntegrationTestCommon.Factories; - -public static class WebApplicationFactoryExtensions +namespace Bit.IntegrationTestCommon.Factories { - private static async Task SendAsync(this TestServer server, - HttpMethod method, - string requestUri, - HttpContent content = null, - Action extraConfiguration = null) + public static class WebApplicationFactoryExtensions { - return await server.SendAsync(httpContext => + private static async Task SendAsync(this TestServer server, + HttpMethod method, + string requestUri, + HttpContent content = null, + Action extraConfiguration = null) { - // Automatically set the whitelisted IP so normal tests do not run into rate limit issues - // to test rate limiter, use the extraConfiguration parameter to set Connection.RemoteIpAddress - // it runs after this so it will take precedence. - httpContext.Connection.RemoteIpAddress = IPAddress.Parse(FactoryConstants.WhitelistedIp); - - httpContext.Request.Path = new PathString(requestUri); - httpContext.Request.Method = method.Method; - - if (content != null) + return await server.SendAsync(httpContext => { - foreach (var header in content.Headers) + // Automatically set the whitelisted IP so normal tests do not run into rate limit issues + // to test rate limiter, use the extraConfiguration parameter to set Connection.RemoteIpAddress + // it runs after this so it will take precedence. + httpContext.Connection.RemoteIpAddress = IPAddress.Parse(FactoryConstants.WhitelistedIp); + + httpContext.Request.Path = new PathString(requestUri); + httpContext.Request.Method = method.Method; + + if (content != null) { - httpContext.Request.Headers.Add(header.Key, new StringValues(header.Value.ToArray())); + foreach (var header in content.Headers) + { + httpContext.Request.Headers.Add(header.Key, new StringValues(header.Value.ToArray())); + } + + httpContext.Request.Body = content.ReadAsStream(); } - httpContext.Request.Body = content.ReadAsStream(); - } + extraConfiguration?.Invoke(httpContext); + }); + } + public static Task PostAsync(this TestServer server, + string requestUri, + HttpContent content, + Action extraConfiguration = null) + => SendAsync(server, HttpMethod.Post, requestUri, content, extraConfiguration); + public static Task GetAsync(this TestServer server, + string requestUri, + Action extraConfiguration = null) + => SendAsync(server, HttpMethod.Get, requestUri, content: null, extraConfiguration); - extraConfiguration?.Invoke(httpContext); - }); - } - public static Task PostAsync(this TestServer server, - string requestUri, - HttpContent content, - Action extraConfiguration = null) - => SendAsync(server, HttpMethod.Post, requestUri, content, extraConfiguration); - public static Task GetAsync(this TestServer server, - string requestUri, - Action extraConfiguration = null) - => SendAsync(server, HttpMethod.Get, requestUri, content: null, extraConfiguration); + public static HttpContext SetAuthEmail(this HttpContext context, string username) + { + context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username)); + return context; + } - public static HttpContext SetAuthEmail(this HttpContext context, string username) - { - context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username)); - return context; - } + public static HttpContext SetIp(this HttpContext context, string ip) + { + context.Connection.RemoteIpAddress = IPAddress.Parse(ip); + return context; + } - public static HttpContext SetIp(this HttpContext context, string ip) - { - context.Connection.RemoteIpAddress = IPAddress.Parse(ip); - return context; - } - - public static async Task ReadBodyAsStringAsync(this HttpContext context) - { - using var sr = new StreamReader(context.Response.Body); - return await sr.ReadToEndAsync(); + public static async Task ReadBodyAsStringAsync(this HttpContext context) + { + using var sr = new StreamReader(context.Response.Body); + return await sr.ReadToEndAsync(); + } } } diff --git a/util/EfShared/MigrationBuilderExtensions.cs b/util/EfShared/MigrationBuilderExtensions.cs index cb9fad33c7..aa8fc04a84 100644 --- a/util/EfShared/MigrationBuilderExtensions.cs +++ b/util/EfShared/MigrationBuilderExtensions.cs @@ -4,28 +4,29 @@ using System.Runtime.CompilerServices; using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit; - -// This file is a manual addition to a project that it helps, a project that chooses to compile it -// should have a projet reference to Core.csproj and a package reference to Microsoft.EntityFrameworkCore.Design -// The reason for this is that if it belonged to it's own library you would have to add manual references to the above -// and manage the version for the EntityFrameworkCore package. This way it also doesn't create another dll -// To include this you can view examples in the MySqlMigrations and PostgresMigrations .csproj files. -// - -public static class MigrationBuilderExtensions +namespace Bit { - /// - /// Reads an embedded resource for it's SQL contents and formats it with the specified direction for easier custom migration steps - /// - /// The MigrationBuilder instance the sql should be applied to - /// The file name portion of the resource name, it is assumed to be in a Scripts folder - /// The direction of the migration taking place - public static void SqlResource(this MigrationBuilder migrationBuilder, string resourceName, [CallerMemberName] string dir = null) - { - var formattedResourceName = string.IsNullOrEmpty(dir) ? resourceName : string.Format(resourceName, dir); + // This file is a manual addition to a project that it helps, a project that chooses to compile it + // should have a projet reference to Core.csproj and a package reference to Microsoft.EntityFrameworkCore.Design + // The reason for this is that if it belonged to it's own library you would have to add manual references to the above + // and manage the version for the EntityFrameworkCore package. This way it also doesn't create another dll + // To include this you can view examples in the MySqlMigrations and PostgresMigrations .csproj files. + // - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync( - $"Scripts.{formattedResourceName}")); + public static class MigrationBuilderExtensions + { + /// + /// Reads an embedded resource for it's SQL contents and formats it with the specified direction for easier custom migration steps + /// + /// The MigrationBuilder instance the sql should be applied to + /// The file name portion of the resource name, it is assumed to be in a Scripts folder + /// The direction of the migration taking place + public static void SqlResource(this MigrationBuilder migrationBuilder, string resourceName, [CallerMemberName] string dir = null) + { + var formattedResourceName = string.IsNullOrEmpty(dir) ? resourceName : string.Format(resourceName, dir); + + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync( + $"Scripts.{formattedResourceName}")); + } } } diff --git a/util/Migrator/DbMigrator.cs b/util/Migrator/DbMigrator.cs index ad62691fc3..d0463a00ee 100644 --- a/util/Migrator/DbMigrator.cs +++ b/util/Migrator/DbMigrator.cs @@ -5,103 +5,104 @@ using Bit.Core; using DbUp; using Microsoft.Extensions.Logging; -namespace Bit.Migrator; - -public class DbMigrator +namespace Bit.Migrator { - private readonly string _connectionString; - private readonly ILogger _logger; - private readonly string _masterConnectionString; - - public DbMigrator(string connectionString, ILogger logger) + public class DbMigrator { - _connectionString = connectionString; - _logger = logger; - _masterConnectionString = new SqlConnectionStringBuilder(connectionString) - { - InitialCatalog = "master" - }.ConnectionString; - } + private readonly string _connectionString; + private readonly ILogger _logger; + private readonly string _masterConnectionString; - public bool MigrateMsSqlDatabase(bool enableLogging = true, - CancellationToken cancellationToken = default(CancellationToken)) - { - if (enableLogging && _logger != null) + public DbMigrator(string connectionString, ILogger logger) { - _logger.LogInformation(Constants.BypassFiltersEventId, "Migrating database."); + _connectionString = connectionString; + _logger = logger; + _masterConnectionString = new SqlConnectionStringBuilder(connectionString) + { + InitialCatalog = "master" + }.ConnectionString; } - using (var connection = new SqlConnection(_masterConnectionString)) + public bool MigrateMsSqlDatabase(bool enableLogging = true, + CancellationToken cancellationToken = default(CancellationToken)) { - var databaseName = new SqlConnectionStringBuilder(_connectionString).InitialCatalog; - if (string.IsNullOrWhiteSpace(databaseName)) + if (enableLogging && _logger != null) { - databaseName = "vault"; + _logger.LogInformation(Constants.BypassFiltersEventId, "Migrating database."); } - var databaseNameQuoted = new SqlCommandBuilder().QuoteIdentifier(databaseName); - var command = new SqlCommand( - "IF ((SELECT COUNT(1) FROM sys.databases WHERE [name] = @DatabaseName) = 0) " + - "CREATE DATABASE " + databaseNameQuoted + ";", connection); - command.Parameters.Add("@DatabaseName", SqlDbType.VarChar).Value = databaseName; - command.Connection.Open(); - command.ExecuteNonQuery(); + using (var connection = new SqlConnection(_masterConnectionString)) + { + var databaseName = new SqlConnectionStringBuilder(_connectionString).InitialCatalog; + if (string.IsNullOrWhiteSpace(databaseName)) + { + databaseName = "vault"; + } - command.CommandText = "IF ((SELECT DATABASEPROPERTYEX([name], 'IsAutoClose') " + - "FROM sys.databases WHERE [name] = @DatabaseName) = 1) " + - "ALTER DATABASE " + databaseNameQuoted + " SET AUTO_CLOSE OFF;"; - command.ExecuteNonQuery(); + var databaseNameQuoted = new SqlCommandBuilder().QuoteIdentifier(databaseName); + var command = new SqlCommand( + "IF ((SELECT COUNT(1) FROM sys.databases WHERE [name] = @DatabaseName) = 0) " + + "CREATE DATABASE " + databaseNameQuoted + ";", connection); + command.Parameters.Add("@DatabaseName", SqlDbType.VarChar).Value = databaseName; + command.Connection.Open(); + command.ExecuteNonQuery(); + + command.CommandText = "IF ((SELECT DATABASEPROPERTYEX([name], 'IsAutoClose') " + + "FROM sys.databases WHERE [name] = @DatabaseName) = 1) " + + "ALTER DATABASE " + databaseNameQuoted + " SET AUTO_CLOSE OFF;"; + command.ExecuteNonQuery(); + } + + cancellationToken.ThrowIfCancellationRequested(); + using (var connection = new SqlConnection(_connectionString)) + { + // Rename old migration scripts to new namespace. + var command = new SqlCommand( + "IF OBJECT_ID('Migration','U') IS NOT NULL " + + "UPDATE [dbo].[Migration] SET " + + "[ScriptName] = REPLACE([ScriptName], 'Bit.Setup.', 'Bit.Migrator.');", connection); + command.Connection.Open(); + command.ExecuteNonQuery(); + } + + cancellationToken.ThrowIfCancellationRequested(); + var builder = DeployChanges.To + .SqlDatabase(_connectionString) + .JournalToSqlTable("dbo", "Migration") + .WithScriptsAndCodeEmbeddedInAssembly(Assembly.GetExecutingAssembly(), + s => s.Contains($".DbScripts.") && !s.Contains(".Archive.")) + .WithTransaction() + .WithExecutionTimeout(new TimeSpan(0, 5, 0)); + + if (enableLogging) + { + if (_logger != null) + { + builder.LogTo(new DbUpLogger(_logger)); + } + else + { + builder.LogToConsole(); + } + } + + var upgrader = builder.Build(); + var result = upgrader.PerformUpgrade(); + + if (enableLogging && _logger != null) + { + if (result.Successful) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Migration successful."); + } + else + { + _logger.LogError(Constants.BypassFiltersEventId, result.Error, "Migration failed."); + } + } + + cancellationToken.ThrowIfCancellationRequested(); + return result.Successful; } - - cancellationToken.ThrowIfCancellationRequested(); - using (var connection = new SqlConnection(_connectionString)) - { - // Rename old migration scripts to new namespace. - var command = new SqlCommand( - "IF OBJECT_ID('Migration','U') IS NOT NULL " + - "UPDATE [dbo].[Migration] SET " + - "[ScriptName] = REPLACE([ScriptName], 'Bit.Setup.', 'Bit.Migrator.');", connection); - command.Connection.Open(); - command.ExecuteNonQuery(); - } - - cancellationToken.ThrowIfCancellationRequested(); - var builder = DeployChanges.To - .SqlDatabase(_connectionString) - .JournalToSqlTable("dbo", "Migration") - .WithScriptsAndCodeEmbeddedInAssembly(Assembly.GetExecutingAssembly(), - s => s.Contains($".DbScripts.") && !s.Contains(".Archive.")) - .WithTransaction() - .WithExecutionTimeout(new TimeSpan(0, 5, 0)); - - if (enableLogging) - { - if (_logger != null) - { - builder.LogTo(new DbUpLogger(_logger)); - } - else - { - builder.LogToConsole(); - } - } - - var upgrader = builder.Build(); - var result = upgrader.PerformUpgrade(); - - if (enableLogging && _logger != null) - { - if (result.Successful) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Migration successful."); - } - else - { - _logger.LogError(Constants.BypassFiltersEventId, result.Error, "Migration failed."); - } - } - - cancellationToken.ThrowIfCancellationRequested(); - return result.Successful; } } diff --git a/util/Migrator/DbUpLogger.cs b/util/Migrator/DbUpLogger.cs index a65b3ec0ed..1c1707d214 100644 --- a/util/Migrator/DbUpLogger.cs +++ b/util/Migrator/DbUpLogger.cs @@ -2,29 +2,30 @@ using DbUp.Engine.Output; using Microsoft.Extensions.Logging; -namespace Bit.Migrator; - -public class DbUpLogger : IUpgradeLog +namespace Bit.Migrator { - private readonly ILogger _logger; - - public DbUpLogger(ILogger logger) + public class DbUpLogger : IUpgradeLog { - _logger = logger; - } + private readonly ILogger _logger; - public void WriteError(string format, params object[] args) - { - _logger.LogError(Constants.BypassFiltersEventId, format, args); - } + public DbUpLogger(ILogger logger) + { + _logger = logger; + } - public void WriteInformation(string format, params object[] args) - { - _logger.LogInformation(Constants.BypassFiltersEventId, format, args); - } + public void WriteError(string format, params object[] args) + { + _logger.LogError(Constants.BypassFiltersEventId, format, args); + } - public void WriteWarning(string format, params object[] args) - { - _logger.LogWarning(Constants.BypassFiltersEventId, format, args); + public void WriteInformation(string format, params object[] args) + { + _logger.LogInformation(Constants.BypassFiltersEventId, format, args); + } + + public void WriteWarning(string format, params object[] args) + { + _logger.LogWarning(Constants.BypassFiltersEventId, format, args); + } } } diff --git a/util/MySqlMigrations/Factories.cs b/util/MySqlMigrations/Factories.cs index 538c39612f..734b88dd8c 100644 --- a/util/MySqlMigrations/Factories.cs +++ b/util/MySqlMigrations/Factories.cs @@ -4,34 +4,35 @@ using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Design; using Microsoft.Extensions.Configuration; -namespace MySqlMigrations; - -public static class GlobalSettingsFactory +namespace MySqlMigrations { - public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); - static GlobalSettingsFactory() + public static class GlobalSettingsFactory { - var configBuilder = new ConfigurationBuilder().AddUserSecrets(); - var Configuration = configBuilder.Build(); - ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); - } -} - -public class DatabaseContextFactory : IDesignTimeDbContextFactory -{ - public DatabaseContext CreateDbContext(string[] args) - { - var globalSettings = GlobalSettingsFactory.GlobalSettings; - var optionsBuilder = new DbContextOptionsBuilder(); - var connectionString = globalSettings.MySql?.ConnectionString; - if (string.IsNullOrWhiteSpace(connectionString)) + public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); + static GlobalSettingsFactory() { - throw new Exception("No MySql connection string found."); + var configBuilder = new ConfigurationBuilder().AddUserSecrets(); + var Configuration = configBuilder.Build(); + ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); + } + } + + public class DatabaseContextFactory : IDesignTimeDbContextFactory + { + public DatabaseContext CreateDbContext(string[] args) + { + var globalSettings = GlobalSettingsFactory.GlobalSettings; + var optionsBuilder = new DbContextOptionsBuilder(); + var connectionString = globalSettings.MySql?.ConnectionString; + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new Exception("No MySql connection string found."); + } + optionsBuilder.UseMySql( + connectionString, + ServerVersion.AutoDetect(connectionString), + b => b.MigrationsAssembly("MySqlMigrations")); + return new DatabaseContext(optionsBuilder.Options); } - optionsBuilder.UseMySql( - connectionString, - ServerVersion.AutoDetect(connectionString), - b => b.MigrationsAssembly("MySqlMigrations")); - return new DatabaseContext(optionsBuilder.Options); } } diff --git a/util/MySqlMigrations/Migrations/20210617183900_Init.cs b/util/MySqlMigrations/Migrations/20210617183900_Init.cs index d85ad6a1ec..859091b726 100644 --- a/util/MySqlMigrations/Migrations/20210617183900_Init.cs +++ b/util/MySqlMigrations/Migrations/20210617183900_Init.cs @@ -1,1128 +1,1129 @@ using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class Init : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class Init : Migration { - migrationBuilder.AlterDatabase() - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Event", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Date = table.Column(type: "datetime(6)", nullable: false), - Type = table.Column(type: "int", nullable: false), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - CipherId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - CollectionId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - PolicyId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - GroupId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - DeviceType = table.Column(type: "tinyint unsigned", nullable: true), - IpAddress = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ActingUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") - }, - constraints: table => - { - table.PrimaryKey("PK_Event", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Grant", - columns: table => new - { - Key = table.Column(type: "varchar(200)", maxLength: 200, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - Type = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - SubjectId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - SessionId = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ClientId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Description = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - ExpirationDate = table.Column(type: "datetime(6)", nullable: true), - ConsumedDate = table.Column(type: "datetime(6)", nullable: true), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_Grant", x => x.Key); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Installation", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "varchar(150)", maxLength: 150, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Installation", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Organization", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessName = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress1 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress2 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress3 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessCountry = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessTaxNumber = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BillingEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Plan = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PlanType = table.Column(type: "tinyint unsigned", nullable: false), - Seats = table.Column(type: "int", nullable: true), - MaxCollections = table.Column(type: "smallint", nullable: true), - UsePolicies = table.Column(type: "tinyint(1)", nullable: false), - UseSso = table.Column(type: "tinyint(1)", nullable: false), - UseGroups = table.Column(type: "tinyint(1)", nullable: false), - UseDirectory = table.Column(type: "tinyint(1)", nullable: false), - UseEvents = table.Column(type: "tinyint(1)", nullable: false), - UseTotp = table.Column(type: "tinyint(1)", nullable: false), - Use2fa = table.Column(type: "tinyint(1)", nullable: false), - UseApi = table.Column(type: "tinyint(1)", nullable: false), - UseResetPassword = table.Column(type: "tinyint(1)", nullable: false), - SelfHost = table.Column(type: "tinyint(1)", nullable: false), - UsersGetPremium = table.Column(type: "tinyint(1)", nullable: false), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "tinyint unsigned", nullable: true), - GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ReferenceData = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PublicKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PrivateKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - TwoFactorProviders = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ExpirationDate = table.Column(type: "datetime(6)", nullable: true), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Organization", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Provider", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessName = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress1 = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress2 = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress3 = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessCountry = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessTaxNumber = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BillingEmail = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Status = table.Column(type: "tinyint unsigned", nullable: false), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Provider", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "TaxRate", - columns: table => new - { - Id = table.Column(type: "varchar(40)", maxLength: 40, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - Country = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - State = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PostalCode = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Rate = table.Column(type: "decimal(65,30)", nullable: false), - Active = table.Column(type: "tinyint(1)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_TaxRate", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "User", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - EmailVerified = table.Column(type: "tinyint(1)", nullable: false), - MasterPassword = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - MasterPasswordHint = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Culture = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - SecurityStamp = table.Column(type: "varchar(50)", maxLength: 50, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - TwoFactorProviders = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - TwoFactorRecoveryCode = table.Column(type: "varchar(32)", maxLength: 32, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - EquivalentDomains = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ExcludedGlobalEquivalentDomains = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - AccountRevisionDate = table.Column(type: "datetime(6)", nullable: false), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PublicKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PrivateKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Premium = table.Column(type: "tinyint(1)", nullable: false), - PremiumExpirationDate = table.Column(type: "datetime(6)", nullable: true), - RenewalReminderDate = table.Column(type: "datetime(6)", nullable: true), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "tinyint unsigned", nullable: true), - GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ReferenceData = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - Kdf = table.Column(type: "tinyint unsigned", nullable: false), - KdfIterations = table.Column(type: "int", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_User", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Collection", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Collection", x => x.Id); - table.ForeignKey( - name: "FK_Collection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Group", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - AccessAll = table.Column(type: "tinyint(1)", nullable: false), - ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Group", x => x.Id); - table.ForeignKey( - name: "FK_Group_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Policy", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Policy", x => x.Id); - table.ForeignKey( - name: "FK_Policy_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "SsoConfig", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoConfig", x => x.Id); - table.ForeignKey( - name: "FK_SsoConfig_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "ProviderOrganization", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Settings = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganization", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganization_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganization_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Cipher", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Favorites = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Folders = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Attachments = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - DeletedDate = table.Column(type: "datetime(6)", nullable: true), - Reprompt = table.Column(type: "tinyint unsigned", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Cipher", x => x.Id); - table.ForeignKey( - name: "FK_Cipher_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Cipher_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Device", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PushToken = table.Column(type: "varchar(255)", maxLength: 255, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Device", x => x.Id); - table.ForeignKey( - name: "FK_Device_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "EmergencyAccess", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - GrantorId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - GranteeId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - KeyEncrypted = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Status = table.Column(type: "tinyint unsigned", nullable: false), - WaitTimeDays = table.Column(type: "int", nullable: false), - RecoveryInitiatedDate = table.Column(type: "datetime(6)", nullable: true), - LastNotificationDate = table.Column(type: "datetime(6)", nullable: true), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_EmergencyAccess", x => x.Id); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GranteeId", - column: x => x.GranteeId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GrantorId", - column: x => x.GrantorId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Folder", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Folder", x => x.Id); - table.ForeignKey( - name: "FK_Folder_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "OrganizationUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ResetPasswordKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Status = table.Column(type: "tinyint unsigned", nullable: false), - Type = table.Column(type: "tinyint unsigned", nullable: false), - AccessAll = table.Column(type: "tinyint(1)", nullable: false), - ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationUser", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_OrganizationUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "ProviderUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Email = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Status = table.Column(type: "tinyint unsigned", nullable: false), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderUser_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Send", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Password = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - MaxAccessCount = table.Column(type: "int", nullable: true), - AccessCount = table.Column(type: "int", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - ExpirationDate = table.Column(type: "datetime(6)", nullable: true), - DeletionDate = table.Column(type: "datetime(6)", nullable: false), - Disabled = table.Column(type: "tinyint(1)", nullable: false), - HideEmail = table.Column(type: "tinyint(1)", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Send", x => x.Id); - table.ForeignKey( - name: "FK_Send_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Send_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "SsoUser", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - ExternalId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoUser", x => x.Id); - table.ForeignKey( - name: "FK_SsoUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_SsoUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Transaction", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Amount = table.Column(type: "decimal(65,30)", nullable: false), - Refunded = table.Column(type: "tinyint(1)", nullable: true), - RefundedAmount = table.Column(type: "decimal(65,30)", nullable: true), - Details = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PaymentMethodType = table.Column(type: "tinyint unsigned", nullable: true), - Gateway = table.Column(type: "tinyint unsigned", nullable: true), - GatewayId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Transaction", x => x.Id); - table.ForeignKey( - name: "FK_Transaction_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Transaction_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "int", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "CollectionGroups", - columns: table => new - { - CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ReadOnly = table.Column(type: "tinyint(1)", nullable: false), - HidePasswords = table.Column(type: "tinyint(1)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); - table.ForeignKey( - name: "FK_CollectionGroups_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionGroups_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "CollectionCipher", - columns: table => new - { - CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - CipherId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci") - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); - table.ForeignKey( - name: "FK_CollectionCipher_Cipher_CipherId", - column: x => x.CipherId, - principalTable: "Cipher", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionCipher_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "CollectionUsers", - columns: table => new - { - CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - ReadOnly = table.Column(type: "tinyint(1)", nullable: false), - HidePasswords = table.Column(type: "tinyint(1)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_CollectionUsers_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "GroupUser", - columns: table => new - { - GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") - }, - constraints: table => - { - table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_GroupUser_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_OrganizationId", - table: "Cipher", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_UserId", - table: "Cipher", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Collection_OrganizationId", - table: "Collection", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionCipher_CipherId", - table: "CollectionCipher", - column: "CipherId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionGroups_GroupId", - table: "CollectionGroups", - column: "GroupId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_OrganizationUserId", - table: "CollectionUsers", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_UserId", - table: "CollectionUsers", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Device_UserId", - table: "Device", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GranteeId", - table: "EmergencyAccess", - column: "GranteeId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GrantorId", - table: "EmergencyAccess", - column: "GrantorId"); - - migrationBuilder.CreateIndex( - name: "IX_Folder_UserId", - table: "Folder", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Group_OrganizationId", - table: "Group", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_OrganizationUserId", - table: "GroupUser", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_UserId", - table: "GroupUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_OrganizationId", - table: "OrganizationUser", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_UserId", - table: "OrganizationUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Policy_OrganizationId", - table: "Policy", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_OrganizationId", - table: "ProviderOrganization", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_ProviderId", - table: "ProviderOrganization", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_ProviderId", - table: "ProviderUser", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_UserId", - table: "ProviderUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_OrganizationId", - table: "Send", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_UserId", - table: "Send", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_SsoConfig_OrganizationId", - table: "SsoConfig", - column: "OrganizationId"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AlterDatabase() + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Event", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Date = table.Column(type: "datetime(6)", nullable: false), + Type = table.Column(type: "int", nullable: false), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + CipherId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + CollectionId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + PolicyId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + GroupId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + DeviceType = table.Column(type: "tinyint unsigned", nullable: true), + IpAddress = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ActingUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") + }, + constraints: table => + { + table.PrimaryKey("PK_Event", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Grant", + columns: table => new + { + Key = table.Column(type: "varchar(200)", maxLength: 200, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + Type = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + SubjectId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + SessionId = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ClientId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Description = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + ExpirationDate = table.Column(type: "datetime(6)", nullable: true), + ConsumedDate = table.Column(type: "datetime(6)", nullable: true), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_Grant", x => x.Key); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Installation", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "varchar(150)", maxLength: 150, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Installation", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Organization", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessName = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress1 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress2 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress3 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessCountry = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessTaxNumber = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BillingEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Plan = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PlanType = table.Column(type: "tinyint unsigned", nullable: false), + Seats = table.Column(type: "int", nullable: true), + MaxCollections = table.Column(type: "smallint", nullable: true), + UsePolicies = table.Column(type: "tinyint(1)", nullable: false), + UseSso = table.Column(type: "tinyint(1)", nullable: false), + UseGroups = table.Column(type: "tinyint(1)", nullable: false), + UseDirectory = table.Column(type: "tinyint(1)", nullable: false), + UseEvents = table.Column(type: "tinyint(1)", nullable: false), + UseTotp = table.Column(type: "tinyint(1)", nullable: false), + Use2fa = table.Column(type: "tinyint(1)", nullable: false), + UseApi = table.Column(type: "tinyint(1)", nullable: false), + UseResetPassword = table.Column(type: "tinyint(1)", nullable: false), + SelfHost = table.Column(type: "tinyint(1)", nullable: false), + UsersGetPremium = table.Column(type: "tinyint(1)", nullable: false), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "tinyint unsigned", nullable: true), + GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ReferenceData = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PublicKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PrivateKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + TwoFactorProviders = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ExpirationDate = table.Column(type: "datetime(6)", nullable: true), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Organization", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Provider", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessName = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress1 = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress2 = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress3 = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessCountry = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessTaxNumber = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BillingEmail = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Status = table.Column(type: "tinyint unsigned", nullable: false), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Provider", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "TaxRate", + columns: table => new + { + Id = table.Column(type: "varchar(40)", maxLength: 40, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + Country = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + State = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PostalCode = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Rate = table.Column(type: "decimal(65,30)", nullable: false), + Active = table.Column(type: "tinyint(1)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_TaxRate", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "User", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + EmailVerified = table.Column(type: "tinyint(1)", nullable: false), + MasterPassword = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + MasterPasswordHint = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Culture = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + SecurityStamp = table.Column(type: "varchar(50)", maxLength: 50, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + TwoFactorProviders = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + TwoFactorRecoveryCode = table.Column(type: "varchar(32)", maxLength: 32, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + EquivalentDomains = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ExcludedGlobalEquivalentDomains = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + AccountRevisionDate = table.Column(type: "datetime(6)", nullable: false), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PublicKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PrivateKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Premium = table.Column(type: "tinyint(1)", nullable: false), + PremiumExpirationDate = table.Column(type: "datetime(6)", nullable: true), + RenewalReminderDate = table.Column(type: "datetime(6)", nullable: true), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "tinyint unsigned", nullable: true), + GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ReferenceData = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + Kdf = table.Column(type: "tinyint unsigned", nullable: false), + KdfIterations = table.Column(type: "int", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_User", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Collection", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Collection", x => x.Id); + table.ForeignKey( + name: "FK_Collection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Group", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + AccessAll = table.Column(type: "tinyint(1)", nullable: false), + ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Group", x => x.Id); + table.ForeignKey( + name: "FK_Group_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Policy", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Policy", x => x.Id); + table.ForeignKey( + name: "FK_Policy_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "SsoConfig", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoConfig", x => x.Id); + table.ForeignKey( + name: "FK_SsoConfig_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "ProviderOrganization", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Settings = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganization", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganization_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganization_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Cipher", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Favorites = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Folders = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Attachments = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + DeletedDate = table.Column(type: "datetime(6)", nullable: true), + Reprompt = table.Column(type: "tinyint unsigned", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Cipher", x => x.Id); + table.ForeignKey( + name: "FK_Cipher_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Cipher_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Device", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PushToken = table.Column(type: "varchar(255)", maxLength: 255, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Device", x => x.Id); + table.ForeignKey( + name: "FK_Device_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "EmergencyAccess", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + GrantorId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + GranteeId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + KeyEncrypted = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Status = table.Column(type: "tinyint unsigned", nullable: false), + WaitTimeDays = table.Column(type: "int", nullable: false), + RecoveryInitiatedDate = table.Column(type: "datetime(6)", nullable: true), + LastNotificationDate = table.Column(type: "datetime(6)", nullable: true), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_EmergencyAccess", x => x.Id); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GranteeId", + column: x => x.GranteeId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GrantorId", + column: x => x.GrantorId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Folder", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Folder", x => x.Id); + table.ForeignKey( + name: "FK_Folder_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "OrganizationUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ResetPasswordKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Status = table.Column(type: "tinyint unsigned", nullable: false), + Type = table.Column(type: "tinyint unsigned", nullable: false), + AccessAll = table.Column(type: "tinyint(1)", nullable: false), + ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationUser", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_OrganizationUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "ProviderUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Email = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Status = table.Column(type: "tinyint unsigned", nullable: false), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderUser_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Send", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Password = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + MaxAccessCount = table.Column(type: "int", nullable: true), + AccessCount = table.Column(type: "int", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + ExpirationDate = table.Column(type: "datetime(6)", nullable: true), + DeletionDate = table.Column(type: "datetime(6)", nullable: false), + Disabled = table.Column(type: "tinyint(1)", nullable: false), + HideEmail = table.Column(type: "tinyint(1)", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Send", x => x.Id); + table.ForeignKey( + name: "FK_Send_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Send_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "SsoUser", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + ExternalId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoUser", x => x.Id); + table.ForeignKey( + name: "FK_SsoUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_SsoUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Transaction", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Amount = table.Column(type: "decimal(65,30)", nullable: false), + Refunded = table.Column(type: "tinyint(1)", nullable: true), + RefundedAmount = table.Column(type: "decimal(65,30)", nullable: true), + Details = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PaymentMethodType = table.Column(type: "tinyint unsigned", nullable: true), + Gateway = table.Column(type: "tinyint unsigned", nullable: true), + GatewayId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Transaction", x => x.Id); + table.ForeignKey( + name: "FK_Transaction_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Transaction_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "int", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "CollectionGroups", + columns: table => new + { + CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ReadOnly = table.Column(type: "tinyint(1)", nullable: false), + HidePasswords = table.Column(type: "tinyint(1)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); + table.ForeignKey( + name: "FK_CollectionGroups_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionGroups_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "CollectionCipher", + columns: table => new + { + CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + CipherId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci") + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); + table.ForeignKey( + name: "FK_CollectionCipher_Cipher_CipherId", + column: x => x.CipherId, + principalTable: "Cipher", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionCipher_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "CollectionUsers", + columns: table => new + { + CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + ReadOnly = table.Column(type: "tinyint(1)", nullable: false), + HidePasswords = table.Column(type: "tinyint(1)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_CollectionUsers_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "GroupUser", + columns: table => new + { + GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") + }, + constraints: table => + { + table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_GroupUser_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_OrganizationId", + table: "Cipher", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_UserId", + table: "Cipher", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Collection_OrganizationId", + table: "Collection", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionCipher_CipherId", + table: "CollectionCipher", + column: "CipherId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionGroups_GroupId", + table: "CollectionGroups", + column: "GroupId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_OrganizationUserId", + table: "CollectionUsers", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_UserId", + table: "CollectionUsers", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Device_UserId", + table: "Device", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GranteeId", + table: "EmergencyAccess", + column: "GranteeId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GrantorId", + table: "EmergencyAccess", + column: "GrantorId"); + + migrationBuilder.CreateIndex( + name: "IX_Folder_UserId", + table: "Folder", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Group_OrganizationId", + table: "Group", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_OrganizationUserId", + table: "GroupUser", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_UserId", + table: "GroupUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_OrganizationId", + table: "OrganizationUser", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_UserId", + table: "OrganizationUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Policy_OrganizationId", + table: "Policy", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_OrganizationId", + table: "ProviderOrganization", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_ProviderId", + table: "ProviderOrganization", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_ProviderId", + table: "ProviderUser", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_UserId", + table: "ProviderUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_OrganizationId", + table: "Send", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_UserId", + table: "Send", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_SsoConfig_OrganizationId", + table: "SsoConfig", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_OrganizationId", - table: "SsoUser", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_OrganizationId", + table: "SsoUser", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_UserId", - table: "SsoUser", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_UserId", + table: "SsoUser", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_OrganizationId", - table: "Transaction", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_OrganizationId", + table: "Transaction", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_UserId", - table: "Transaction", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_UserId", + table: "Transaction", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); - } + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "CollectionCipher"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "CollectionCipher"); - migrationBuilder.DropTable( - name: "CollectionGroups"); + migrationBuilder.DropTable( + name: "CollectionGroups"); - migrationBuilder.DropTable( - name: "CollectionUsers"); + migrationBuilder.DropTable( + name: "CollectionUsers"); - migrationBuilder.DropTable( - name: "Device"); + migrationBuilder.DropTable( + name: "Device"); - migrationBuilder.DropTable( - name: "EmergencyAccess"); + migrationBuilder.DropTable( + name: "EmergencyAccess"); - migrationBuilder.DropTable( - name: "Event"); + migrationBuilder.DropTable( + name: "Event"); - migrationBuilder.DropTable( - name: "Folder"); + migrationBuilder.DropTable( + name: "Folder"); - migrationBuilder.DropTable( - name: "Grant"); + migrationBuilder.DropTable( + name: "Grant"); - migrationBuilder.DropTable( - name: "GroupUser"); + migrationBuilder.DropTable( + name: "GroupUser"); - migrationBuilder.DropTable( - name: "Installation"); + migrationBuilder.DropTable( + name: "Installation"); - migrationBuilder.DropTable( - name: "Policy"); + migrationBuilder.DropTable( + name: "Policy"); - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.DropTable( - name: "Send"); + migrationBuilder.DropTable( + name: "Send"); - migrationBuilder.DropTable( - name: "SsoConfig"); + migrationBuilder.DropTable( + name: "SsoConfig"); - migrationBuilder.DropTable( - name: "SsoUser"); + migrationBuilder.DropTable( + name: "SsoUser"); - migrationBuilder.DropTable( - name: "TaxRate"); + migrationBuilder.DropTable( + name: "TaxRate"); - migrationBuilder.DropTable( - name: "Transaction"); + migrationBuilder.DropTable( + name: "Transaction"); - migrationBuilder.DropTable( - name: "U2f"); + migrationBuilder.DropTable( + name: "U2f"); - migrationBuilder.DropTable( - name: "Cipher"); + migrationBuilder.DropTable( + name: "Cipher"); - migrationBuilder.DropTable( - name: "Collection"); + migrationBuilder.DropTable( + name: "Collection"); - migrationBuilder.DropTable( - name: "Group"); + migrationBuilder.DropTable( + name: "Group"); - migrationBuilder.DropTable( - name: "OrganizationUser"); + migrationBuilder.DropTable( + name: "OrganizationUser"); - migrationBuilder.DropTable( - name: "ProviderOrganization"); + migrationBuilder.DropTable( + name: "ProviderOrganization"); - migrationBuilder.DropTable( - name: "ProviderUser"); + migrationBuilder.DropTable( + name: "ProviderUser"); - migrationBuilder.DropTable( - name: "Organization"); + migrationBuilder.DropTable( + name: "Organization"); - migrationBuilder.DropTable( - name: "Provider"); + migrationBuilder.DropTable( + name: "Provider"); - migrationBuilder.DropTable( - name: "User"); + migrationBuilder.DropTable( + name: "User"); + } } } diff --git a/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs b/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs index cc53719fa3..cf28c6fce3 100644 --- a/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs +++ b/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs @@ -1,89 +1,90 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class RemoveProviderOrganizationProviderUser : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class RemoveProviderOrganizationProviderUser : Migration { - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.AddColumn( - name: "UseEvents", - table: "Provider", - type: "tinyint(1)", - nullable: false, - defaultValue: false); + migrationBuilder.AddColumn( + name: "UseEvents", + table: "Provider", + type: "tinyint(1)", + nullable: false, + defaultValue: false); - migrationBuilder.AddColumn( - name: "ProviderId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); + migrationBuilder.AddColumn( + name: "ProviderId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); - migrationBuilder.AddColumn( - name: "ProviderUserId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); - } + migrationBuilder.AddColumn( + name: "ProviderUserId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseEvents", - table: "Provider"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseEvents", + table: "Provider"); - migrationBuilder.DropColumn( - name: "ProviderId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderId", + table: "Event"); - migrationBuilder.DropColumn( - name: "ProviderUserId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderUserId", + table: "Event"); - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - Type = table.Column(type: "tinyint unsigned", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + Type = table.Column(type: "tinyint unsigned", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); + } } } diff --git a/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs b/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs index 1c385968c6..762aa05467 100644 --- a/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs +++ b/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class UserForcePasswordReset : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class UserForcePasswordReset : Migration { - migrationBuilder.AddColumn( - name: "ForcePasswordReset", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ForcePasswordReset", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "ForcePasswordReset", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "ForcePasswordReset", + table: "User"); + } } } diff --git a/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs b/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs index 168667ecba..2ae51dc4a5 100644 --- a/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs +++ b/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs @@ -1,43 +1,44 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class AddMaxAutoscaleSeatsToOrganization : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class AddMaxAutoscaleSeatsToOrganization : Migration { - migrationBuilder.AddColumn( - name: "MaxAutoscaleSeats", - table: "Organization", - type: "int", - nullable: true); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "MaxAutoscaleSeats", + table: "Organization", + type: "int", + nullable: true); - migrationBuilder.AddColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization", - type: "datetime(6)", - nullable: true); + migrationBuilder.AddColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization", + type: "datetime(6)", + nullable: true); - migrationBuilder.AddColumn( - name: "ProviderOrganizationId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); - } + migrationBuilder.AddColumn( + name: "ProviderOrganizationId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "MaxAutoscaleSeats", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "MaxAutoscaleSeats", + table: "Organization"); - migrationBuilder.DropColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization"); + migrationBuilder.DropColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization"); - migrationBuilder.DropColumn( - name: "ProviderOrganizationId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderOrganizationId", + table: "Event"); + } } } diff --git a/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs b/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs index 19817d1285..8884d33400 100644 --- a/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs +++ b/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class SplitManageCollectionsPermissions2 : Migration +namespace Bit.MySqlMigrations.Migrations { - private const string _scriptLocation = - "MySqlMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.sql"; - - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SplitManageCollectionsPermissions2 : Migration { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } + private const string _scriptLocation = + "MySqlMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.sql"; - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); + } } } diff --git a/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs b/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs index 00574ab65f..af746bcccc 100644 --- a/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs +++ b/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration +namespace Bit.MySqlMigrations.Migrations { - private const string _scriptLocation = - "MySqlMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.sql"; - - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } + private const string _scriptLocation = + "MySqlMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.sql"; - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); + } } } diff --git a/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs b/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs index 59ed36ff4a..3e48440fa7 100644 --- a/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs +++ b/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class KeyConnector : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class KeyConnector : Migration { - migrationBuilder.AddColumn( - name: "UsesKeyConnector", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UsesKeyConnector", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UsesKeyConnector", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UsesKeyConnector", + table: "User"); + } } } diff --git a/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs b/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs index 155ecfc008..68dc13557c 100644 --- a/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs +++ b/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs @@ -1,84 +1,85 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class OrganizationSponsorship : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class OrganizationSponsorship : Migration { - migrationBuilder.AddColumn( - name: "UsesCryptoAgent", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: false); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UsesCryptoAgent", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: false); - migrationBuilder.CreateTable( - name: "OrganizationSponsorship", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - InstallationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - SponsoringOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - SponsoringOrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - SponsoredOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - FriendlyName = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - OfferedToEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PlanSponsorshipType = table.Column(type: "tinyint unsigned", nullable: true), - CloudSponsor = table.Column(type: "tinyint(1)", nullable: false), - LastSyncDate = table.Column(type: "datetime(6)", nullable: true), - TimesRenewedWithoutValidation = table.Column(type: "tinyint unsigned", nullable: false), - SponsorshipLapsedDate = table.Column(type: "datetime(6)", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - column: x => x.InstallationId, - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", - column: x => x.SponsoredOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - column: x => x.SponsoringOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.CreateTable( + name: "OrganizationSponsorship", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + InstallationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + SponsoringOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + SponsoringOrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + SponsoredOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + FriendlyName = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + OfferedToEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PlanSponsorshipType = table.Column(type: "tinyint unsigned", nullable: true), + CloudSponsor = table.Column(type: "tinyint(1)", nullable: false), + LastSyncDate = table.Column(type: "datetime(6)", nullable: true), + TimesRenewedWithoutValidation = table.Column(type: "tinyint unsigned", nullable: false), + SponsorshipLapsedDate = table.Column(type: "datetime(6)", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + column: x => x.InstallationId, + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", + column: x => x.SponsoredOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + column: x => x.SponsoringOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoredOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoredOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoredOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoredOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "UsesCryptoAgent", - table: "User"); + migrationBuilder.DropColumn( + name: "UsesCryptoAgent", + table: "User"); + } } } diff --git a/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs b/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs index 62d924f5c5..d68eb65a17 100644 --- a/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs +++ b/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class KeyConnectorFlag : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class KeyConnectorFlag : Migration { - migrationBuilder.AddColumn( - name: "UseKeyConnector", - table: "Organization", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseKeyConnector", + table: "Organization", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseKeyConnector", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseKeyConnector", + table: "Organization"); + } } } diff --git a/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs b/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs index 8950ac914f..8d9250fee3 100644 --- a/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs +++ b/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs @@ -1,50 +1,51 @@ using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class RemoveU2F : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class RemoveU2F : Migration { - migrationBuilder.DropTable( - name: "U2f"); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "U2f"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "int", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "int", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); + } } } diff --git a/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs b/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs index 91245fc462..f7c8bcc57c 100644 --- a/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs +++ b/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs @@ -1,33 +1,34 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class FailedLoginCaptcha : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class FailedLoginCaptcha : Migration { - migrationBuilder.AddColumn( - name: "FailedLoginCount", - table: "User", - type: "int", - nullable: false, - defaultValue: 0); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "FailedLoginCount", + table: "User", + type: "int", + nullable: false, + defaultValue: 0); - migrationBuilder.AddColumn( - name: "LastFailedLoginDate", - table: "User", - type: "datetime(6)", - nullable: true); - } + migrationBuilder.AddColumn( + name: "LastFailedLoginDate", + table: "User", + type: "datetime(6)", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "FailedLoginCount", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "FailedLoginCount", + table: "User"); - migrationBuilder.DropColumn( - name: "LastFailedLoginDate", - table: "User"); + migrationBuilder.DropColumn( + name: "LastFailedLoginDate", + table: "User"); + } } } diff --git a/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs b/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs index 993399e502..c6f9c89344 100644 --- a/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs +++ b/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs @@ -1,157 +1,158 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class SelfHostF4E : Migration +namespace Bit.MySqlMigrations.Migrations { - private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.sql"; - - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SelfHostF4E : Migration { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship"); + private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.sql"; - migrationBuilder.DropIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.CreateTable( - name: "OrganizationApiKey", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationApiKey_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.DropColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship"); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.CreateTable( + name: "OrganizationApiKey", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationApiKey_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.DropColumn( - name: "ApiKey", - table: "Organization"); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.RenameColumn( - name: "SponsorshipLapsedDate", - table: "OrganizationSponsorship", - newName: "ValidUntil"); + migrationBuilder.DropColumn( + name: "ApiKey", + table: "Organization"); - migrationBuilder.RenameColumn( - name: "CloudSponsor", - table: "OrganizationSponsorship", - newName: "ToDelete"); + migrationBuilder.RenameColumn( + name: "SponsorshipLapsedDate", + table: "OrganizationSponsorship", + newName: "ValidUntil"); + + migrationBuilder.RenameColumn( + name: "CloudSponsor", + table: "OrganizationSponsorship", + newName: "ToDelete"); - migrationBuilder.CreateTable( - name: "OrganizationConnection", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - Config = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationConnection", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationConnection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.CreateTable( + name: "OrganizationConnection", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + Config = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationConnection", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationConnection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationApiKey_OrganizationId", - table: "OrganizationApiKey", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationApiKey_OrganizationId", + table: "OrganizationApiKey", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationConnection_OrganizationId", - table: "OrganizationConnection", - column: "OrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationConnection_OrganizationId", + table: "OrganizationConnection", + column: "OrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "ApiKey", - table: "Organization", - type: "varchar(30)", - maxLength: 30, - nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ApiKey", + table: "Organization", + type: "varchar(30)", + maxLength: 30, + nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.DropTable( - name: "OrganizationApiKey"); + migrationBuilder.DropTable( + name: "OrganizationApiKey"); - migrationBuilder.DropTable( - name: "OrganizationConnection"); + migrationBuilder.DropTable( + name: "OrganizationConnection"); - migrationBuilder.RenameColumn( - name: "ValidUntil", - table: "OrganizationSponsorship", - newName: "SponsorshipLapsedDate"); + migrationBuilder.RenameColumn( + name: "ValidUntil", + table: "OrganizationSponsorship", + newName: "SponsorshipLapsedDate"); - migrationBuilder.RenameColumn( - name: "ToDelete", - table: "OrganizationSponsorship", - newName: "CloudSponsor"); + migrationBuilder.RenameColumn( + name: "ToDelete", + table: "OrganizationSponsorship", + newName: "CloudSponsor"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); - migrationBuilder.AddColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship", - type: "tinyint unsigned", - nullable: false, - defaultValue: (byte)0); + migrationBuilder.AddColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship", + type: "tinyint unsigned", + nullable: false, + defaultValue: (byte)0); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId", - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId", + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } } } diff --git a/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs b/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs index 9b66e00cda..30e31e0153 100644 --- a/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs +++ b/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs @@ -1,80 +1,81 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class SponsorshipBulkActions : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SponsorshipBulkActions : Migration { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)", - oldNullable: true) - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)", + oldNullable: true) + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)", - oldNullable: true) - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)", + oldNullable: true) + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)") - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)") + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)") - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)") + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } } } diff --git a/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs b/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs index d07b0e41d5..77b3b5a69c 100644 --- a/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs +++ b/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs @@ -1,69 +1,70 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class AddInstallationIdToEvents : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class AddInstallationIdToEvents : Migration { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)") - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)") + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "Event"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "Event"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)", - oldNullable: true) - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)", + oldNullable: true) + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + } } } diff --git a/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs b/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs index 2ce7dadf28..23017d4da1 100644 --- a/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs +++ b/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class DeviceUnknownVerification : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class DeviceUnknownVerification : Migration { - migrationBuilder.AddColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: true); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User"); + } } } diff --git a/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs b/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs index d0c5caf370..52b10dcb3e 100644 --- a/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs +++ b/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs @@ -1,28 +1,29 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations; - -public partial class DeactivatedUserStatus : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class DeactivatedUserStatus : Migration { - migrationBuilder.AlterColumn( - name: "Status", - table: "OrganizationUser", - type: "smallint", - nullable: false, - oldClrType: typeof(byte), - oldType: "tinyint unsigned"); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AlterColumn( + name: "Status", + table: "OrganizationUser", + type: "smallint", + nullable: false, + oldClrType: typeof(byte), + oldType: "tinyint unsigned"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.AlterColumn( - name: "Status", - table: "OrganizationUser", - type: "tinyint unsigned", - nullable: false, - oldClrType: typeof(short), - oldType: "smallint"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.AlterColumn( + name: "Status", + table: "OrganizationUser", + type: "tinyint unsigned", + nullable: false, + oldClrType: typeof(short), + oldType: "smallint"); + } } } diff --git a/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs b/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs index c0033e60d8..4c29986125 100644 --- a/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs +++ b/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs @@ -2,24 +2,25 @@ #nullable disable -namespace Bit.MySqlMigrations.Migrations; - -public partial class UseScimFlag : Migration +namespace Bit.MySqlMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class UseScimFlag : Migration { - migrationBuilder.AddColumn( - name: "UseScim", - table: "Organization", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseScim", + table: "Organization", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseScim", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseScim", + table: "Organization"); + } } } diff --git a/util/PostgresMigrations/Factories.cs b/util/PostgresMigrations/Factories.cs index 5504fe58b3..532dddf739 100644 --- a/util/PostgresMigrations/Factories.cs +++ b/util/PostgresMigrations/Factories.cs @@ -4,33 +4,34 @@ using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Design; using Microsoft.Extensions.Configuration; -namespace MySqlMigrations; - -public static class GlobalSettingsFactory +namespace MySqlMigrations { - public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); - static GlobalSettingsFactory() + public static class GlobalSettingsFactory { - var configBuilder = new ConfigurationBuilder().AddUserSecrets(); - var Configuration = configBuilder.Build(); - ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); - } -} - -public class DatabaseContextFactory : IDesignTimeDbContextFactory -{ - public DatabaseContext CreateDbContext(string[] args) - { - var globalSettings = GlobalSettingsFactory.GlobalSettings; - var optionsBuilder = new DbContextOptionsBuilder(); - var connectionString = globalSettings.PostgreSql?.ConnectionString; - if (string.IsNullOrWhiteSpace(connectionString)) + public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); + static GlobalSettingsFactory() { - throw new Exception("No Postgres connection string found."); + var configBuilder = new ConfigurationBuilder().AddUserSecrets(); + var Configuration = configBuilder.Build(); + ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); + } + } + + public class DatabaseContextFactory : IDesignTimeDbContextFactory + { + public DatabaseContext CreateDbContext(string[] args) + { + var globalSettings = GlobalSettingsFactory.GlobalSettings; + var optionsBuilder = new DbContextOptionsBuilder(); + var connectionString = globalSettings.PostgreSql?.ConnectionString; + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new Exception("No Postgres connection string found."); + } + optionsBuilder.UseNpgsql( + connectionString, + b => b.MigrationsAssembly("PostgresMigrations")); + return new DatabaseContext(optionsBuilder.Options); } - optionsBuilder.UseNpgsql( - connectionString, - b => b.MigrationsAssembly("PostgresMigrations")); - return new DatabaseContext(optionsBuilder.Options); } } diff --git a/util/PostgresMigrations/Migrations/20210708191531_Init.cs b/util/PostgresMigrations/Migrations/20210708191531_Init.cs index 068e292ce3..3de407ff2e 100644 --- a/util/PostgresMigrations/Migrations/20210708191531_Init.cs +++ b/util/PostgresMigrations/Migrations/20210708191531_Init.cs @@ -1,1007 +1,1008 @@ using Microsoft.EntityFrameworkCore.Migrations; using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; -namespace Bit.PostgresMigrations.Migrations; - -public partial class Init : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class Init : Migration { - migrationBuilder.AlterDatabase() - .Annotation("Npgsql:CollationDefinition:postgresIndetermanisticCollation", "en-u-ks-primary,en-u-ks-primary,icu,False"); - - migrationBuilder.CreateTable( - name: "Event", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Date = table.Column(type: "timestamp without time zone", nullable: false), - Type = table.Column(type: "integer", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - CipherId = table.Column(type: "uuid", nullable: true), - CollectionId = table.Column(type: "uuid", nullable: true), - PolicyId = table.Column(type: "uuid", nullable: true), - GroupId = table.Column(type: "uuid", nullable: true), - OrganizationUserId = table.Column(type: "uuid", nullable: true), - DeviceType = table.Column(type: "smallint", nullable: true), - IpAddress = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - ActingUserId = table.Column(type: "uuid", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Event", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Grant", - columns: table => new - { - Key = table.Column(type: "character varying(200)", maxLength: 200, nullable: false), - Type = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - SubjectId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - SessionId = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - ClientId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - Description = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - ConsumedDate = table.Column(type: "timestamp without time zone", nullable: true), - Data = table.Column(type: "text", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Grant", x => x.Key); - }); - - migrationBuilder.CreateTable( - name: "Installation", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - Key = table.Column(type: "character varying(150)", maxLength: 150, nullable: true), - Enabled = table.Column(type: "boolean", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Installation", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Organization", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), - Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessName = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessAddress1 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessAddress2 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessAddress3 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessCountry = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), - BusinessTaxNumber = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), - BillingEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - Plan = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - PlanType = table.Column(type: "smallint", nullable: false), - Seats = table.Column(type: "integer", nullable: true), - MaxCollections = table.Column(type: "smallint", nullable: true), - UsePolicies = table.Column(type: "boolean", nullable: false), - UseSso = table.Column(type: "boolean", nullable: false), - UseGroups = table.Column(type: "boolean", nullable: false), - UseDirectory = table.Column(type: "boolean", nullable: false), - UseEvents = table.Column(type: "boolean", nullable: false), - UseTotp = table.Column(type: "boolean", nullable: false), - Use2fa = table.Column(type: "boolean", nullable: false), - UseApi = table.Column(type: "boolean", nullable: false), - UseResetPassword = table.Column(type: "boolean", nullable: false), - SelfHost = table.Column(type: "boolean", nullable: false), - UsersGetPremium = table.Column(type: "boolean", nullable: false), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "smallint", nullable: true), - GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - ReferenceData = table.Column(type: "text", nullable: true), - Enabled = table.Column(type: "boolean", nullable: false), - LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), - PublicKey = table.Column(type: "text", nullable: true), - PrivateKey = table.Column(type: "text", nullable: true), - TwoFactorProviders = table.Column(type: "text", nullable: true), - ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Organization", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Provider", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "text", nullable: true), - BusinessName = table.Column(type: "text", nullable: true), - BusinessAddress1 = table.Column(type: "text", nullable: true), - BusinessAddress2 = table.Column(type: "text", nullable: true), - BusinessAddress3 = table.Column(type: "text", nullable: true), - BusinessCountry = table.Column(type: "text", nullable: true), - BusinessTaxNumber = table.Column(type: "text", nullable: true), - BillingEmail = table.Column(type: "text", nullable: true), - Status = table.Column(type: "smallint", nullable: false), - UseEvents = table.Column(type: "boolean", nullable: false), - Enabled = table.Column(type: "boolean", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Provider", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "TaxRate", - columns: table => new - { - Id = table.Column(type: "character varying(40)", maxLength: 40, nullable: false), - Country = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - State = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), - PostalCode = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), - Rate = table.Column(type: "numeric", nullable: false), - Active = table.Column(type: "boolean", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_TaxRate", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "User", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: false, collation: "postgresIndetermanisticCollation"), - EmailVerified = table.Column(type: "boolean", nullable: false), - MasterPassword = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - MasterPasswordHint = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Culture = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), - SecurityStamp = table.Column(type: "character varying(50)", maxLength: 50, nullable: false), - TwoFactorProviders = table.Column(type: "text", nullable: true), - TwoFactorRecoveryCode = table.Column(type: "character varying(32)", maxLength: 32, nullable: true), - EquivalentDomains = table.Column(type: "text", nullable: true), - ExcludedGlobalEquivalentDomains = table.Column(type: "text", nullable: true), - AccountRevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - Key = table.Column(type: "text", nullable: true), - PublicKey = table.Column(type: "text", nullable: true), - PrivateKey = table.Column(type: "text", nullable: true), - Premium = table.Column(type: "boolean", nullable: false), - PremiumExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - RenewalReminderDate = table.Column(type: "timestamp without time zone", nullable: true), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "smallint", nullable: true), - GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - ReferenceData = table.Column(type: "text", nullable: true), - LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: false), - Kdf = table.Column(type: "smallint", nullable: false), - KdfIterations = table.Column(type: "integer", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_User", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Collection", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "text", nullable: true), - ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Collection", x => x.Id); - table.ForeignKey( - name: "FK_Collection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Group", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - AccessAll = table.Column(type: "boolean", nullable: false), - ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Group", x => x.Id); - table.ForeignKey( - name: "FK_Group_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Policy", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - Data = table.Column(type: "text", nullable: true), - Enabled = table.Column(type: "boolean", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Policy", x => x.Id); - table.ForeignKey( - name: "FK_Policy_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "SsoConfig", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - Enabled = table.Column(type: "boolean", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Data = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoConfig", x => x.Id); - table.ForeignKey( - name: "FK_SsoConfig_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "ProviderOrganization", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - ProviderId = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Key = table.Column(type: "text", nullable: true), - Settings = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganization", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganization_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganization_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Cipher", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Data = table.Column(type: "text", nullable: true), - Favorites = table.Column(type: "text", nullable: true), - Folders = table.Column(type: "text", nullable: true), - Attachments = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - DeletedDate = table.Column(type: "timestamp without time zone", nullable: true), - Reprompt = table.Column(type: "smallint", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Cipher", x => x.Id); - table.ForeignKey( - name: "FK_Cipher_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Cipher_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "Device", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - PushToken = table.Column(type: "character varying(255)", maxLength: 255, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Device", x => x.Id); - table.ForeignKey( - name: "FK_Device_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "EmergencyAccess", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - GrantorId = table.Column(type: "uuid", nullable: false), - GranteeId = table.Column(type: "uuid", nullable: true), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - KeyEncrypted = table.Column(type: "text", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Status = table.Column(type: "smallint", nullable: false), - WaitTimeDays = table.Column(type: "integer", nullable: false), - RecoveryInitiatedDate = table.Column(type: "timestamp without time zone", nullable: true), - LastNotificationDate = table.Column(type: "timestamp without time zone", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_EmergencyAccess", x => x.Id); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GranteeId", - column: x => x.GranteeId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GrantorId", - column: x => x.GrantorId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Folder", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Folder", x => x.Id); - table.ForeignKey( - name: "FK_Folder_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "OrganizationUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - Key = table.Column(type: "text", nullable: true), - ResetPasswordKey = table.Column(type: "text", nullable: true), - Status = table.Column(type: "smallint", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - AccessAll = table.Column(type: "boolean", nullable: false), - ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - Permissions = table.Column(type: "text", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationUser", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_OrganizationUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "ProviderUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - ProviderId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - Email = table.Column(type: "text", nullable: true), - Key = table.Column(type: "text", nullable: true), - Status = table.Column(type: "smallint", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - Permissions = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderUser_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "Send", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Data = table.Column(type: "text", nullable: true), - Key = table.Column(type: "text", nullable: true), - Password = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - MaxAccessCount = table.Column(type: "integer", nullable: true), - AccessCount = table.Column(type: "integer", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - DeletionDate = table.Column(type: "timestamp without time zone", nullable: false), - Disabled = table.Column(type: "boolean", nullable: false), - HideEmail = table.Column(type: "boolean", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Send", x => x.Id); - table.ForeignKey( - name: "FK_Send_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Send_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "SsoUser", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - UserId = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: true), - ExternalId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoUser", x => x.Id); - table.ForeignKey( - name: "FK_SsoUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_SsoUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Transaction", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Amount = table.Column(type: "numeric", nullable: false), - Refunded = table.Column(type: "boolean", nullable: true), - RefundedAmount = table.Column(type: "numeric", nullable: true), - Details = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - PaymentMethodType = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "smallint", nullable: true), - GatewayId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Transaction", x => x.Id); - table.ForeignKey( - name: "FK_Transaction_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Transaction_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "integer", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - UserId = table.Column(type: "uuid", nullable: false), - KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "CollectionGroups", - columns: table => new - { - CollectionId = table.Column(type: "uuid", nullable: false), - GroupId = table.Column(type: "uuid", nullable: false), - ReadOnly = table.Column(type: "boolean", nullable: false), - HidePasswords = table.Column(type: "boolean", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); - table.ForeignKey( - name: "FK_CollectionGroups_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionGroups_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "CollectionCipher", - columns: table => new - { - CollectionId = table.Column(type: "uuid", nullable: false), - CipherId = table.Column(type: "uuid", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); - table.ForeignKey( - name: "FK_CollectionCipher_Cipher_CipherId", - column: x => x.CipherId, - principalTable: "Cipher", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionCipher_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "CollectionUsers", - columns: table => new - { - CollectionId = table.Column(type: "uuid", nullable: false), - OrganizationUserId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - ReadOnly = table.Column(type: "boolean", nullable: false), - HidePasswords = table.Column(type: "boolean", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_CollectionUsers_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "GroupUser", - columns: table => new - { - GroupId = table.Column(type: "uuid", nullable: false), - OrganizationUserId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_GroupUser_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - ProviderOrganizationId = table.Column(type: "uuid", nullable: false), - ProviderUserId = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - Permissions = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_OrganizationId", - table: "Cipher", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_UserId", - table: "Cipher", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Collection_OrganizationId", - table: "Collection", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionCipher_CipherId", - table: "CollectionCipher", - column: "CipherId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionGroups_GroupId", - table: "CollectionGroups", - column: "GroupId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_OrganizationUserId", - table: "CollectionUsers", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_UserId", - table: "CollectionUsers", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Device_UserId", - table: "Device", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GranteeId", - table: "EmergencyAccess", - column: "GranteeId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GrantorId", - table: "EmergencyAccess", - column: "GrantorId"); - - migrationBuilder.CreateIndex( - name: "IX_Folder_UserId", - table: "Folder", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Group_OrganizationId", - table: "Group", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_OrganizationUserId", - table: "GroupUser", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_UserId", - table: "GroupUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_OrganizationId", - table: "OrganizationUser", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_UserId", - table: "OrganizationUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Policy_OrganizationId", - table: "Policy", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_OrganizationId", - table: "ProviderOrganization", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_ProviderId", - table: "ProviderOrganization", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_ProviderId", - table: "ProviderUser", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_UserId", - table: "ProviderUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_OrganizationId", - table: "Send", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_UserId", - table: "Send", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_SsoConfig_OrganizationId", - table: "SsoConfig", - column: "OrganizationId"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AlterDatabase() + .Annotation("Npgsql:CollationDefinition:postgresIndetermanisticCollation", "en-u-ks-primary,en-u-ks-primary,icu,False"); + + migrationBuilder.CreateTable( + name: "Event", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Date = table.Column(type: "timestamp without time zone", nullable: false), + Type = table.Column(type: "integer", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + CipherId = table.Column(type: "uuid", nullable: true), + CollectionId = table.Column(type: "uuid", nullable: true), + PolicyId = table.Column(type: "uuid", nullable: true), + GroupId = table.Column(type: "uuid", nullable: true), + OrganizationUserId = table.Column(type: "uuid", nullable: true), + DeviceType = table.Column(type: "smallint", nullable: true), + IpAddress = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + ActingUserId = table.Column(type: "uuid", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Event", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Grant", + columns: table => new + { + Key = table.Column(type: "character varying(200)", maxLength: 200, nullable: false), + Type = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + SubjectId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + SessionId = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + ClientId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + Description = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + ConsumedDate = table.Column(type: "timestamp without time zone", nullable: true), + Data = table.Column(type: "text", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Grant", x => x.Key); + }); + + migrationBuilder.CreateTable( + name: "Installation", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + Key = table.Column(type: "character varying(150)", maxLength: 150, nullable: true), + Enabled = table.Column(type: "boolean", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Installation", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Organization", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), + Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessName = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessAddress1 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessAddress2 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessAddress3 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessCountry = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), + BusinessTaxNumber = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), + BillingEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + Plan = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + PlanType = table.Column(type: "smallint", nullable: false), + Seats = table.Column(type: "integer", nullable: true), + MaxCollections = table.Column(type: "smallint", nullable: true), + UsePolicies = table.Column(type: "boolean", nullable: false), + UseSso = table.Column(type: "boolean", nullable: false), + UseGroups = table.Column(type: "boolean", nullable: false), + UseDirectory = table.Column(type: "boolean", nullable: false), + UseEvents = table.Column(type: "boolean", nullable: false), + UseTotp = table.Column(type: "boolean", nullable: false), + Use2fa = table.Column(type: "boolean", nullable: false), + UseApi = table.Column(type: "boolean", nullable: false), + UseResetPassword = table.Column(type: "boolean", nullable: false), + SelfHost = table.Column(type: "boolean", nullable: false), + UsersGetPremium = table.Column(type: "boolean", nullable: false), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "smallint", nullable: true), + GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + ReferenceData = table.Column(type: "text", nullable: true), + Enabled = table.Column(type: "boolean", nullable: false), + LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), + PublicKey = table.Column(type: "text", nullable: true), + PrivateKey = table.Column(type: "text", nullable: true), + TwoFactorProviders = table.Column(type: "text", nullable: true), + ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Organization", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Provider", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "text", nullable: true), + BusinessName = table.Column(type: "text", nullable: true), + BusinessAddress1 = table.Column(type: "text", nullable: true), + BusinessAddress2 = table.Column(type: "text", nullable: true), + BusinessAddress3 = table.Column(type: "text", nullable: true), + BusinessCountry = table.Column(type: "text", nullable: true), + BusinessTaxNumber = table.Column(type: "text", nullable: true), + BillingEmail = table.Column(type: "text", nullable: true), + Status = table.Column(type: "smallint", nullable: false), + UseEvents = table.Column(type: "boolean", nullable: false), + Enabled = table.Column(type: "boolean", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Provider", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "TaxRate", + columns: table => new + { + Id = table.Column(type: "character varying(40)", maxLength: 40, nullable: false), + Country = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + State = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), + PostalCode = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), + Rate = table.Column(type: "numeric", nullable: false), + Active = table.Column(type: "boolean", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_TaxRate", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "User", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: false, collation: "postgresIndetermanisticCollation"), + EmailVerified = table.Column(type: "boolean", nullable: false), + MasterPassword = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + MasterPasswordHint = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Culture = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), + SecurityStamp = table.Column(type: "character varying(50)", maxLength: 50, nullable: false), + TwoFactorProviders = table.Column(type: "text", nullable: true), + TwoFactorRecoveryCode = table.Column(type: "character varying(32)", maxLength: 32, nullable: true), + EquivalentDomains = table.Column(type: "text", nullable: true), + ExcludedGlobalEquivalentDomains = table.Column(type: "text", nullable: true), + AccountRevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + Key = table.Column(type: "text", nullable: true), + PublicKey = table.Column(type: "text", nullable: true), + PrivateKey = table.Column(type: "text", nullable: true), + Premium = table.Column(type: "boolean", nullable: false), + PremiumExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + RenewalReminderDate = table.Column(type: "timestamp without time zone", nullable: true), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "smallint", nullable: true), + GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + ReferenceData = table.Column(type: "text", nullable: true), + LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: false), + Kdf = table.Column(type: "smallint", nullable: false), + KdfIterations = table.Column(type: "integer", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_User", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Collection", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "text", nullable: true), + ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Collection", x => x.Id); + table.ForeignKey( + name: "FK_Collection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Group", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + AccessAll = table.Column(type: "boolean", nullable: false), + ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Group", x => x.Id); + table.ForeignKey( + name: "FK_Group_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Policy", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + Data = table.Column(type: "text", nullable: true), + Enabled = table.Column(type: "boolean", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Policy", x => x.Id); + table.ForeignKey( + name: "FK_Policy_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "SsoConfig", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + Enabled = table.Column(type: "boolean", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Data = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoConfig", x => x.Id); + table.ForeignKey( + name: "FK_SsoConfig_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "ProviderOrganization", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + ProviderId = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Key = table.Column(type: "text", nullable: true), + Settings = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganization", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganization_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganization_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Cipher", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Data = table.Column(type: "text", nullable: true), + Favorites = table.Column(type: "text", nullable: true), + Folders = table.Column(type: "text", nullable: true), + Attachments = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + DeletedDate = table.Column(type: "timestamp without time zone", nullable: true), + Reprompt = table.Column(type: "smallint", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Cipher", x => x.Id); + table.ForeignKey( + name: "FK_Cipher_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Cipher_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "Device", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + PushToken = table.Column(type: "character varying(255)", maxLength: 255, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Device", x => x.Id); + table.ForeignKey( + name: "FK_Device_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "EmergencyAccess", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + GrantorId = table.Column(type: "uuid", nullable: false), + GranteeId = table.Column(type: "uuid", nullable: true), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + KeyEncrypted = table.Column(type: "text", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Status = table.Column(type: "smallint", nullable: false), + WaitTimeDays = table.Column(type: "integer", nullable: false), + RecoveryInitiatedDate = table.Column(type: "timestamp without time zone", nullable: true), + LastNotificationDate = table.Column(type: "timestamp without time zone", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_EmergencyAccess", x => x.Id); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GranteeId", + column: x => x.GranteeId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GrantorId", + column: x => x.GrantorId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Folder", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Folder", x => x.Id); + table.ForeignKey( + name: "FK_Folder_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "OrganizationUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + Key = table.Column(type: "text", nullable: true), + ResetPasswordKey = table.Column(type: "text", nullable: true), + Status = table.Column(type: "smallint", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + AccessAll = table.Column(type: "boolean", nullable: false), + ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + Permissions = table.Column(type: "text", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationUser", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_OrganizationUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "ProviderUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + ProviderId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + Email = table.Column(type: "text", nullable: true), + Key = table.Column(type: "text", nullable: true), + Status = table.Column(type: "smallint", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + Permissions = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderUser_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "Send", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Data = table.Column(type: "text", nullable: true), + Key = table.Column(type: "text", nullable: true), + Password = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + MaxAccessCount = table.Column(type: "integer", nullable: true), + AccessCount = table.Column(type: "integer", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + DeletionDate = table.Column(type: "timestamp without time zone", nullable: false), + Disabled = table.Column(type: "boolean", nullable: false), + HideEmail = table.Column(type: "boolean", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Send", x => x.Id); + table.ForeignKey( + name: "FK_Send_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Send_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "SsoUser", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + UserId = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: true), + ExternalId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoUser", x => x.Id); + table.ForeignKey( + name: "FK_SsoUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_SsoUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Transaction", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Amount = table.Column(type: "numeric", nullable: false), + Refunded = table.Column(type: "boolean", nullable: true), + RefundedAmount = table.Column(type: "numeric", nullable: true), + Details = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + PaymentMethodType = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "smallint", nullable: true), + GatewayId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Transaction", x => x.Id); + table.ForeignKey( + name: "FK_Transaction_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Transaction_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "integer", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + UserId = table.Column(type: "uuid", nullable: false), + KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CollectionGroups", + columns: table => new + { + CollectionId = table.Column(type: "uuid", nullable: false), + GroupId = table.Column(type: "uuid", nullable: false), + ReadOnly = table.Column(type: "boolean", nullable: false), + HidePasswords = table.Column(type: "boolean", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); + table.ForeignKey( + name: "FK_CollectionGroups_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionGroups_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CollectionCipher", + columns: table => new + { + CollectionId = table.Column(type: "uuid", nullable: false), + CipherId = table.Column(type: "uuid", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); + table.ForeignKey( + name: "FK_CollectionCipher_Cipher_CipherId", + column: x => x.CipherId, + principalTable: "Cipher", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionCipher_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CollectionUsers", + columns: table => new + { + CollectionId = table.Column(type: "uuid", nullable: false), + OrganizationUserId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + ReadOnly = table.Column(type: "boolean", nullable: false), + HidePasswords = table.Column(type: "boolean", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_CollectionUsers_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "GroupUser", + columns: table => new + { + GroupId = table.Column(type: "uuid", nullable: false), + OrganizationUserId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_GroupUser_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + ProviderOrganizationId = table.Column(type: "uuid", nullable: false), + ProviderUserId = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + Permissions = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_OrganizationId", + table: "Cipher", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_UserId", + table: "Cipher", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Collection_OrganizationId", + table: "Collection", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionCipher_CipherId", + table: "CollectionCipher", + column: "CipherId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionGroups_GroupId", + table: "CollectionGroups", + column: "GroupId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_OrganizationUserId", + table: "CollectionUsers", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_UserId", + table: "CollectionUsers", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Device_UserId", + table: "Device", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GranteeId", + table: "EmergencyAccess", + column: "GranteeId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GrantorId", + table: "EmergencyAccess", + column: "GrantorId"); + + migrationBuilder.CreateIndex( + name: "IX_Folder_UserId", + table: "Folder", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Group_OrganizationId", + table: "Group", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_OrganizationUserId", + table: "GroupUser", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_UserId", + table: "GroupUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_OrganizationId", + table: "OrganizationUser", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_UserId", + table: "OrganizationUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Policy_OrganizationId", + table: "Policy", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_OrganizationId", + table: "ProviderOrganization", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_ProviderId", + table: "ProviderOrganization", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_ProviderId", + table: "ProviderUser", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_UserId", + table: "ProviderUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_OrganizationId", + table: "Send", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_UserId", + table: "Send", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_SsoConfig_OrganizationId", + table: "SsoConfig", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_OrganizationId", - table: "SsoUser", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_OrganizationId", + table: "SsoUser", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_UserId", - table: "SsoUser", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_UserId", + table: "SsoUser", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_OrganizationId", - table: "Transaction", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_OrganizationId", + table: "Transaction", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_UserId", - table: "Transaction", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_UserId", + table: "Transaction", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); - } + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "CollectionCipher"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "CollectionCipher"); - migrationBuilder.DropTable( - name: "CollectionGroups"); + migrationBuilder.DropTable( + name: "CollectionGroups"); - migrationBuilder.DropTable( - name: "CollectionUsers"); + migrationBuilder.DropTable( + name: "CollectionUsers"); - migrationBuilder.DropTable( - name: "Device"); + migrationBuilder.DropTable( + name: "Device"); - migrationBuilder.DropTable( - name: "EmergencyAccess"); + migrationBuilder.DropTable( + name: "EmergencyAccess"); - migrationBuilder.DropTable( - name: "Event"); + migrationBuilder.DropTable( + name: "Event"); - migrationBuilder.DropTable( - name: "Folder"); + migrationBuilder.DropTable( + name: "Folder"); - migrationBuilder.DropTable( - name: "Grant"); + migrationBuilder.DropTable( + name: "Grant"); - migrationBuilder.DropTable( - name: "GroupUser"); + migrationBuilder.DropTable( + name: "GroupUser"); - migrationBuilder.DropTable( - name: "Installation"); + migrationBuilder.DropTable( + name: "Installation"); - migrationBuilder.DropTable( - name: "Policy"); + migrationBuilder.DropTable( + name: "Policy"); - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.DropTable( - name: "Send"); + migrationBuilder.DropTable( + name: "Send"); - migrationBuilder.DropTable( - name: "SsoConfig"); + migrationBuilder.DropTable( + name: "SsoConfig"); - migrationBuilder.DropTable( - name: "SsoUser"); + migrationBuilder.DropTable( + name: "SsoUser"); - migrationBuilder.DropTable( - name: "TaxRate"); + migrationBuilder.DropTable( + name: "TaxRate"); - migrationBuilder.DropTable( - name: "Transaction"); + migrationBuilder.DropTable( + name: "Transaction"); - migrationBuilder.DropTable( - name: "U2f"); + migrationBuilder.DropTable( + name: "U2f"); - migrationBuilder.DropTable( - name: "Cipher"); + migrationBuilder.DropTable( + name: "Cipher"); - migrationBuilder.DropTable( - name: "Collection"); + migrationBuilder.DropTable( + name: "Collection"); - migrationBuilder.DropTable( - name: "Group"); + migrationBuilder.DropTable( + name: "Group"); - migrationBuilder.DropTable( - name: "OrganizationUser"); + migrationBuilder.DropTable( + name: "OrganizationUser"); - migrationBuilder.DropTable( - name: "ProviderOrganization"); + migrationBuilder.DropTable( + name: "ProviderOrganization"); - migrationBuilder.DropTable( - name: "ProviderUser"); + migrationBuilder.DropTable( + name: "ProviderUser"); - migrationBuilder.DropTable( - name: "Organization"); + migrationBuilder.DropTable( + name: "Organization"); - migrationBuilder.DropTable( - name: "Provider"); + migrationBuilder.DropTable( + name: "Provider"); - migrationBuilder.DropTable( - name: "User"); + migrationBuilder.DropTable( + name: "User"); + } } } diff --git a/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs b/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs index f0f0d235b9..ba7da780f8 100644 --- a/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs +++ b/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs @@ -1,74 +1,75 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class RemoveProviderOrganizationProviderUser : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class RemoveProviderOrganizationProviderUser : Migration { - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.AddColumn( - name: "ProviderId", - table: "Event", - type: "uuid", - nullable: true); + migrationBuilder.AddColumn( + name: "ProviderId", + table: "Event", + type: "uuid", + nullable: true); - migrationBuilder.AddColumn( - name: "ProviderUserId", - table: "Event", - type: "uuid", - nullable: true); - } + migrationBuilder.AddColumn( + name: "ProviderUserId", + table: "Event", + type: "uuid", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "ProviderId", - table: "Event"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "ProviderId", + table: "Event"); - migrationBuilder.DropColumn( - name: "ProviderUserId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderUserId", + table: "Event"); - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - Permissions = table.Column(type: "text", nullable: true), - ProviderOrganizationId = table.Column(type: "uuid", nullable: false), - ProviderUserId = table.Column(type: "uuid", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - Type = table.Column(type: "smallint", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + Permissions = table.Column(type: "text", nullable: true), + ProviderOrganizationId = table.Column(type: "uuid", nullable: false), + ProviderUserId = table.Column(type: "uuid", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + Type = table.Column(type: "smallint", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); + } } } diff --git a/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs b/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs index 5b435b218b..bb39dfe4b6 100644 --- a/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs +++ b/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class UserForcePasswordReset : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class UserForcePasswordReset : Migration { - migrationBuilder.AddColumn( - name: "ForcePasswordReset", - table: "User", - type: "boolean", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ForcePasswordReset", + table: "User", + type: "boolean", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "ForcePasswordReset", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "ForcePasswordReset", + table: "User"); + } } } diff --git a/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs b/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs index 41ab20399d..98d2acce63 100644 --- a/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs +++ b/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs @@ -1,42 +1,43 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class AddMaxAutoscaleSeatsToOrganization : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class AddMaxAutoscaleSeatsToOrganization : Migration { - migrationBuilder.AddColumn( - name: "MaxAutoscaleSeats", - table: "Organization", - type: "integer", - nullable: true); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "MaxAutoscaleSeats", + table: "Organization", + type: "integer", + nullable: true); - migrationBuilder.AddColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization", - type: "timestamp without time zone", - nullable: true); + migrationBuilder.AddColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization", + type: "timestamp without time zone", + nullable: true); - migrationBuilder.AddColumn( - name: "ProviderOrganizationId", - table: "Event", - type: "uuid", - nullable: true); - } + migrationBuilder.AddColumn( + name: "ProviderOrganizationId", + table: "Event", + type: "uuid", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "MaxAutoscaleSeats", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "MaxAutoscaleSeats", + table: "Organization"); - migrationBuilder.DropColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization"); + migrationBuilder.DropColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization"); - migrationBuilder.DropColumn( - name: "ProviderOrganizationId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderOrganizationId", + table: "Event"); + } } } diff --git a/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs b/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs index d1c08d3fbb..90b0884aba 100644 --- a/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs +++ b/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class SplitManageCollectionsPermissions2 : Migration +namespace Bit.PostgresMigrations.Migrations { - private const string _scriptLocation = - "PostgresMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.psql"; - - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SplitManageCollectionsPermissions2 : Migration { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } + private const string _scriptLocation = + "PostgresMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.psql"; - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); + } } } diff --git a/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs b/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs index c569d7f1b0..c8c569a0d6 100644 --- a/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs +++ b/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs @@ -1,20 +1,21 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration +namespace Bit.PostgresMigrations.Migrations { - private const string _scriptLocation = - "PostgresMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.psql"; - - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } + private const string _scriptLocation = + "PostgresMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.psql"; - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); + } } } diff --git a/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs b/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs index 7619e76896..264869124f 100644 --- a/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs +++ b/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class KeyConnector : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class KeyConnector : Migration { - migrationBuilder.AddColumn( - name: "UsesKeyConnector", - table: "User", - type: "boolean", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UsesKeyConnector", + table: "User", + type: "boolean", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UsesKeyConnector", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UsesKeyConnector", + table: "User"); + } } } diff --git a/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs b/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs index a787141e78..6918e885d2 100644 --- a/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs +++ b/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs @@ -1,81 +1,82 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class OrganizationSponsorship : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class OrganizationSponsorship : Migration { - migrationBuilder.AddColumn( - name: "UsesCryptoAgent", - table: "User", - type: "boolean", - nullable: false, - defaultValue: false); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UsesCryptoAgent", + table: "User", + type: "boolean", + nullable: false, + defaultValue: false); - migrationBuilder.CreateTable( - name: "OrganizationSponsorship", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - InstallationId = table.Column(type: "uuid", nullable: true), - SponsoringOrganizationId = table.Column(type: "uuid", nullable: true), - SponsoringOrganizationUserId = table.Column(type: "uuid", nullable: true), - SponsoredOrganizationId = table.Column(type: "uuid", nullable: true), - FriendlyName = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - OfferedToEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - PlanSponsorshipType = table.Column(type: "smallint", nullable: true), - CloudSponsor = table.Column(type: "boolean", nullable: false), - LastSyncDate = table.Column(type: "timestamp without time zone", nullable: true), - TimesRenewedWithoutValidation = table.Column(type: "smallint", nullable: false), - SponsorshipLapsedDate = table.Column(type: "timestamp without time zone", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - column: x => x.InstallationId, - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", - column: x => x.SponsoredOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - column: x => x.SponsoringOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); + migrationBuilder.CreateTable( + name: "OrganizationSponsorship", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + InstallationId = table.Column(type: "uuid", nullable: true), + SponsoringOrganizationId = table.Column(type: "uuid", nullable: true), + SponsoringOrganizationUserId = table.Column(type: "uuid", nullable: true), + SponsoredOrganizationId = table.Column(type: "uuid", nullable: true), + FriendlyName = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + OfferedToEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + PlanSponsorshipType = table.Column(type: "smallint", nullable: true), + CloudSponsor = table.Column(type: "boolean", nullable: false), + LastSyncDate = table.Column(type: "timestamp without time zone", nullable: true), + TimesRenewedWithoutValidation = table.Column(type: "smallint", nullable: false), + SponsorshipLapsedDate = table.Column(type: "timestamp without time zone", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + column: x => x.InstallationId, + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", + column: x => x.SponsoredOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + column: x => x.SponsoringOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoredOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoredOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoredOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoredOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "UsesCryptoAgent", - table: "User"); + migrationBuilder.DropColumn( + name: "UsesCryptoAgent", + table: "User"); + } } } diff --git a/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs b/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs index 225f67bf9b..edc5220861 100644 --- a/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs +++ b/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class KeyConnectorFlag : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class KeyConnectorFlag : Migration { - migrationBuilder.AddColumn( - name: "UseKeyConnector", - table: "Organization", - type: "boolean", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseKeyConnector", + table: "Organization", + type: "boolean", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseKeyConnector", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseKeyConnector", + table: "Organization"); + } } } diff --git a/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs b/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs index 0679e212c5..906c30be4c 100644 --- a/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs +++ b/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs @@ -1,45 +1,46 @@ using Microsoft.EntityFrameworkCore.Migrations; using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; -namespace Bit.PostgresMigrations.Migrations; - -public partial class RemoveU2F : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class RemoveU2F : Migration { - migrationBuilder.DropTable( - name: "U2f"); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "U2f"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "integer", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - UserId = table.Column(type: "uuid", nullable: false), - Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "integer", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + UserId = table.Column(type: "uuid", nullable: false), + Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); + } } } diff --git a/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs b/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs index 6015ef357f..6c57172fb7 100644 --- a/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs +++ b/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs @@ -1,33 +1,34 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class FailedLoginCaptcha : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class FailedLoginCaptcha : Migration { - migrationBuilder.AddColumn( - name: "FailedLoginCount", - table: "User", - type: "integer", - nullable: false, - defaultValue: 0); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "FailedLoginCount", + table: "User", + type: "integer", + nullable: false, + defaultValue: 0); - migrationBuilder.AddColumn( - name: "LastFailedLoginDate", - table: "User", - type: "timestamp without time zone", - nullable: true); - } + migrationBuilder.AddColumn( + name: "LastFailedLoginDate", + table: "User", + type: "timestamp without time zone", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "FailedLoginCount", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "FailedLoginCount", + table: "User"); - migrationBuilder.DropColumn( - name: "LastFailedLoginDate", - table: "User"); + migrationBuilder.DropColumn( + name: "LastFailedLoginDate", + table: "User"); + } } } diff --git a/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs b/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs index 0c030f0dd2..b636101b00 100644 --- a/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs +++ b/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs @@ -1,153 +1,154 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class SelfHostF4E : Migration +namespace Bit.PostgresMigrations.Migrations { - private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.psql"; - - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SelfHostF4E : Migration { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship"); + private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.psql"; - migrationBuilder.DropIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.CreateTable( - name: "OrganizationApiKey", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationApiKey_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + migrationBuilder.DropColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship"); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.CreateTable( + name: "OrganizationApiKey", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationApiKey_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.DropColumn( - name: "ApiKey", - table: "Organization"); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.RenameColumn( - name: "SponsorshipLapsedDate", - table: "OrganizationSponsorship", - newName: "ValidUntil"); + migrationBuilder.DropColumn( + name: "ApiKey", + table: "Organization"); - migrationBuilder.RenameColumn( - name: "CloudSponsor", - table: "OrganizationSponsorship", - newName: "ToDelete"); + migrationBuilder.RenameColumn( + name: "SponsorshipLapsedDate", + table: "OrganizationSponsorship", + newName: "ValidUntil"); + + migrationBuilder.RenameColumn( + name: "CloudSponsor", + table: "OrganizationSponsorship", + newName: "ToDelete"); - migrationBuilder.CreateTable( - name: "OrganizationConnection", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Enabled = table.Column(type: "boolean", nullable: false), - Config = table.Column(type: "text", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationConnection", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationConnection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + migrationBuilder.CreateTable( + name: "OrganizationConnection", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Enabled = table.Column(type: "boolean", nullable: false), + Config = table.Column(type: "text", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationConnection", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationConnection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateIndex( - name: "IX_OrganizationApiKey_OrganizationId", - table: "OrganizationApiKey", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationApiKey_OrganizationId", + table: "OrganizationApiKey", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationConnection_OrganizationId", - table: "OrganizationConnection", - column: "OrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationConnection_OrganizationId", + table: "OrganizationConnection", + column: "OrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "ApiKey", - table: "Organization", - type: "character varying(30)", - maxLength: 30, - nullable: true); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ApiKey", + table: "Organization", + type: "character varying(30)", + maxLength: 30, + nullable: true); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.DropTable( - name: "OrganizationApiKey"); + migrationBuilder.DropTable( + name: "OrganizationApiKey"); - migrationBuilder.DropTable( - name: "OrganizationConnection"); + migrationBuilder.DropTable( + name: "OrganizationConnection"); - migrationBuilder.RenameColumn( - name: "ValidUntil", - table: "OrganizationSponsorship", - newName: "SponsorshipLapsedDate"); + migrationBuilder.RenameColumn( + name: "ValidUntil", + table: "OrganizationSponsorship", + newName: "SponsorshipLapsedDate"); - migrationBuilder.RenameColumn( - name: "ToDelete", - table: "OrganizationSponsorship", - newName: "CloudSponsor"); + migrationBuilder.RenameColumn( + name: "ToDelete", + table: "OrganizationSponsorship", + newName: "CloudSponsor"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true); - migrationBuilder.AddColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship", - type: "smallint", - nullable: false, - defaultValue: (byte)0); + migrationBuilder.AddColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship", + type: "smallint", + nullable: false, + defaultValue: (byte)0); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId", - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId", + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } } } diff --git a/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs b/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs index 7b569a62ce..46b76b2bfd 100644 --- a/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs +++ b/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs @@ -1,72 +1,73 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class SponsorshipBulkActions : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class SponsorshipBulkActions : Migration { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - oldClrType: typeof(Guid), - oldType: "uuid", - oldNullable: true); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + oldClrType: typeof(Guid), + oldType: "uuid", + oldNullable: true); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - oldClrType: typeof(Guid), - oldType: "uuid", - oldNullable: true); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + oldClrType: typeof(Guid), + oldType: "uuid", + oldNullable: true); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true, - oldClrType: typeof(Guid), - oldType: "uuid"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true, + oldClrType: typeof(Guid), + oldType: "uuid"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true, - oldClrType: typeof(Guid), - oldType: "uuid"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true, + oldClrType: typeof(Guid), + oldType: "uuid"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } } } diff --git a/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs b/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs index a02d9e70ae..94bfa5b7c6 100644 --- a/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs +++ b/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs @@ -1,64 +1,65 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class AddInstallationIdToEvents : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class AddInstallationIdToEvents : Migration { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true, - oldClrType: typeof(Guid), - oldType: "uuid"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true, + oldClrType: typeof(Guid), + oldType: "uuid"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "Event", - type: "uuid", - nullable: true); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "Event", + type: "uuid", + nullable: true); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "Event"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "Event"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - oldClrType: typeof(Guid), - oldType: "uuid", - oldNullable: true); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + oldClrType: typeof(Guid), + oldType: "uuid", + oldNullable: true); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + } } } diff --git a/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs b/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs index 3dd1b4c5f0..880c659086 100644 --- a/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs +++ b/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs @@ -1,23 +1,24 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations; - -public partial class DeviceUnknownVerification : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class DeviceUnknownVerification : Migration { - migrationBuilder.AddColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User", - type: "boolean", - nullable: false, - defaultValue: true); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User", + type: "boolean", + nullable: false, + defaultValue: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User"); + } } } diff --git a/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs b/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs index 02c7ca90ea..6a71e38fb3 100644 --- a/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs +++ b/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs @@ -2,24 +2,25 @@ #nullable disable -namespace Bit.PostgresMigrations.Migrations; - -public partial class UseScimFlag : Migration +namespace Bit.PostgresMigrations.Migrations { - protected override void Up(MigrationBuilder migrationBuilder) + public partial class UseScimFlag : Migration { - migrationBuilder.AddColumn( - name: "UseScim", - table: "Organization", - type: "boolean", - nullable: false, - defaultValue: false); - } + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseScim", + table: "Organization", + type: "boolean", + nullable: false, + defaultValue: false); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseScim", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseScim", + table: "Organization"); + } } } diff --git a/util/Server/Program.cs b/util/Server/Program.cs index 767b965149..25f5fd4406 100644 --- a/util/Server/Program.cs +++ b/util/Server/Program.cs @@ -1,40 +1,41 @@ -namespace Bit.Server; - -public class Program +namespace Bit.Server { - public static void Main(string[] args) + public class Program { - var config = new ConfigurationBuilder() - .AddCommandLine(args) - .Build(); + public static void Main(string[] args) + { + var config = new ConfigurationBuilder() + .AddCommandLine(args) + .Build(); - var builder = new WebHostBuilder() - .UseConfiguration(config) - .UseKestrel() - .UseStartup() - .ConfigureLogging((hostingContext, logging) => + var builder = new WebHostBuilder() + .UseConfiguration(config) + .UseKestrel() + .UseStartup() + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConsole().AddDebug(); + }) + .ConfigureKestrel((context, options) => { }); + + var contentRoot = config.GetValue("contentRoot"); + if (!string.IsNullOrWhiteSpace(contentRoot)) { - logging.AddConsole().AddDebug(); - }) - .ConfigureKestrel((context, options) => { }); + builder.UseContentRoot(contentRoot); + } + else + { + builder.UseContentRoot(Directory.GetCurrentDirectory()); + } - var contentRoot = config.GetValue("contentRoot"); - if (!string.IsNullOrWhiteSpace(contentRoot)) - { - builder.UseContentRoot(contentRoot); - } - else - { - builder.UseContentRoot(Directory.GetCurrentDirectory()); - } + var webRoot = config.GetValue("webRoot"); + if (string.IsNullOrWhiteSpace(webRoot)) + { + builder.UseWebRoot(webRoot); + } - var webRoot = config.GetValue("webRoot"); - if (string.IsNullOrWhiteSpace(webRoot)) - { - builder.UseWebRoot(webRoot); + var host = builder.Build(); + host.Run(); } - - var host = builder.Build(); - host.Run(); } } diff --git a/util/Server/Startup.cs b/util/Server/Startup.cs index 7b195beb53..362d87383a 100644 --- a/util/Server/Startup.cs +++ b/util/Server/Startup.cs @@ -1,89 +1,90 @@ using System.Globalization; using Microsoft.AspNetCore.StaticFiles; -namespace Bit.Server; - -public class Startup +namespace Bit.Server { - private readonly List _longCachedPaths = new List + public class Startup { - "/app/", "/locales/", "/fonts/", "/connectors/", "/scripts/" - }; - private readonly List _mediumCachedPaths = new List - { - "/images/" - }; - - public Startup() - { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - } - - public void ConfigureServices(IServiceCollection services) - { - services.AddRouting(); - } - - public void Configure( - IApplicationBuilder app, - IConfiguration configuration) - { - if (configuration.GetValue("serveUnknown") ?? false) + private readonly List _longCachedPaths = new List { - app.UseStaticFiles(new StaticFileOptions - { - ServeUnknownFileTypes = true, - DefaultContentType = "application/octet-stream" - }); - app.UseRouting(); - app.UseEndpoints(endpoints => - { - endpoints.MapGet("/alive", - async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); - }); + "/app/", "/locales/", "/fonts/", "/connectors/", "/scripts/" + }; + private readonly List _mediumCachedPaths = new List + { + "/images/" + }; + + public Startup() + { + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); } - else if (configuration.GetValue("webVault") ?? false) - { - // TODO: This should be removed when asp.net natively support avif - var provider = new FileExtensionContentTypeProvider { Mappings = { [".avif"] = "image/avif" } }; - var options = new DefaultFilesOptions(); - options.DefaultFileNames.Clear(); - options.DefaultFileNames.Add("index.html"); - app.UseDefaultFiles(options); - app.UseStaticFiles(new StaticFileOptions + public void ConfigureServices(IServiceCollection services) + { + services.AddRouting(); + } + + public void Configure( + IApplicationBuilder app, + IConfiguration configuration) + { + if (configuration.GetValue("serveUnknown") ?? false) { - ContentTypeProvider = provider, - OnPrepareResponse = ctx => + app.UseStaticFiles(new StaticFileOptions { - if (!ctx.Context.Request.Path.HasValue || - ctx.Context.Response.Headers.ContainsKey("Cache-Control")) - { - return; - } - var path = ctx.Context.Request.Path.Value; - if (_longCachedPaths.Any(ext => path.StartsWith(ext))) - { - // 14 days - ctx.Context.Response.Headers.Append("Cache-Control", "max-age=1209600"); - } - if (_mediumCachedPaths.Any(ext => path.StartsWith(ext))) - { - // 7 days - ctx.Context.Response.Headers.Append("Cache-Control", "max-age=604800"); - } - } - }); - } - else - { - app.UseFileServer(); - app.UseRouting(); - app.UseEndpoints(endpoints => + ServeUnknownFileTypes = true, + DefaultContentType = "application/octet-stream" + }); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/alive", + async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); + }); + } + else if (configuration.GetValue("webVault") ?? false) { - endpoints.MapGet("/alive", - async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); - }); + // TODO: This should be removed when asp.net natively support avif + var provider = new FileExtensionContentTypeProvider { Mappings = { [".avif"] = "image/avif" } }; + + var options = new DefaultFilesOptions(); + options.DefaultFileNames.Clear(); + options.DefaultFileNames.Add("index.html"); + app.UseDefaultFiles(options); + app.UseStaticFiles(new StaticFileOptions + { + ContentTypeProvider = provider, + OnPrepareResponse = ctx => + { + if (!ctx.Context.Request.Path.HasValue || + ctx.Context.Response.Headers.ContainsKey("Cache-Control")) + { + return; + } + var path = ctx.Context.Request.Path.Value; + if (_longCachedPaths.Any(ext => path.StartsWith(ext))) + { + // 14 days + ctx.Context.Response.Headers.Append("Cache-Control", "max-age=1209600"); + } + if (_mediumCachedPaths.Any(ext => path.StartsWith(ext))) + { + // 7 days + ctx.Context.Response.Headers.Append("Cache-Control", "max-age=604800"); + } + } + }); + } + else + { + app.UseFileServer(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/alive", + async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); + }); + } } } } diff --git a/util/Setup/AppIdBuilder.cs b/util/Setup/AppIdBuilder.cs index 6e984aa904..46fe222b6b 100644 --- a/util/Setup/AppIdBuilder.cs +++ b/util/Setup/AppIdBuilder.cs @@ -1,33 +1,34 @@ -namespace Bit.Setup; - -public class AppIdBuilder +namespace Bit.Setup { - private readonly Context _context; - - public AppIdBuilder(Context context) + public class AppIdBuilder { - _context = context; - } + private readonly Context _context; - public void Build() - { - var model = new TemplateModel + public AppIdBuilder(Context context) { - Url = _context.Config.Url - }; + _context = context; + } - // Needed for backwards compatability with migrated U2F tokens. - Helpers.WriteLine(_context, "Building FIDO U2F app id."); - Directory.CreateDirectory("/bitwarden/web/"); - var template = Helpers.ReadTemplate("AppId"); - using (var sw = File.CreateText("/bitwarden/web/app-id.json")) + public void Build() { - sw.Write(template(model)); + var model = new TemplateModel + { + Url = _context.Config.Url + }; + + // Needed for backwards compatability with migrated U2F tokens. + Helpers.WriteLine(_context, "Building FIDO U2F app id."); + Directory.CreateDirectory("/bitwarden/web/"); + var template = Helpers.ReadTemplate("AppId"); + using (var sw = File.CreateText("/bitwarden/web/app-id.json")) + { + sw.Write(template(model)); + } + } + + public class TemplateModel + { + public string Url { get; set; } } } - - public class TemplateModel - { - public string Url { get; set; } - } } diff --git a/util/Setup/CertBuilder.cs b/util/Setup/CertBuilder.cs index a01e9d98bc..3a43888f27 100644 --- a/util/Setup/CertBuilder.cs +++ b/util/Setup/CertBuilder.cs @@ -1,111 +1,112 @@ -namespace Bit.Setup; - -public class CertBuilder +namespace Bit.Setup { - private readonly Context _context; - - public CertBuilder(Context context) + public class CertBuilder { - _context = context; - } + private readonly Context _context; - public void BuildForInstall() - { - if (_context.Stub) + public CertBuilder(Context context) { - _context.Config.Ssl = true; - _context.Install.Trusted = true; - _context.Install.SelfSignedCert = false; - _context.Install.DiffieHellman = false; - _context.Install.IdentityCertPassword = "IDENTITY_CERT_PASSWORD"; - return; + _context = context; } - _context.Config.Ssl = _context.Config.SslManagedLetsEncrypt; - - if (!_context.Config.Ssl) + public void BuildForInstall() { - var skipSSL = _context.Parameters.ContainsKey("skip-ssl") && (_context.Parameters["skip-ssl"] == "true" || _context.Parameters["skip-ssl"] == "1"); - - if (!skipSSL) + if (_context.Stub) { - _context.Config.Ssl = Helpers.ReadQuestion("Do you have a SSL certificate to use?"); - if (_context.Config.Ssl) + _context.Config.Ssl = true; + _context.Install.Trusted = true; + _context.Install.SelfSignedCert = false; + _context.Install.DiffieHellman = false; + _context.Install.IdentityCertPassword = "IDENTITY_CERT_PASSWORD"; + return; + } + + _context.Config.Ssl = _context.Config.SslManagedLetsEncrypt; + + if (!_context.Config.Ssl) + { + var skipSSL = _context.Parameters.ContainsKey("skip-ssl") && (_context.Parameters["skip-ssl"] == "true" || _context.Parameters["skip-ssl"] == "1"); + + if (!skipSSL) { - Directory.CreateDirectory($"/bitwarden/ssl/{_context.Install.Domain}/"); - var message = "Make sure 'certificate.crt' and 'private.key' are provided in the \n" + - "appropriate directory before running 'start' (see docs for info)."; - Helpers.ShowBanner(_context, "NOTE", message); - } - else if (Helpers.ReadQuestion("Do you want to generate a self-signed SSL certificate?")) - { - Directory.CreateDirectory($"/bitwarden/ssl/self/{_context.Install.Domain}/"); - Helpers.WriteLine(_context, "Generating self signed SSL certificate."); - _context.Config.Ssl = true; - _context.Install.Trusted = false; - _context.Install.SelfSignedCert = true; - Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -days 36500 " + - $"-keyout /bitwarden/ssl/self/{_context.Install.Domain}/private.key " + - $"-out /bitwarden/ssl/self/{_context.Install.Domain}/certificate.crt " + - $"-reqexts SAN -extensions SAN " + - $"-config <(cat /usr/lib/ssl/openssl.cnf <(printf '[SAN]\nsubjectAltName=DNS:{_context.Install.Domain}\nbasicConstraints=CA:true')) " + - $"-subj \"/C=US/ST=California/L=Santa Barbara/O=Bitwarden Inc./OU=Bitwarden/CN={_context.Install.Domain}\""); + _context.Config.Ssl = Helpers.ReadQuestion("Do you have a SSL certificate to use?"); + if (_context.Config.Ssl) + { + Directory.CreateDirectory($"/bitwarden/ssl/{_context.Install.Domain}/"); + var message = "Make sure 'certificate.crt' and 'private.key' are provided in the \n" + + "appropriate directory before running 'start' (see docs for info)."; + Helpers.ShowBanner(_context, "NOTE", message); + } + else if (Helpers.ReadQuestion("Do you want to generate a self-signed SSL certificate?")) + { + Directory.CreateDirectory($"/bitwarden/ssl/self/{_context.Install.Domain}/"); + Helpers.WriteLine(_context, "Generating self signed SSL certificate."); + _context.Config.Ssl = true; + _context.Install.Trusted = false; + _context.Install.SelfSignedCert = true; + Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -days 36500 " + + $"-keyout /bitwarden/ssl/self/{_context.Install.Domain}/private.key " + + $"-out /bitwarden/ssl/self/{_context.Install.Domain}/certificate.crt " + + $"-reqexts SAN -extensions SAN " + + $"-config <(cat /usr/lib/ssl/openssl.cnf <(printf '[SAN]\nsubjectAltName=DNS:{_context.Install.Domain}\nbasicConstraints=CA:true')) " + + $"-subj \"/C=US/ST=California/L=Santa Barbara/O=Bitwarden Inc./OU=Bitwarden/CN={_context.Install.Domain}\""); + } } } + + if (_context.Config.SslManagedLetsEncrypt) + { + _context.Install.Trusted = true; + _context.Install.DiffieHellman = true; + Directory.CreateDirectory($"/bitwarden/letsencrypt/live/{_context.Install.Domain}/"); + Helpers.Exec($"openssl dhparam -out " + + $"/bitwarden/letsencrypt/live/{_context.Install.Domain}/dhparam.pem 2048"); + } + else if (_context.Config.Ssl && !_context.Install.SelfSignedCert) + { + _context.Install.Trusted = Helpers.ReadQuestion("Is this a trusted SSL certificate " + + "(requires ca.crt, see docs)?"); + } + + Helpers.WriteLine(_context, "Generating key for IdentityServer."); + _context.Install.IdentityCertPassword = Helpers.SecureRandomString(32, alpha: true, numeric: true); + Directory.CreateDirectory("/bitwarden/identity/"); + Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout identity.key " + + "-out identity.crt -subj \"/CN=Bitwarden IdentityServer\" -days 36500"); + Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + + $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword}"); + + Helpers.WriteLine(_context); + + if (!_context.Config.Ssl) + { + var message = "You are not using a SSL certificate. Bitwarden requires HTTPS to operate. \n" + + "You must front your installation with a HTTPS proxy or the web vault (and \n" + + "other Bitwarden apps) will not work properly."; + Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); + } + else if (_context.Config.Ssl && !_context.Install.Trusted) + { + var message = "You are using an untrusted SSL certificate. This certificate will not be \n" + + "trusted by Bitwarden client applications. You must add this certificate to \n" + + "the trusted store on each device or else you will receive errors when trying \n" + + "to connect to your installation."; + Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); + } } - if (_context.Config.SslManagedLetsEncrypt) + public void BuildForUpdater() { - _context.Install.Trusted = true; - _context.Install.DiffieHellman = true; - Directory.CreateDirectory($"/bitwarden/letsencrypt/live/{_context.Install.Domain}/"); - Helpers.Exec($"openssl dhparam -out " + - $"/bitwarden/letsencrypt/live/{_context.Install.Domain}/dhparam.pem 2048"); - } - else if (_context.Config.Ssl && !_context.Install.SelfSignedCert) - { - _context.Install.Trusted = Helpers.ReadQuestion("Is this a trusted SSL certificate " + - "(requires ca.crt, see docs)?"); - } - - Helpers.WriteLine(_context, "Generating key for IdentityServer."); - _context.Install.IdentityCertPassword = Helpers.SecureRandomString(32, alpha: true, numeric: true); - Directory.CreateDirectory("/bitwarden/identity/"); - Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout identity.key " + - "-out identity.crt -subj \"/CN=Bitwarden IdentityServer\" -days 36500"); - Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + - $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword}"); - - Helpers.WriteLine(_context); - - if (!_context.Config.Ssl) - { - var message = "You are not using a SSL certificate. Bitwarden requires HTTPS to operate. \n" + - "You must front your installation with a HTTPS proxy or the web vault (and \n" + - "other Bitwarden apps) will not work properly."; - Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); - } - else if (_context.Config.Ssl && !_context.Install.Trusted) - { - var message = "You are using an untrusted SSL certificate. This certificate will not be \n" + - "trusted by Bitwarden client applications. You must add this certificate to \n" + - "the trusted store on each device or else you will receive errors when trying \n" + - "to connect to your installation."; - Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); - } - } - - public void BuildForUpdater() - { - if (_context.Config.EnableKeyConnector && !File.Exists("/bitwarden/key-connector/bwkc.pfx")) - { - Directory.CreateDirectory("/bitwarden/key-connector/"); - var keyConnectorCertPassword = Helpers.GetValueFromEnvFile("key-connector", - "keyConnectorSettings__certificate__filesystemPassword"); - Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout bwkc.key " + - "-out bwkc.crt -subj \"/CN=Bitwarden Key Connector\" -days 36500"); - Helpers.Exec("openssl pkcs12 -export -out /bitwarden/key-connector/bwkc.pfx -inkey bwkc.key " + - $"-in bwkc.crt -passout pass:{keyConnectorCertPassword}"); + if (_context.Config.EnableKeyConnector && !File.Exists("/bitwarden/key-connector/bwkc.pfx")) + { + Directory.CreateDirectory("/bitwarden/key-connector/"); + var keyConnectorCertPassword = Helpers.GetValueFromEnvFile("key-connector", + "keyConnectorSettings__certificate__filesystemPassword"); + Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout bwkc.key " + + "-out bwkc.crt -subj \"/CN=Bitwarden Key Connector\" -days 36500"); + Helpers.Exec("openssl pkcs12 -export -out /bitwarden/key-connector/bwkc.pfx -inkey bwkc.key " + + $"-in bwkc.crt -passout pass:{keyConnectorCertPassword}"); + } } } } diff --git a/util/Setup/Configuration.cs b/util/Setup/Configuration.cs index b58b87952a..e0062b522f 100644 --- a/util/Setup/Configuration.cs +++ b/util/Setup/Configuration.cs @@ -1,121 +1,122 @@ using System.ComponentModel; using YamlDotNet.Serialization; -namespace Bit.Setup; - -public class Configuration +namespace Bit.Setup { - [Description("Note: After making changes to this file you need to run the `rebuild` or `update`\n" + - "command for them to be applied.\n\n" + - - "Full URL for accessing the installation from a browser. (Required)")] - public string Url { get; set; } = "https://localhost"; - - [Description("Auto-generate the `./docker/docker-compose.yml` config file.\n" + - "WARNING: Disabling generated config files can break future updates. You will be\n" + - "responsible for maintaining this config file.\n" + - "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/DockerCompose.hbs")] - public bool GenerateComposeConfig { get; set; } = true; - - [Description("Auto-generate the `./nginx/default.conf` file.\n" + - "WARNING: Disabling generated config files can break future updates. You will be\n" + - "responsible for maintaining this config file.\n" + - "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/NginxConfig.hbs")] - public bool GenerateNginxConfig { get; set; } = true; - - [Description("Docker compose file port mapping for HTTP. Leave empty to remove the port mapping.\n" + - "Learn more: https://docs.docker.com/compose/compose-file/#ports")] - public string HttpPort { get; set; } = "80"; - - [Description("Docker compose file port mapping for HTTPS. Leave empty to remove the port mapping.\n" + - "Learn more: https://docs.docker.com/compose/compose-file/#ports")] - public string HttpsPort { get; set; } = "443"; - - [Description("Docker compose file version. Leave empty for default.\n" + - "Learn more: https://docs.docker.com/compose/compose-file/compose-versioning/")] - public string ComposeVersion { get; set; } - - [Description("Configure Nginx for Captcha.")] - public bool Captcha { get; set; } = false; - - [Description("Configure Nginx for SSL.")] - public bool Ssl { get; set; } = true; - - [Description("SSL versions used by Nginx (ssl_protocols). Leave empty for recommended default.\n" + - "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] - public string SslVersions { get; set; } - - [Description("SSL ciphersuites used by Nginx (ssl_ciphers). Leave empty for recommended default.\n" + - "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] - public string SslCiphersuites { get; set; } - - [Description("Installation uses a managed Let's Encrypt certificate.")] - public bool SslManagedLetsEncrypt { get; set; } - - [Description("The actual certificate. (Required if using SSL without managed Let's Encrypt)\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslCertificatePath { get; set; } - - [Description("The certificate's private key. (Required if using SSL without managed Let's Encrypt)\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslKeyPath { get; set; } - - [Description("If the certificate is trusted by a CA, you should provide the CA's certificate.\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslCaPath { get; set; } - - [Description("Diffie Hellman ephemeral parameters\n" + - "Learn more: https://security.stackexchange.com/q/94390/79072\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslDiffieHellmanPath { get; set; } - - [Description("Nginx Header Content-Security-Policy parameter\n" + - "WARNING: Reconfiguring this parameter may break features. By changing this parameter\n" + - "you become responsible for maintaining this value.")] - public string NginxHeaderContentSecurityPolicy { get; set; } = "default-src 'self'; style-src 'self' " + - "'unsafe-inline'; img-src 'self' data: https://haveibeenpwned.com https://www.gravatar.com; " + - "child-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + - "frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + - "connect-src 'self' wss://{0} https://api.pwnedpasswords.com " + - "https://2fa.directory; object-src 'self' blob:;"; - - [Description("Communicate with the Bitwarden push relay service (push.bitwarden.com) for mobile\n" + - "app live sync.")] - public bool PushNotifications { get; set; } = true; - - [Description("Use a docker volume (`mssql_data`) instead of a host-mapped volume for the persisted " + - "database.\n" + - "WARNING: Changing this value will cause you to lose access to the existing persisted database.\n" + - "Learn more: https://docs.docker.com/storage/volumes/")] - public bool DatabaseDockerVolume { get; set; } - - [Description("Defines \"real\" IPs in nginx.conf. Useful for defining proxy servers that forward the \n" + - "client IP address.\n" + - "Learn more: https://nginx.org/en/docs/http/ngx_http_realip_module.html\n\n" + - "Defined as a dictionary, e.g.:\n" + - "real_ips: ['10.10.0.0/24', '172.16.0.0/16']")] - public List RealIps { get; set; } - - [Description("Enable Key Connector (https://bitwarden.com/help/article/deploy-key-connector)")] - public bool EnableKeyConnector { get; set; } = false; - - [Description("Enable SCIM")] - public bool EnableScim { get; set; } = false; - - [YamlIgnore] - public string Domain + public class Configuration { - get + [Description("Note: After making changes to this file you need to run the `rebuild` or `update`\n" + + "command for them to be applied.\n\n" + + + "Full URL for accessing the installation from a browser. (Required)")] + public string Url { get; set; } = "https://localhost"; + + [Description("Auto-generate the `./docker/docker-compose.yml` config file.\n" + + "WARNING: Disabling generated config files can break future updates. You will be\n" + + "responsible for maintaining this config file.\n" + + "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/DockerCompose.hbs")] + public bool GenerateComposeConfig { get; set; } = true; + + [Description("Auto-generate the `./nginx/default.conf` file.\n" + + "WARNING: Disabling generated config files can break future updates. You will be\n" + + "responsible for maintaining this config file.\n" + + "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/NginxConfig.hbs")] + public bool GenerateNginxConfig { get; set; } = true; + + [Description("Docker compose file port mapping for HTTP. Leave empty to remove the port mapping.\n" + + "Learn more: https://docs.docker.com/compose/compose-file/#ports")] + public string HttpPort { get; set; } = "80"; + + [Description("Docker compose file port mapping for HTTPS. Leave empty to remove the port mapping.\n" + + "Learn more: https://docs.docker.com/compose/compose-file/#ports")] + public string HttpsPort { get; set; } = "443"; + + [Description("Docker compose file version. Leave empty for default.\n" + + "Learn more: https://docs.docker.com/compose/compose-file/compose-versioning/")] + public string ComposeVersion { get; set; } + + [Description("Configure Nginx for Captcha.")] + public bool Captcha { get; set; } = false; + + [Description("Configure Nginx for SSL.")] + public bool Ssl { get; set; } = true; + + [Description("SSL versions used by Nginx (ssl_protocols). Leave empty for recommended default.\n" + + "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] + public string SslVersions { get; set; } + + [Description("SSL ciphersuites used by Nginx (ssl_ciphers). Leave empty for recommended default.\n" + + "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] + public string SslCiphersuites { get; set; } + + [Description("Installation uses a managed Let's Encrypt certificate.")] + public bool SslManagedLetsEncrypt { get; set; } + + [Description("The actual certificate. (Required if using SSL without managed Let's Encrypt)\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslCertificatePath { get; set; } + + [Description("The certificate's private key. (Required if using SSL without managed Let's Encrypt)\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslKeyPath { get; set; } + + [Description("If the certificate is trusted by a CA, you should provide the CA's certificate.\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslCaPath { get; set; } + + [Description("Diffie Hellman ephemeral parameters\n" + + "Learn more: https://security.stackexchange.com/q/94390/79072\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslDiffieHellmanPath { get; set; } + + [Description("Nginx Header Content-Security-Policy parameter\n" + + "WARNING: Reconfiguring this parameter may break features. By changing this parameter\n" + + "you become responsible for maintaining this value.")] + public string NginxHeaderContentSecurityPolicy { get; set; } = "default-src 'self'; style-src 'self' " + + "'unsafe-inline'; img-src 'self' data: https://haveibeenpwned.com https://www.gravatar.com; " + + "child-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + + "frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + + "connect-src 'self' wss://{0} https://api.pwnedpasswords.com " + + "https://2fa.directory; object-src 'self' blob:;"; + + [Description("Communicate with the Bitwarden push relay service (push.bitwarden.com) for mobile\n" + + "app live sync.")] + public bool PushNotifications { get; set; } = true; + + [Description("Use a docker volume (`mssql_data`) instead of a host-mapped volume for the persisted " + + "database.\n" + + "WARNING: Changing this value will cause you to lose access to the existing persisted database.\n" + + "Learn more: https://docs.docker.com/storage/volumes/")] + public bool DatabaseDockerVolume { get; set; } + + [Description("Defines \"real\" IPs in nginx.conf. Useful for defining proxy servers that forward the \n" + + "client IP address.\n" + + "Learn more: https://nginx.org/en/docs/http/ngx_http_realip_module.html\n\n" + + "Defined as a dictionary, e.g.:\n" + + "real_ips: ['10.10.0.0/24', '172.16.0.0/16']")] + public List RealIps { get; set; } + + [Description("Enable Key Connector (https://bitwarden.com/help/article/deploy-key-connector)")] + public bool EnableKeyConnector { get; set; } = false; + + [Description("Enable SCIM")] + public bool EnableScim { get; set; } = false; + + [YamlIgnore] + public string Domain { - if (Uri.TryCreate(Url, UriKind.Absolute, out var uri)) + get { - return uri.Host; + if (Uri.TryCreate(Url, UriKind.Absolute, out var uri)) + { + return uri.Host; + } + return null; } - return null; } } } diff --git a/util/Setup/Context.cs b/util/Setup/Context.cs index f82e5005c6..cf8efa90ef 100644 --- a/util/Setup/Context.cs +++ b/util/Setup/Context.cs @@ -1,152 +1,153 @@ using YamlDotNet.Serialization; using YamlDotNet.Serialization.NamingConventions; -namespace Bit.Setup; - -public class Context +namespace Bit.Setup { - private const string ConfigPath = "/bitwarden/config.yml"; - - public string[] Args { get; set; } - public bool Quiet { get; set; } - public bool Stub { get; set; } - public IDictionary Parameters { get; set; } - public string OutputDir { get; set; } = "/etc/bitwarden"; - public string HostOS { get; set; } = "win"; - public string CoreVersion { get; set; } = "latest"; - public string WebVersion { get; set; } = "latest"; - public string KeyConnectorVersion { get; set; } = "latest"; - public Installation Install { get; set; } = new Installation(); - public Configuration Config { get; set; } = new Configuration(); - - public bool PrintToScreen() + public class Context { - return !Quiet || Parameters.ContainsKey("install"); - } + private const string ConfigPath = "/bitwarden/config.yml"; - public void LoadConfiguration() - { - if (!File.Exists(ConfigPath)) + public string[] Args { get; set; } + public bool Quiet { get; set; } + public bool Stub { get; set; } + public IDictionary Parameters { get; set; } + public string OutputDir { get; set; } = "/etc/bitwarden"; + public string HostOS { get; set; } = "win"; + public string CoreVersion { get; set; } = "latest"; + public string WebVersion { get; set; } = "latest"; + public string KeyConnectorVersion { get; set; } = "latest"; + public Installation Install { get; set; } = new Installation(); + public Configuration Config { get; set; } = new Configuration(); + + public bool PrintToScreen() { - Helpers.WriteLine(this, "No existing `config.yml` detected. Let's generate one."); - - // Looks like updating from older version. Try to create config file. - var url = Helpers.GetValueFromEnvFile("global", "globalSettings__baseServiceUri__vault"); - if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) - { - Helpers.WriteLine(this, "Unable to determine existing installation url."); - return; - } - Config.Url = url; - - var push = Helpers.GetValueFromEnvFile("global", "globalSettings__pushRelayBaseUri"); - Config.PushNotifications = push != "REPLACE"; - - var composeFile = "/bitwarden/docker/docker-compose.yml"; - if (File.Exists(composeFile)) - { - var fileLines = File.ReadAllLines(composeFile); - foreach (var line in fileLines) - { - if (!line.StartsWith("# Parameter:")) - { - continue; - } - - var paramParts = line.Split("="); - if (paramParts.Length < 2) - { - continue; - } - - if (paramParts[0] == "# Parameter:MssqlDataDockerVolume" && - bool.TryParse(paramParts[1], out var mssqlDataDockerVolume)) - { - Config.DatabaseDockerVolume = mssqlDataDockerVolume; - continue; - } - - if (paramParts[0] == "# Parameter:HttpPort" && int.TryParse(paramParts[1], out var httpPort)) - { - Config.HttpPort = httpPort == 0 ? null : httpPort.ToString(); - continue; - } - - if (paramParts[0] == "# Parameter:HttpsPort" && int.TryParse(paramParts[1], out var httpsPort)) - { - Config.HttpsPort = httpsPort == 0 ? null : httpsPort.ToString(); - continue; - } - } - } - - var nginxFile = "/bitwarden/nginx/default.conf"; - if (File.Exists(nginxFile)) - { - var confContent = File.ReadAllText(nginxFile); - var selfSigned = confContent.Contains("/etc/ssl/self/"); - Config.Ssl = confContent.Contains("ssl http2;"); - Config.SslManagedLetsEncrypt = !selfSigned && confContent.Contains("/etc/letsencrypt/live/"); - var diffieHellman = confContent.Contains("/dhparam.pem;"); - var trusted = confContent.Contains("ssl_trusted_certificate "); - if (Config.SslManagedLetsEncrypt) - { - Config.Ssl = true; - } - else if (Config.Ssl) - { - var sslPath = selfSigned ? $"/etc/ssl/self/{Config.Domain}" : $"/etc/ssl/{Config.Domain}"; - Config.SslCertificatePath = string.Concat(sslPath, "/", "certificate.crt"); - Config.SslKeyPath = string.Concat(sslPath, "/", "private.key"); - if (trusted) - { - Config.SslCaPath = string.Concat(sslPath, "/", "ca.crt"); - } - if (diffieHellman) - { - Config.SslDiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); - } - } - } - - SaveConfiguration(); + return !Quiet || Parameters.ContainsKey("install"); } - var configText = File.ReadAllText(ConfigPath); - var deserializer = new DeserializerBuilder() - .WithNamingConvention(UnderscoredNamingConvention.Instance) - .Build(); - Config = deserializer.Deserialize(configText); - } - - public void SaveConfiguration() - { - if (Config == null) + public void LoadConfiguration() { - throw new Exception("Config is null."); - } - var serializer = new SerializerBuilder() - .WithNamingConvention(UnderscoredNamingConvention.Instance) - .WithTypeInspector(inner => new CommentGatheringTypeInspector(inner)) - .WithEmissionPhaseObjectGraphVisitor(args => new CommentsObjectGraphVisitor(args.InnerVisitor)) - .Build(); - var yaml = serializer.Serialize(Config); - Directory.CreateDirectory("/bitwarden/"); - using (var sw = File.CreateText(ConfigPath)) - { - sw.Write(yaml); - } - } + if (!File.Exists(ConfigPath)) + { + Helpers.WriteLine(this, "No existing `config.yml` detected. Let's generate one."); - public class Installation - { - public Guid InstallationId { get; set; } - public string InstallationKey { get; set; } - public bool DiffieHellman { get; set; } - public bool Trusted { get; set; } - public bool SelfSignedCert { get; set; } - public string IdentityCertPassword { get; set; } - public string Domain { get; set; } - public string Database { get; set; } + // Looks like updating from older version. Try to create config file. + var url = Helpers.GetValueFromEnvFile("global", "globalSettings__baseServiceUri__vault"); + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + { + Helpers.WriteLine(this, "Unable to determine existing installation url."); + return; + } + Config.Url = url; + + var push = Helpers.GetValueFromEnvFile("global", "globalSettings__pushRelayBaseUri"); + Config.PushNotifications = push != "REPLACE"; + + var composeFile = "/bitwarden/docker/docker-compose.yml"; + if (File.Exists(composeFile)) + { + var fileLines = File.ReadAllLines(composeFile); + foreach (var line in fileLines) + { + if (!line.StartsWith("# Parameter:")) + { + continue; + } + + var paramParts = line.Split("="); + if (paramParts.Length < 2) + { + continue; + } + + if (paramParts[0] == "# Parameter:MssqlDataDockerVolume" && + bool.TryParse(paramParts[1], out var mssqlDataDockerVolume)) + { + Config.DatabaseDockerVolume = mssqlDataDockerVolume; + continue; + } + + if (paramParts[0] == "# Parameter:HttpPort" && int.TryParse(paramParts[1], out var httpPort)) + { + Config.HttpPort = httpPort == 0 ? null : httpPort.ToString(); + continue; + } + + if (paramParts[0] == "# Parameter:HttpsPort" && int.TryParse(paramParts[1], out var httpsPort)) + { + Config.HttpsPort = httpsPort == 0 ? null : httpsPort.ToString(); + continue; + } + } + } + + var nginxFile = "/bitwarden/nginx/default.conf"; + if (File.Exists(nginxFile)) + { + var confContent = File.ReadAllText(nginxFile); + var selfSigned = confContent.Contains("/etc/ssl/self/"); + Config.Ssl = confContent.Contains("ssl http2;"); + Config.SslManagedLetsEncrypt = !selfSigned && confContent.Contains("/etc/letsencrypt/live/"); + var diffieHellman = confContent.Contains("/dhparam.pem;"); + var trusted = confContent.Contains("ssl_trusted_certificate "); + if (Config.SslManagedLetsEncrypt) + { + Config.Ssl = true; + } + else if (Config.Ssl) + { + var sslPath = selfSigned ? $"/etc/ssl/self/{Config.Domain}" : $"/etc/ssl/{Config.Domain}"; + Config.SslCertificatePath = string.Concat(sslPath, "/", "certificate.crt"); + Config.SslKeyPath = string.Concat(sslPath, "/", "private.key"); + if (trusted) + { + Config.SslCaPath = string.Concat(sslPath, "/", "ca.crt"); + } + if (diffieHellman) + { + Config.SslDiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); + } + } + } + + SaveConfiguration(); + } + + var configText = File.ReadAllText(ConfigPath); + var deserializer = new DeserializerBuilder() + .WithNamingConvention(UnderscoredNamingConvention.Instance) + .Build(); + Config = deserializer.Deserialize(configText); + } + + public void SaveConfiguration() + { + if (Config == null) + { + throw new Exception("Config is null."); + } + var serializer = new SerializerBuilder() + .WithNamingConvention(UnderscoredNamingConvention.Instance) + .WithTypeInspector(inner => new CommentGatheringTypeInspector(inner)) + .WithEmissionPhaseObjectGraphVisitor(args => new CommentsObjectGraphVisitor(args.InnerVisitor)) + .Build(); + var yaml = serializer.Serialize(Config); + Directory.CreateDirectory("/bitwarden/"); + using (var sw = File.CreateText(ConfigPath)) + { + sw.Write(yaml); + } + } + + public class Installation + { + public Guid InstallationId { get; set; } + public string InstallationKey { get; set; } + public bool DiffieHellman { get; set; } + public bool Trusted { get; set; } + public bool SelfSignedCert { get; set; } + public string IdentityCertPassword { get; set; } + public string Domain { get; set; } + public string Database { get; set; } + } } } diff --git a/util/Setup/DockerComposeBuilder.cs b/util/Setup/DockerComposeBuilder.cs index 0d76dc9e92..d007ffe1c5 100644 --- a/util/Setup/DockerComposeBuilder.cs +++ b/util/Setup/DockerComposeBuilder.cs @@ -1,79 +1,80 @@ -namespace Bit.Setup; - -public class DockerComposeBuilder +namespace Bit.Setup { - private readonly Context _context; - - public DockerComposeBuilder(Context context) + public class DockerComposeBuilder { - _context = context; - } + private readonly Context _context; - public void BuildForInstaller() - { - _context.Config.DatabaseDockerVolume = _context.HostOS == "mac"; - Build(); - } - - public void BuildForUpdater() - { - Build(); - } - - private void Build() - { - Directory.CreateDirectory("/bitwarden/docker/"); - Helpers.WriteLine(_context, "Building docker-compose.yml."); - if (!_context.Config.GenerateComposeConfig) + public DockerComposeBuilder(Context context) { - Helpers.WriteLine(_context, "...skipped"); - return; + _context = context; } - var template = Helpers.ReadTemplate("DockerCompose"); - var model = new TemplateModel(_context); - using (var sw = File.CreateText("/bitwarden/docker/docker-compose.yml")) + public void BuildForInstaller() { - sw.Write(template(model)); + _context.Config.DatabaseDockerVolume = _context.HostOS == "mac"; + Build(); } - } - public class TemplateModel - { - public TemplateModel(Context context) + public void BuildForUpdater() { - if (!string.IsNullOrWhiteSpace(context.Config.ComposeVersion)) + Build(); + } + + private void Build() + { + Directory.CreateDirectory("/bitwarden/docker/"); + Helpers.WriteLine(_context, "Building docker-compose.yml."); + if (!_context.Config.GenerateComposeConfig) { - ComposeVersion = context.Config.ComposeVersion; + Helpers.WriteLine(_context, "...skipped"); + return; } - MssqlDataDockerVolume = context.Config.DatabaseDockerVolume; - EnableKeyConnector = context.Config.EnableKeyConnector; - EnableScim = context.Config.EnableScim; - HttpPort = context.Config.HttpPort; - HttpsPort = context.Config.HttpsPort; - if (!string.IsNullOrWhiteSpace(context.CoreVersion)) + + var template = Helpers.ReadTemplate("DockerCompose"); + var model = new TemplateModel(_context); + using (var sw = File.CreateText("/bitwarden/docker/docker-compose.yml")) { - CoreVersion = context.CoreVersion; - } - if (!string.IsNullOrWhiteSpace(context.WebVersion)) - { - WebVersion = context.WebVersion; - } - if (!string.IsNullOrWhiteSpace(context.KeyConnectorVersion)) - { - KeyConnectorVersion = context.KeyConnectorVersion; + sw.Write(template(model)); } } - public string ComposeVersion { get; set; } = "3"; - public bool MssqlDataDockerVolume { get; set; } - public bool EnableKeyConnector { get; set; } - public bool EnableScim { get; set; } - public string HttpPort { get; set; } - public string HttpsPort { get; set; } - public bool HasPort => !string.IsNullOrWhiteSpace(HttpPort) || !string.IsNullOrWhiteSpace(HttpsPort); - public string CoreVersion { get; set; } = "latest"; - public string WebVersion { get; set; } = "latest"; - public string KeyConnectorVersion { get; set; } = "latest"; + public class TemplateModel + { + public TemplateModel(Context context) + { + if (!string.IsNullOrWhiteSpace(context.Config.ComposeVersion)) + { + ComposeVersion = context.Config.ComposeVersion; + } + MssqlDataDockerVolume = context.Config.DatabaseDockerVolume; + EnableKeyConnector = context.Config.EnableKeyConnector; + EnableScim = context.Config.EnableScim; + HttpPort = context.Config.HttpPort; + HttpsPort = context.Config.HttpsPort; + if (!string.IsNullOrWhiteSpace(context.CoreVersion)) + { + CoreVersion = context.CoreVersion; + } + if (!string.IsNullOrWhiteSpace(context.WebVersion)) + { + WebVersion = context.WebVersion; + } + if (!string.IsNullOrWhiteSpace(context.KeyConnectorVersion)) + { + KeyConnectorVersion = context.KeyConnectorVersion; + } + } + + public string ComposeVersion { get; set; } = "3"; + public bool MssqlDataDockerVolume { get; set; } + public bool EnableKeyConnector { get; set; } + public bool EnableScim { get; set; } + public string HttpPort { get; set; } + public string HttpsPort { get; set; } + public bool HasPort => !string.IsNullOrWhiteSpace(HttpPort) || !string.IsNullOrWhiteSpace(HttpsPort); + public string CoreVersion { get; set; } = "latest"; + public string WebVersion { get; set; } = "latest"; + public string KeyConnectorVersion { get; set; } = "latest"; + } } } diff --git a/util/Setup/EnvironmentFileBuilder.cs b/util/Setup/EnvironmentFileBuilder.cs index 893ca85376..77a94bd064 100644 --- a/util/Setup/EnvironmentFileBuilder.cs +++ b/util/Setup/EnvironmentFileBuilder.cs @@ -1,224 +1,225 @@ using System.Data.SqlClient; -namespace Bit.Setup; - -public class EnvironmentFileBuilder +namespace Bit.Setup { - private readonly Context _context; - - private IDictionary _globalValues; - private IDictionary _mssqlValues; - private IDictionary _globalOverrideValues; - private IDictionary _mssqlOverrideValues; - private IDictionary _keyConnectorOverrideValues; - - public EnvironmentFileBuilder(Context context) + public class EnvironmentFileBuilder { - _context = context; - _globalValues = new Dictionary + private readonly Context _context; + + private IDictionary _globalValues; + private IDictionary _mssqlValues; + private IDictionary _globalOverrideValues; + private IDictionary _mssqlOverrideValues; + private IDictionary _keyConnectorOverrideValues; + + public EnvironmentFileBuilder(Context context) { - ["ASPNETCORE_ENVIRONMENT"] = "Production", - ["globalSettings__selfHosted"] = "true", - ["globalSettings__baseServiceUri__vault"] = "http://localhost", - ["globalSettings__pushRelayBaseUri"] = "https://push.bitwarden.com", - }; - _mssqlValues = new Dictionary - { - ["ACCEPT_EULA"] = "Y", - ["MSSQL_PID"] = "Express", - ["SA_PASSWORD"] = "SECRET", - }; - } - - public void BuildForInstaller() - { - Directory.CreateDirectory("/bitwarden/env/"); - Init(); - Build(); - } - - public void BuildForUpdater() - { - Init(); - LoadExistingValues(_globalOverrideValues, "/bitwarden/env/global.override.env"); - LoadExistingValues(_mssqlOverrideValues, "/bitwarden/env/mssql.override.env"); - LoadExistingValues(_keyConnectorOverrideValues, "/bitwarden/env/key-connector.override.env"); - - if (_context.Config.PushNotifications && - _globalOverrideValues.ContainsKey("globalSettings__pushRelayBaseUri") && - _globalOverrideValues["globalSettings__pushRelayBaseUri"] == "REPLACE") - { - _globalOverrideValues.Remove("globalSettings__pushRelayBaseUri"); - } - - Build(); - } - - private void Init() - { - var dbPassword = _context.Stub ? "RANDOM_DATABASE_PASSWORD" : Helpers.SecureRandomString(32); - var dbConnectionString = new SqlConnectionStringBuilder - { - DataSource = "tcp:mssql,1433", - InitialCatalog = _context.Install?.Database ?? "vault", - UserID = "sa", - Password = dbPassword, - MultipleActiveResultSets = false, - Encrypt = true, - ConnectTimeout = 30, - TrustServerCertificate = true, - PersistSecurityInfo = false - }.ConnectionString; - - _globalOverrideValues = new Dictionary - { - ["globalSettings__baseServiceUri__vault"] = _context.Config.Url, - ["globalSettings__sqlServer__connectionString"] = $"\"{dbConnectionString.Replace("\"", "\\\"")}\"", - ["globalSettings__identityServer__certificatePassword"] = _context.Install?.IdentityCertPassword, - ["globalSettings__internalIdentityKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : - Helpers.SecureRandomString(64, alpha: true, numeric: true), - ["globalSettings__oidcIdentityClientKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : - Helpers.SecureRandomString(64, alpha: true, numeric: true), - ["globalSettings__duo__aKey"] = _context.Stub ? "RANDOM_DUO_AKEY" : - Helpers.SecureRandomString(64, alpha: true, numeric: true), - ["globalSettings__installation__id"] = _context.Install?.InstallationId.ToString(), - ["globalSettings__installation__key"] = _context.Install?.InstallationKey, - ["globalSettings__yubico__clientId"] = "REPLACE", - ["globalSettings__yubico__key"] = "REPLACE", - ["globalSettings__mail__replyToEmail"] = $"no-reply@{_context.Config.Domain}", - ["globalSettings__mail__smtp__host"] = "REPLACE", - ["globalSettings__mail__smtp__port"] = "587", - ["globalSettings__mail__smtp__ssl"] = "false", - ["globalSettings__mail__smtp__username"] = "REPLACE", - ["globalSettings__mail__smtp__password"] = "REPLACE", - ["globalSettings__disableUserRegistration"] = "false", - ["globalSettings__hibpApiKey"] = "REPLACE", - ["adminSettings__admins"] = string.Empty, - }; - - if (!_context.Config.PushNotifications) - { - _globalOverrideValues.Add("globalSettings__pushRelayBaseUri", "REPLACE"); - } - - _mssqlOverrideValues = new Dictionary - { - ["SA_PASSWORD"] = dbPassword, - }; - - _keyConnectorOverrideValues = new Dictionary - { - ["keyConnectorSettings__webVaultUri"] = _context.Config.Url, - ["keyConnectorSettings__identityServerUri"] = "http://identity:5000", - ["keyConnectorSettings__database__provider"] = "json", - ["keyConnectorSettings__database__jsonFilePath"] = "/etc/bitwarden/key-connector/data.json", - ["keyConnectorSettings__rsaKey__provider"] = "certificate", - ["keyConnectorSettings__certificate__provider"] = "filesystem", - ["keyConnectorSettings__certificate__filesystemPath"] = "/etc/bitwarden/key-connector/bwkc.pfx", - ["keyConnectorSettings__certificate__filesystemPassword"] = Helpers.SecureRandomString(32, alpha: true, numeric: true), - }; - } - - private void LoadExistingValues(IDictionary _values, string file) - { - if (!File.Exists(file)) - { - return; - } - - var fileLines = File.ReadAllLines(file); - foreach (var line in fileLines) - { - if (!line.Contains("=")) + _context = context; + _globalValues = new Dictionary { - continue; + ["ASPNETCORE_ENVIRONMENT"] = "Production", + ["globalSettings__selfHosted"] = "true", + ["globalSettings__baseServiceUri__vault"] = "http://localhost", + ["globalSettings__pushRelayBaseUri"] = "https://push.bitwarden.com", + }; + _mssqlValues = new Dictionary + { + ["ACCEPT_EULA"] = "Y", + ["MSSQL_PID"] = "Express", + ["SA_PASSWORD"] = "SECRET", + }; + } + + public void BuildForInstaller() + { + Directory.CreateDirectory("/bitwarden/env/"); + Init(); + Build(); + } + + public void BuildForUpdater() + { + Init(); + LoadExistingValues(_globalOverrideValues, "/bitwarden/env/global.override.env"); + LoadExistingValues(_mssqlOverrideValues, "/bitwarden/env/mssql.override.env"); + LoadExistingValues(_keyConnectorOverrideValues, "/bitwarden/env/key-connector.override.env"); + + if (_context.Config.PushNotifications && + _globalOverrideValues.ContainsKey("globalSettings__pushRelayBaseUri") && + _globalOverrideValues["globalSettings__pushRelayBaseUri"] == "REPLACE") + { + _globalOverrideValues.Remove("globalSettings__pushRelayBaseUri"); } - var value = string.Empty; - var lineParts = line.Split("=", 2); - if (lineParts.Length < 1) + Build(); + } + + private void Init() + { + var dbPassword = _context.Stub ? "RANDOM_DATABASE_PASSWORD" : Helpers.SecureRandomString(32); + var dbConnectionString = new SqlConnectionStringBuilder { - continue; + DataSource = "tcp:mssql,1433", + InitialCatalog = _context.Install?.Database ?? "vault", + UserID = "sa", + Password = dbPassword, + MultipleActiveResultSets = false, + Encrypt = true, + ConnectTimeout = 30, + TrustServerCertificate = true, + PersistSecurityInfo = false + }.ConnectionString; + + _globalOverrideValues = new Dictionary + { + ["globalSettings__baseServiceUri__vault"] = _context.Config.Url, + ["globalSettings__sqlServer__connectionString"] = $"\"{dbConnectionString.Replace("\"", "\\\"")}\"", + ["globalSettings__identityServer__certificatePassword"] = _context.Install?.IdentityCertPassword, + ["globalSettings__internalIdentityKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : + Helpers.SecureRandomString(64, alpha: true, numeric: true), + ["globalSettings__oidcIdentityClientKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : + Helpers.SecureRandomString(64, alpha: true, numeric: true), + ["globalSettings__duo__aKey"] = _context.Stub ? "RANDOM_DUO_AKEY" : + Helpers.SecureRandomString(64, alpha: true, numeric: true), + ["globalSettings__installation__id"] = _context.Install?.InstallationId.ToString(), + ["globalSettings__installation__key"] = _context.Install?.InstallationKey, + ["globalSettings__yubico__clientId"] = "REPLACE", + ["globalSettings__yubico__key"] = "REPLACE", + ["globalSettings__mail__replyToEmail"] = $"no-reply@{_context.Config.Domain}", + ["globalSettings__mail__smtp__host"] = "REPLACE", + ["globalSettings__mail__smtp__port"] = "587", + ["globalSettings__mail__smtp__ssl"] = "false", + ["globalSettings__mail__smtp__username"] = "REPLACE", + ["globalSettings__mail__smtp__password"] = "REPLACE", + ["globalSettings__disableUserRegistration"] = "false", + ["globalSettings__hibpApiKey"] = "REPLACE", + ["adminSettings__admins"] = string.Empty, + }; + + if (!_context.Config.PushNotifications) + { + _globalOverrideValues.Add("globalSettings__pushRelayBaseUri", "REPLACE"); } - if (lineParts.Length > 1) + _mssqlOverrideValues = new Dictionary { - value = lineParts[1]; + ["SA_PASSWORD"] = dbPassword, + }; + + _keyConnectorOverrideValues = new Dictionary + { + ["keyConnectorSettings__webVaultUri"] = _context.Config.Url, + ["keyConnectorSettings__identityServerUri"] = "http://identity:5000", + ["keyConnectorSettings__database__provider"] = "json", + ["keyConnectorSettings__database__jsonFilePath"] = "/etc/bitwarden/key-connector/data.json", + ["keyConnectorSettings__rsaKey__provider"] = "certificate", + ["keyConnectorSettings__certificate__provider"] = "filesystem", + ["keyConnectorSettings__certificate__filesystemPath"] = "/etc/bitwarden/key-connector/bwkc.pfx", + ["keyConnectorSettings__certificate__filesystemPassword"] = Helpers.SecureRandomString(32, alpha: true, numeric: true), + }; + } + + private void LoadExistingValues(IDictionary _values, string file) + { + if (!File.Exists(file)) + { + return; } - if (_values.ContainsKey(lineParts[0])) + var fileLines = File.ReadAllLines(file); + foreach (var line in fileLines) { - _values[lineParts[0]] = value; - } - else - { - _values.Add(lineParts[0], value.Replace("\\\"", "\"")); + if (!line.Contains("=")) + { + continue; + } + + var value = string.Empty; + var lineParts = line.Split("=", 2); + if (lineParts.Length < 1) + { + continue; + } + + if (lineParts.Length > 1) + { + value = lineParts[1]; + } + + if (_values.ContainsKey(lineParts[0])) + { + _values[lineParts[0]] = value; + } + else + { + _values.Add(lineParts[0], value.Replace("\\\"", "\"")); + } } } - } - private void Build() - { - var template = Helpers.ReadTemplate("EnvironmentFile"); - - Helpers.WriteLine(_context, "Building docker environment files."); - Directory.CreateDirectory("/bitwarden/docker/"); - using (var sw = File.CreateText("/bitwarden/docker/global.env")) + private void Build() { - sw.Write(template(new TemplateModel(_globalValues))); - } - Helpers.Exec("chmod 600 /bitwarden/docker/global.env"); + var template = Helpers.ReadTemplate("EnvironmentFile"); - using (var sw = File.CreateText("/bitwarden/docker/mssql.env")) - { - sw.Write(template(new TemplateModel(_mssqlValues))); - } - Helpers.Exec("chmod 600 /bitwarden/docker/mssql.env"); - - Helpers.WriteLine(_context, "Building docker environment override files."); - Directory.CreateDirectory("/bitwarden/env/"); - using (var sw = File.CreateText("/bitwarden/env/global.override.env")) - { - sw.Write(template(new TemplateModel(_globalOverrideValues))); - } - Helpers.Exec("chmod 600 /bitwarden/env/global.override.env"); - - using (var sw = File.CreateText("/bitwarden/env/mssql.override.env")) - { - sw.Write(template(new TemplateModel(_mssqlOverrideValues))); - } - Helpers.Exec("chmod 600 /bitwarden/env/mssql.override.env"); - - if (_context.Config.EnableKeyConnector) - { - using (var sw = File.CreateText("/bitwarden/env/key-connector.override.env")) + Helpers.WriteLine(_context, "Building docker environment files."); + Directory.CreateDirectory("/bitwarden/docker/"); + using (var sw = File.CreateText("/bitwarden/docker/global.env")) { - sw.Write(template(new TemplateModel(_keyConnectorOverrideValues))); + sw.Write(template(new TemplateModel(_globalValues))); + } + Helpers.Exec("chmod 600 /bitwarden/docker/global.env"); + + using (var sw = File.CreateText("/bitwarden/docker/mssql.env")) + { + sw.Write(template(new TemplateModel(_mssqlValues))); + } + Helpers.Exec("chmod 600 /bitwarden/docker/mssql.env"); + + Helpers.WriteLine(_context, "Building docker environment override files."); + Directory.CreateDirectory("/bitwarden/env/"); + using (var sw = File.CreateText("/bitwarden/env/global.override.env")) + { + sw.Write(template(new TemplateModel(_globalOverrideValues))); + } + Helpers.Exec("chmod 600 /bitwarden/env/global.override.env"); + + using (var sw = File.CreateText("/bitwarden/env/mssql.override.env")) + { + sw.Write(template(new TemplateModel(_mssqlOverrideValues))); + } + Helpers.Exec("chmod 600 /bitwarden/env/mssql.override.env"); + + if (_context.Config.EnableKeyConnector) + { + using (var sw = File.CreateText("/bitwarden/env/key-connector.override.env")) + { + sw.Write(template(new TemplateModel(_keyConnectorOverrideValues))); + } + + Helpers.Exec("chmod 600 /bitwarden/env/key-connector.override.env"); } - Helpers.Exec("chmod 600 /bitwarden/env/key-connector.override.env"); + // Empty uid env file. Only used on Linux hosts. + if (!File.Exists("/bitwarden/env/uid.env")) + { + using (var sw = File.CreateText("/bitwarden/env/uid.env")) { } + } } - // Empty uid env file. Only used on Linux hosts. - if (!File.Exists("/bitwarden/env/uid.env")) + public class TemplateModel { - using (var sw = File.CreateText("/bitwarden/env/uid.env")) { } - } - } + public TemplateModel(IEnumerable> variables) + { + Variables = variables.Select(v => new Kvp { Key = v.Key, Value = v.Value }); + } - public class TemplateModel - { - public TemplateModel(IEnumerable> variables) - { - Variables = variables.Select(v => new Kvp { Key = v.Key, Value = v.Value }); - } + public IEnumerable Variables { get; set; } - public IEnumerable Variables { get; set; } - - public class Kvp - { - public string Key { get; set; } - public string Value { get; set; } + public class Kvp + { + public string Key { get; set; } + public string Value { get; set; } + } } } } diff --git a/util/Setup/Helpers.cs b/util/Setup/Helpers.cs index ea7351b98f..06c48f2fe6 100644 --- a/util/Setup/Helpers.cs +++ b/util/Setup/Helpers.cs @@ -4,222 +4,223 @@ using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; -namespace Bit.Setup; - -public static class Helpers +namespace Bit.Setup { - public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, - bool numeric = true, bool special = false) + public static class Helpers { - return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); - } - - // ref https://stackoverflow.com/a/8996788/1090359 with modifications - public static string SecureRandomString(int length, string characters) - { - if (length < 0) + public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, + bool numeric = true, bool special = false) { - throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); + return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); } - if ((characters?.Length ?? 0) == 0) + // ref https://stackoverflow.com/a/8996788/1090359 with modifications + public static string SecureRandomString(int length, string characters) { - throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); - } - - const int byteSize = 0x100; - if (byteSize < characters.Length) - { - throw new ArgumentException( - string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), - nameof(characters)); - } - - var outOfRangeStart = byteSize - (byteSize % characters.Length); - using (var rng = RandomNumberGenerator.Create()) - { - var sb = new StringBuilder(); - var buffer = new byte[128]; - while (sb.Length < length) + if (length < 0) { - rng.GetBytes(buffer); - for (var i = 0; i < buffer.Length && sb.Length < length; ++i) - { - // Divide the byte into charSet-sized groups. If the random value falls into the last group and the - // last group is too small to choose from the entire allowedCharSet, ignore the value in order to - // avoid biasing the result. - if (outOfRangeStart <= buffer[i]) - { - continue; - } + throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); + } - sb.Append(characters[buffer[i] % characters.Length]); + if ((characters?.Length ?? 0) == 0) + { + throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); + } + + const int byteSize = 0x100; + if (byteSize < characters.Length) + { + throw new ArgumentException( + string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), + nameof(characters)); + } + + var outOfRangeStart = byteSize - (byteSize % characters.Length); + using (var rng = RandomNumberGenerator.Create()) + { + var sb = new StringBuilder(); + var buffer = new byte[128]; + while (sb.Length < length) + { + rng.GetBytes(buffer); + for (var i = 0; i < buffer.Length && sb.Length < length; ++i) + { + // Divide the byte into charSet-sized groups. If the random value falls into the last group and the + // last group is too small to choose from the entire allowedCharSet, ignore the value in order to + // avoid biasing the result. + if (outOfRangeStart <= buffer[i]) + { + continue; + } + + sb.Append(characters[buffer[i] % characters.Length]); + } + } + + return sb.ToString(); + } + } + + private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) + { + var characters = string.Empty; + if (alpha) + { + if (upper) + { + characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + } + + if (lower) + { + characters += "abcdefghijklmnopqrstuvwxyz"; } } - return sb.ToString(); - } - } - - private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) - { - var characters = string.Empty; - if (alpha) - { - if (upper) + if (numeric) { - characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + characters += "0123456789"; } - if (lower) + if (special) { - characters += "abcdefghijklmnopqrstuvwxyz"; + characters += "!@#$%^*&"; } + + return characters; } - if (numeric) + public static string GetValueFromEnvFile(string envFile, string key) { - characters += "0123456789"; - } + if (!File.Exists($"/bitwarden/env/{envFile}.override.env")) + { + return null; + } - if (special) - { - characters += "!@#$%^*&"; - } + var lines = File.ReadAllLines($"/bitwarden/env/{envFile}.override.env"); + foreach (var line in lines) + { + if (line.StartsWith($"{key}=")) + { + return line.Split(new char[] { '=' }, 2)[1].Trim('"').Replace("\\\"", "\""); + } + } - return characters; - } - - public static string GetValueFromEnvFile(string envFile, string key) - { - if (!File.Exists($"/bitwarden/env/{envFile}.override.env")) - { return null; } - var lines = File.ReadAllLines($"/bitwarden/env/{envFile}.override.env"); - foreach (var line in lines) + public static string Exec(string cmd, bool returnStdout = false) { - if (line.StartsWith($"{key}=")) + var process = new Process { - return line.Split(new char[] { '=' }, 2)[1].Trim('"').Replace("\\\"", "\""); - } - } + StartInfo = new ProcessStartInfo + { + RedirectStandardOutput = true, + UseShellExecute = false, + CreateNoWindow = true, + WindowStyle = ProcessWindowStyle.Hidden + } + }; - return null; - } - - public static string Exec(string cmd, bool returnStdout = false) - { - var process = new Process - { - StartInfo = new ProcessStartInfo + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - RedirectStandardOutput = true, - UseShellExecute = false, - CreateNoWindow = true, - WindowStyle = ProcessWindowStyle.Hidden + var escapedArgs = cmd.Replace("\"", "\\\""); + process.StartInfo.FileName = "/bin/bash"; + process.StartInfo.Arguments = $"-c \"{escapedArgs}\""; + } + else + { + process.StartInfo.FileName = "powershell"; + process.StartInfo.Arguments = cmd; } - }; - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var escapedArgs = cmd.Replace("\"", "\\\""); - process.StartInfo.FileName = "/bin/bash"; - process.StartInfo.Arguments = $"-c \"{escapedArgs}\""; - } - else - { - process.StartInfo.FileName = "powershell"; - process.StartInfo.Arguments = cmd; + process.Start(); + var result = returnStdout ? process.StandardOutput.ReadToEnd() : null; + process.WaitForExit(); + return result; } - process.Start(); - var result = returnStdout ? process.StandardOutput.ReadToEnd() : null; - process.WaitForExit(); - return result; - } - - public static string ReadInput(string prompt) - { - Console.ForegroundColor = ConsoleColor.Cyan; - Console.Write("(!) "); - Console.ResetColor(); - Console.Write(prompt); - if (prompt.EndsWith("?")) - { - Console.Write(" (y/n)"); - } - Console.Write(": "); - var input = Console.ReadLine(); - Console.WriteLine(); - return input; - } - - public static bool ReadQuestion(string prompt) - { - var input = ReadInput(prompt).ToLowerInvariant().Trim(); - return input == "y" || input == "yes"; - } - - public static void ShowBanner(Context context, string title, string message, ConsoleColor? color = null) - { - if (!context.PrintToScreen()) - { - return; - } - if (color != null) - { - Console.ForegroundColor = color.Value; - } - Console.WriteLine($"!!!!!!!!!! {title} !!!!!!!!!!"); - Console.WriteLine(message); - Console.WriteLine(); - Console.ResetColor(); - } - - public static HandlebarsDotNet.HandlebarsTemplate ReadTemplate(string templateName) - { - var assembly = typeof(Helpers).GetTypeInfo().Assembly; - var fullTemplateName = $"Bit.Setup.Templates.{templateName}.hbs"; - if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) - { - return null; - } - using (var s = assembly.GetManifestResourceStream(fullTemplateName)) - using (var sr = new StreamReader(s)) - { - var templateText = sr.ReadToEnd(); - return HandlebarsDotNet.Handlebars.Compile(templateText); - } - } - - public static void WriteLine(Context context, string format = null, object arg0 = null, object arg1 = null, - object arg2 = null) - { - if (!context.PrintToScreen()) - { - return; - } - if (format != null && arg0 != null && arg1 != null && arg2 != null) - { - Console.WriteLine(format, arg0, arg1, arg2); - } - else if (format != null && arg0 != null && arg1 != null) - { - Console.WriteLine(format, arg0, arg1); - } - else if (format != null && arg0 != null) - { - Console.WriteLine(format, arg0); - } - else if (format != null) - { - Console.WriteLine(format); - } - else + public static string ReadInput(string prompt) { + Console.ForegroundColor = ConsoleColor.Cyan; + Console.Write("(!) "); + Console.ResetColor(); + Console.Write(prompt); + if (prompt.EndsWith("?")) + { + Console.Write(" (y/n)"); + } + Console.Write(": "); + var input = Console.ReadLine(); Console.WriteLine(); + return input; + } + + public static bool ReadQuestion(string prompt) + { + var input = ReadInput(prompt).ToLowerInvariant().Trim(); + return input == "y" || input == "yes"; + } + + public static void ShowBanner(Context context, string title, string message, ConsoleColor? color = null) + { + if (!context.PrintToScreen()) + { + return; + } + if (color != null) + { + Console.ForegroundColor = color.Value; + } + Console.WriteLine($"!!!!!!!!!! {title} !!!!!!!!!!"); + Console.WriteLine(message); + Console.WriteLine(); + Console.ResetColor(); + } + + public static HandlebarsDotNet.HandlebarsTemplate ReadTemplate(string templateName) + { + var assembly = typeof(Helpers).GetTypeInfo().Assembly; + var fullTemplateName = $"Bit.Setup.Templates.{templateName}.hbs"; + if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) + { + return null; + } + using (var s = assembly.GetManifestResourceStream(fullTemplateName)) + using (var sr = new StreamReader(s)) + { + var templateText = sr.ReadToEnd(); + return HandlebarsDotNet.Handlebars.Compile(templateText); + } + } + + public static void WriteLine(Context context, string format = null, object arg0 = null, object arg1 = null, + object arg2 = null) + { + if (!context.PrintToScreen()) + { + return; + } + if (format != null && arg0 != null && arg1 != null && arg2 != null) + { + Console.WriteLine(format, arg0, arg1, arg2); + } + else if (format != null && arg0 != null && arg1 != null) + { + Console.WriteLine(format, arg0, arg1); + } + else if (format != null && arg0 != null) + { + Console.WriteLine(format, arg0); + } + else if (format != null) + { + Console.WriteLine(format); + } + else + { + Console.WriteLine(); + } } } } diff --git a/util/Setup/NginxConfigBuilder.cs b/util/Setup/NginxConfigBuilder.cs index 420793cef7..f2ad08ced2 100644 --- a/util/Setup/NginxConfigBuilder.cs +++ b/util/Setup/NginxConfigBuilder.cs @@ -1,132 +1,133 @@ -namespace Bit.Setup; - -public class NginxConfigBuilder +namespace Bit.Setup { - private const string ConfFile = "/bitwarden/nginx/default.conf"; - - private readonly Context _context; - - public NginxConfigBuilder(Context context) + public class NginxConfigBuilder { - _context = context; - } + private const string ConfFile = "/bitwarden/nginx/default.conf"; - public void BuildForInstaller() - { - var model = new TemplateModel(_context); - if (model.Ssl && !_context.Config.SslManagedLetsEncrypt) + private readonly Context _context; + + public NginxConfigBuilder(Context context) { - var sslPath = _context.Install.SelfSignedCert ? - $"/etc/ssl/self/{model.Domain}" : $"/etc/ssl/{model.Domain}"; - _context.Config.SslCertificatePath = model.CertificatePath = - string.Concat(sslPath, "/", "certificate.crt"); - _context.Config.SslKeyPath = model.KeyPath = - string.Concat(sslPath, "/", "private.key"); - if (_context.Install.Trusted) - { - _context.Config.SslCaPath = model.CaPath = - string.Concat(sslPath, "/", "ca.crt"); - } - if (_context.Install.DiffieHellman) - { - _context.Config.SslDiffieHellmanPath = model.DiffieHellmanPath = - string.Concat(sslPath, "/", "dhparam.pem"); - } - } - Build(model); - } - - public void BuildForUpdater() - { - var model = new TemplateModel(_context); - Build(model); - } - - private void Build(TemplateModel model) - { - Directory.CreateDirectory("/bitwarden/nginx/"); - Helpers.WriteLine(_context, "Building nginx config."); - if (!_context.Config.GenerateNginxConfig) - { - Helpers.WriteLine(_context, "...skipped"); - return; + _context = context; } - var template = Helpers.ReadTemplate("NginxConfig"); - using (var sw = File.CreateText(ConfFile)) + public void BuildForInstaller() { - sw.WriteLine(template(model)); - } - } - - public class TemplateModel - { - public TemplateModel() { } - - public TemplateModel(Context context) - { - Captcha = context.Config.Captcha; - Ssl = context.Config.Ssl; - EnableKeyConnector = context.Config.EnableKeyConnector; - EnableScim = context.Config.EnableScim; - Domain = context.Config.Domain; - Url = context.Config.Url; - RealIps = context.Config.RealIps; - ContentSecurityPolicy = string.Format(context.Config.NginxHeaderContentSecurityPolicy, Domain); - - if (Ssl) + var model = new TemplateModel(_context); + if (model.Ssl && !_context.Config.SslManagedLetsEncrypt) { - if (context.Config.SslManagedLetsEncrypt) + var sslPath = _context.Install.SelfSignedCert ? + $"/etc/ssl/self/{model.Domain}" : $"/etc/ssl/{model.Domain}"; + _context.Config.SslCertificatePath = model.CertificatePath = + string.Concat(sslPath, "/", "certificate.crt"); + _context.Config.SslKeyPath = model.KeyPath = + string.Concat(sslPath, "/", "private.key"); + if (_context.Install.Trusted) { - var sslPath = $"/etc/letsencrypt/live/{Domain}"; - CertificatePath = CaPath = string.Concat(sslPath, "/", "fullchain.pem"); - KeyPath = string.Concat(sslPath, "/", "privkey.pem"); - DiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); + _context.Config.SslCaPath = model.CaPath = + string.Concat(sslPath, "/", "ca.crt"); + } + if (_context.Install.DiffieHellman) + { + _context.Config.SslDiffieHellmanPath = model.DiffieHellmanPath = + string.Concat(sslPath, "/", "dhparam.pem"); + } + } + Build(model); + } + + public void BuildForUpdater() + { + var model = new TemplateModel(_context); + Build(model); + } + + private void Build(TemplateModel model) + { + Directory.CreateDirectory("/bitwarden/nginx/"); + Helpers.WriteLine(_context, "Building nginx config."); + if (!_context.Config.GenerateNginxConfig) + { + Helpers.WriteLine(_context, "...skipped"); + return; + } + + var template = Helpers.ReadTemplate("NginxConfig"); + using (var sw = File.CreateText(ConfFile)) + { + sw.WriteLine(template(model)); + } + } + + public class TemplateModel + { + public TemplateModel() { } + + public TemplateModel(Context context) + { + Captcha = context.Config.Captcha; + Ssl = context.Config.Ssl; + EnableKeyConnector = context.Config.EnableKeyConnector; + EnableScim = context.Config.EnableScim; + Domain = context.Config.Domain; + Url = context.Config.Url; + RealIps = context.Config.RealIps; + ContentSecurityPolicy = string.Format(context.Config.NginxHeaderContentSecurityPolicy, Domain); + + if (Ssl) + { + if (context.Config.SslManagedLetsEncrypt) + { + var sslPath = $"/etc/letsencrypt/live/{Domain}"; + CertificatePath = CaPath = string.Concat(sslPath, "/", "fullchain.pem"); + KeyPath = string.Concat(sslPath, "/", "privkey.pem"); + DiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); + } + else + { + CertificatePath = context.Config.SslCertificatePath; + KeyPath = context.Config.SslKeyPath; + CaPath = context.Config.SslCaPath; + DiffieHellmanPath = context.Config.SslDiffieHellmanPath; + } + } + + if (!string.IsNullOrWhiteSpace(context.Config.SslCiphersuites)) + { + SslCiphers = context.Config.SslCiphersuites; } else { - CertificatePath = context.Config.SslCertificatePath; - KeyPath = context.Config.SslKeyPath; - CaPath = context.Config.SslCaPath; - DiffieHellmanPath = context.Config.SslDiffieHellmanPath; + SslCiphers = "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" + + "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:" + + "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" + + "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256"; + } + + if (!string.IsNullOrWhiteSpace(context.Config.SslVersions)) + { + SslProtocols = context.Config.SslVersions; + } + else + { + SslProtocols = "TLSv1.2"; } } - if (!string.IsNullOrWhiteSpace(context.Config.SslCiphersuites)) - { - SslCiphers = context.Config.SslCiphersuites; - } - else - { - SslCiphers = "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" + - "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:" + - "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" + - "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256"; - } - - if (!string.IsNullOrWhiteSpace(context.Config.SslVersions)) - { - SslProtocols = context.Config.SslVersions; - } - else - { - SslProtocols = "TLSv1.2"; - } + public bool Captcha { get; set; } + public bool Ssl { get; set; } + public bool EnableKeyConnector { get; set; } + public bool EnableScim { get; set; } + public string Domain { get; set; } + public string Url { get; set; } + public string CertificatePath { get; set; } + public string KeyPath { get; set; } + public string CaPath { get; set; } + public string DiffieHellmanPath { get; set; } + public string SslCiphers { get; set; } + public string SslProtocols { get; set; } + public string ContentSecurityPolicy { get; set; } + public List RealIps { get; set; } } - - public bool Captcha { get; set; } - public bool Ssl { get; set; } - public bool EnableKeyConnector { get; set; } - public bool EnableScim { get; set; } - public string Domain { get; set; } - public string Url { get; set; } - public string CertificatePath { get; set; } - public string KeyPath { get; set; } - public string CaPath { get; set; } - public string DiffieHellmanPath { get; set; } - public string SslCiphers { get; set; } - public string SslProtocols { get; set; } - public string ContentSecurityPolicy { get; set; } - public List RealIps { get; set; } } } diff --git a/util/Setup/Program.cs b/util/Setup/Program.cs index 507b329b2b..8eb6474fd6 100644 --- a/util/Setup/Program.cs +++ b/util/Setup/Program.cs @@ -3,327 +3,328 @@ using System.Globalization; using System.Net.Http.Json; using Bit.Migrator; -namespace Bit.Setup; - -public class Program +namespace Bit.Setup { - private static Context _context; - - public static void Main(string[] args) + public class Program { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + private static Context _context; - _context = new Context + public static void Main(string[] args) { - Args = args - }; - ParseParameters(); + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - if (_context.Parameters.ContainsKey("q")) - { - _context.Quiet = _context.Parameters["q"] == "true" || _context.Parameters["q"] == "1"; - } - if (_context.Parameters.ContainsKey("os")) - { - _context.HostOS = _context.Parameters["os"]; - } - if (_context.Parameters.ContainsKey("corev")) - { - _context.CoreVersion = _context.Parameters["corev"]; - } - if (_context.Parameters.ContainsKey("webv")) - { - _context.WebVersion = _context.Parameters["webv"]; - } - if (_context.Parameters.ContainsKey("keyconnectorv")) - { - _context.KeyConnectorVersion = _context.Parameters["keyconnectorv"]; - } - if (_context.Parameters.ContainsKey("stub")) - { - _context.Stub = _context.Parameters["stub"] == "true" || - _context.Parameters["stub"] == "1"; - } - - Helpers.WriteLine(_context); - - if (_context.Parameters.ContainsKey("install")) - { - Install(); - } - else if (_context.Parameters.ContainsKey("update")) - { - Update(); - } - else if (_context.Parameters.ContainsKey("printenv")) - { - PrintEnvironment(); - } - else - { - Helpers.WriteLine(_context, "No top-level command detected. Exiting..."); - } - } - - private static void Install() - { - if (_context.Parameters.ContainsKey("letsencrypt")) - { - _context.Config.SslManagedLetsEncrypt = - _context.Parameters["letsencrypt"].ToLowerInvariant() == "y"; - } - if (_context.Parameters.ContainsKey("domain")) - { - _context.Install.Domain = _context.Parameters["domain"].ToLowerInvariant(); - } - if (_context.Parameters.ContainsKey("dbname")) - { - _context.Install.Database = _context.Parameters["dbname"]; - } - - if (_context.Stub) - { - _context.Install.InstallationId = Guid.Empty; - _context.Install.InstallationKey = "SECRET_INSTALLATION_KEY"; - } - else if (!ValidateInstallation()) - { - return; - } - - var certBuilder = new CertBuilder(_context); - certBuilder.BuildForInstall(); - - // Set the URL - _context.Config.Url = string.Format("http{0}://{1}", - _context.Config.Ssl ? "s" : string.Empty, _context.Install.Domain); - - var nginxBuilder = new NginxConfigBuilder(_context); - nginxBuilder.BuildForInstaller(); - - var environmentFileBuilder = new EnvironmentFileBuilder(_context); - environmentFileBuilder.BuildForInstaller(); - - var appIdBuilder = new AppIdBuilder(_context); - appIdBuilder.Build(); - - var dockerComposeBuilder = new DockerComposeBuilder(_context); - dockerComposeBuilder.BuildForInstaller(); - - _context.SaveConfiguration(); - - Console.WriteLine("\nInstallation complete"); - - Console.WriteLine("\nIf you need to make additional configuration changes, you can modify\n" + - "the settings in `{0}` and then run:\n{1}", - _context.HostOS == "win" ? ".\\bwdata\\config.yml" : "./bwdata/config.yml", - _context.HostOS == "win" ? "`.\\bitwarden.ps1 -rebuild` or `.\\bitwarden.ps1 -update`" : - "`./bitwarden.sh rebuild` or `./bitwarden.sh update`"); - - Console.WriteLine("\nNext steps, run:"); - if (_context.HostOS == "win") - { - Console.WriteLine("`.\\bitwarden.ps1 -start`"); - } - else - { - Console.WriteLine("`./bitwarden.sh start`"); - } - Console.WriteLine(string.Empty); - } - - private static void Update() - { - // This portion of code checks for multiple certs in the Identity.pfx PKCS12 bag. If found, it generates - // a new cert and bag to replace the old Identity.pfx. This fixes an issue that came up as a result of - // moving the project to .NET 5. - _context.Install.IdentityCertPassword = Helpers.GetValueFromEnvFile("global", "globalSettings__identityServer__certificatePassword"); - var certCountString = Helpers.Exec("openssl pkcs12 -nokeys -info -in /bitwarden/identity/identity.pfx " + - $"-passin pass:{_context.Install.IdentityCertPassword} 2> /dev/null | grep -c \"\\-----BEGIN CERTIFICATE----\"", true); - if (int.TryParse(certCountString, out var certCount) && certCount > 1) - { - // Extract key from identity.pfx - Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -nocerts -nodes -out identity.key " + - $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); - // Extract certificate from identity.pfx - Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -clcerts -nokeys -out identity.crt " + - $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); - // Create new PKCS12 bag with certificate and key - Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + - $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); - } - - if (_context.Parameters.ContainsKey("db")) - { - MigrateDatabase(); - } - else - { - RebuildConfigs(); - } - } - - private static void PrintEnvironment() - { - _context.LoadConfiguration(); - if (!_context.PrintToScreen()) - { - return; - } - Console.WriteLine("\nBitwarden is up and running!"); - Console.WriteLine("==================================================="); - Console.WriteLine("\nvisit {0}", _context.Config.Url); - Console.Write("to update, run "); - if (_context.HostOS == "win") - { - Console.Write("`.\\bitwarden.ps1 -updateself` and then `.\\bitwarden.ps1 -update`"); - } - else - { - Console.Write("`./bitwarden.sh updateself` and then `./bitwarden.sh update`"); - } - Console.WriteLine("\n"); - } - - private static void MigrateDatabase(int attempt = 1) - { - try - { - Helpers.WriteLine(_context, "Migrating database."); - var vaultConnectionString = Helpers.GetValueFromEnvFile("global", - "globalSettings__sqlServer__connectionString"); - var migrator = new DbMigrator(vaultConnectionString, null); - var success = migrator.MigrateMsSqlDatabase(false); - if (success) + _context = new Context { - Helpers.WriteLine(_context, "Migration successful."); + Args = args + }; + ParseParameters(); + + if (_context.Parameters.ContainsKey("q")) + { + _context.Quiet = _context.Parameters["q"] == "true" || _context.Parameters["q"] == "1"; + } + if (_context.Parameters.ContainsKey("os")) + { + _context.HostOS = _context.Parameters["os"]; + } + if (_context.Parameters.ContainsKey("corev")) + { + _context.CoreVersion = _context.Parameters["corev"]; + } + if (_context.Parameters.ContainsKey("webv")) + { + _context.WebVersion = _context.Parameters["webv"]; + } + if (_context.Parameters.ContainsKey("keyconnectorv")) + { + _context.KeyConnectorVersion = _context.Parameters["keyconnectorv"]; + } + if (_context.Parameters.ContainsKey("stub")) + { + _context.Stub = _context.Parameters["stub"] == "true" || + _context.Parameters["stub"] == "1"; + } + + Helpers.WriteLine(_context); + + if (_context.Parameters.ContainsKey("install")) + { + Install(); + } + else if (_context.Parameters.ContainsKey("update")) + { + Update(); + } + else if (_context.Parameters.ContainsKey("printenv")) + { + PrintEnvironment(); } else { - Helpers.WriteLine(_context, "Migration failed."); + Helpers.WriteLine(_context, "No top-level command detected. Exiting..."); } } - catch (SqlException e) + + private static void Install() { - if (e.Message.Contains("Server is in script upgrade mode") && attempt < 10) + if (_context.Parameters.ContainsKey("letsencrypt")) + { + _context.Config.SslManagedLetsEncrypt = + _context.Parameters["letsencrypt"].ToLowerInvariant() == "y"; + } + if (_context.Parameters.ContainsKey("domain")) + { + _context.Install.Domain = _context.Parameters["domain"].ToLowerInvariant(); + } + if (_context.Parameters.ContainsKey("dbname")) + { + _context.Install.Database = _context.Parameters["dbname"]; + } + + if (_context.Stub) + { + _context.Install.InstallationId = Guid.Empty; + _context.Install.InstallationKey = "SECRET_INSTALLATION_KEY"; + } + else if (!ValidateInstallation()) { - var nextAttempt = attempt + 1; - Helpers.WriteLine(_context, "Database is in script upgrade mode. " + - "Trying again (attempt #{0})...", nextAttempt); - System.Threading.Thread.Sleep(20000); - MigrateDatabase(nextAttempt); return; } - throw; - } - } - private static bool ValidateInstallation() - { - var installationId = string.Empty; - var installationKey = string.Empty; + var certBuilder = new CertBuilder(_context); + certBuilder.BuildForInstall(); - if (_context.Parameters.ContainsKey("install-id")) - { - installationId = _context.Parameters["install-id"].ToLowerInvariant(); - } - else - { - installationId = Helpers.ReadInput("Enter your installation id (get at https://bitwarden.com/host)"); - } + // Set the URL + _context.Config.Url = string.Format("http{0}://{1}", + _context.Config.Ssl ? "s" : string.Empty, _context.Install.Domain); - if (!Guid.TryParse(installationId.Trim(), out var installationidGuid)) - { - Console.WriteLine("Invalid installation id."); - return false; - } + var nginxBuilder = new NginxConfigBuilder(_context); + nginxBuilder.BuildForInstaller(); - if (_context.Parameters.ContainsKey("install-key")) - { - installationKey = _context.Parameters["install-key"]; - } - else - { - installationKey = Helpers.ReadInput("Enter your installation key"); - } + var environmentFileBuilder = new EnvironmentFileBuilder(_context); + environmentFileBuilder.BuildForInstaller(); - _context.Install.InstallationId = installationidGuid; - _context.Install.InstallationKey = installationKey; + var appIdBuilder = new AppIdBuilder(_context); + appIdBuilder.Build(); - try - { - var response = new HttpClient().GetAsync("https://api.bitwarden.com/installations/" + - _context.Install.InstallationId).GetAwaiter().GetResult(); + var dockerComposeBuilder = new DockerComposeBuilder(_context); + dockerComposeBuilder.BuildForInstaller(); - if (!response.IsSuccessStatusCode) + _context.SaveConfiguration(); + + Console.WriteLine("\nInstallation complete"); + + Console.WriteLine("\nIf you need to make additional configuration changes, you can modify\n" + + "the settings in `{0}` and then run:\n{1}", + _context.HostOS == "win" ? ".\\bwdata\\config.yml" : "./bwdata/config.yml", + _context.HostOS == "win" ? "`.\\bitwarden.ps1 -rebuild` or `.\\bitwarden.ps1 -update`" : + "`./bitwarden.sh rebuild` or `./bitwarden.sh update`"); + + Console.WriteLine("\nNext steps, run:"); + if (_context.HostOS == "win") { - if (response.StatusCode == System.Net.HttpStatusCode.NotFound) + Console.WriteLine("`.\\bitwarden.ps1 -start`"); + } + else + { + Console.WriteLine("`./bitwarden.sh start`"); + } + Console.WriteLine(string.Empty); + } + + private static void Update() + { + // This portion of code checks for multiple certs in the Identity.pfx PKCS12 bag. If found, it generates + // a new cert and bag to replace the old Identity.pfx. This fixes an issue that came up as a result of + // moving the project to .NET 5. + _context.Install.IdentityCertPassword = Helpers.GetValueFromEnvFile("global", "globalSettings__identityServer__certificatePassword"); + var certCountString = Helpers.Exec("openssl pkcs12 -nokeys -info -in /bitwarden/identity/identity.pfx " + + $"-passin pass:{_context.Install.IdentityCertPassword} 2> /dev/null | grep -c \"\\-----BEGIN CERTIFICATE----\"", true); + if (int.TryParse(certCountString, out var certCount) && certCount > 1) + { + // Extract key from identity.pfx + Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -nocerts -nodes -out identity.key " + + $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); + // Extract certificate from identity.pfx + Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -clcerts -nokeys -out identity.crt " + + $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); + // Create new PKCS12 bag with certificate and key + Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + + $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); + } + + if (_context.Parameters.ContainsKey("db")) + { + MigrateDatabase(); + } + else + { + RebuildConfigs(); + } + } + + private static void PrintEnvironment() + { + _context.LoadConfiguration(); + if (!_context.PrintToScreen()) + { + return; + } + Console.WriteLine("\nBitwarden is up and running!"); + Console.WriteLine("==================================================="); + Console.WriteLine("\nvisit {0}", _context.Config.Url); + Console.Write("to update, run "); + if (_context.HostOS == "win") + { + Console.Write("`.\\bitwarden.ps1 -updateself` and then `.\\bitwarden.ps1 -update`"); + } + else + { + Console.Write("`./bitwarden.sh updateself` and then `./bitwarden.sh update`"); + } + Console.WriteLine("\n"); + } + + private static void MigrateDatabase(int attempt = 1) + { + try + { + Helpers.WriteLine(_context, "Migrating database."); + var vaultConnectionString = Helpers.GetValueFromEnvFile("global", + "globalSettings__sqlServer__connectionString"); + var migrator = new DbMigrator(vaultConnectionString, null); + var success = migrator.MigrateMsSqlDatabase(false); + if (success) { - Console.WriteLine("Invalid installation id."); + Helpers.WriteLine(_context, "Migration successful."); } else { - Console.WriteLine("Unable to validate installation id."); + Helpers.WriteLine(_context, "Migration failed."); + } + } + catch (SqlException e) + { + if (e.Message.Contains("Server is in script upgrade mode") && attempt < 10) + { + var nextAttempt = attempt + 1; + Helpers.WriteLine(_context, "Database is in script upgrade mode. " + + "Trying again (attempt #{0})...", nextAttempt); + System.Threading.Thread.Sleep(20000); + MigrateDatabase(nextAttempt); + return; + } + throw; + } + } + + private static bool ValidateInstallation() + { + var installationId = string.Empty; + var installationKey = string.Empty; + + if (_context.Parameters.ContainsKey("install-id")) + { + installationId = _context.Parameters["install-id"].ToLowerInvariant(); + } + else + { + installationId = Helpers.ReadInput("Enter your installation id (get at https://bitwarden.com/host)"); + } + + if (!Guid.TryParse(installationId.Trim(), out var installationidGuid)) + { + Console.WriteLine("Invalid installation id."); + return false; + } + + if (_context.Parameters.ContainsKey("install-key")) + { + installationKey = _context.Parameters["install-key"]; + } + else + { + installationKey = Helpers.ReadInput("Enter your installation key"); + } + + _context.Install.InstallationId = installationidGuid; + _context.Install.InstallationKey = installationKey; + + try + { + var response = new HttpClient().GetAsync("https://api.bitwarden.com/installations/" + + _context.Install.InstallationId).GetAwaiter().GetResult(); + + if (!response.IsSuccessStatusCode) + { + if (response.StatusCode == System.Net.HttpStatusCode.NotFound) + { + Console.WriteLine("Invalid installation id."); + } + else + { + Console.WriteLine("Unable to validate installation id."); + } + + return false; } + var result = response.Content.ReadFromJsonAsync().GetAwaiter().GetResult(); + if (!result.Enabled) + { + Console.WriteLine("Installation id has been disabled."); + return false; + } + + return true; + } + catch + { + Console.WriteLine("Unable to validate installation id. Problem contacting Bitwarden server."); return false; } - - var result = response.Content.ReadFromJsonAsync().GetAwaiter().GetResult(); - if (!result.Enabled) - { - Console.WriteLine("Installation id has been disabled."); - return false; - } - - return true; } - catch + + private static void RebuildConfigs() { - Console.WriteLine("Unable to validate installation id. Problem contacting Bitwarden server."); - return false; + _context.LoadConfiguration(); + + var environmentFileBuilder = new EnvironmentFileBuilder(_context); + environmentFileBuilder.BuildForUpdater(); + + var certBuilder = new CertBuilder(_context); + certBuilder.BuildForUpdater(); + + var nginxBuilder = new NginxConfigBuilder(_context); + nginxBuilder.BuildForUpdater(); + + var appIdBuilder = new AppIdBuilder(_context); + appIdBuilder.Build(); + + var dockerComposeBuilder = new DockerComposeBuilder(_context); + dockerComposeBuilder.BuildForUpdater(); + + _context.SaveConfiguration(); + Console.WriteLine(string.Empty); } - } - private static void RebuildConfigs() - { - _context.LoadConfiguration(); - - var environmentFileBuilder = new EnvironmentFileBuilder(_context); - environmentFileBuilder.BuildForUpdater(); - - var certBuilder = new CertBuilder(_context); - certBuilder.BuildForUpdater(); - - var nginxBuilder = new NginxConfigBuilder(_context); - nginxBuilder.BuildForUpdater(); - - var appIdBuilder = new AppIdBuilder(_context); - appIdBuilder.Build(); - - var dockerComposeBuilder = new DockerComposeBuilder(_context); - dockerComposeBuilder.BuildForUpdater(); - - _context.SaveConfiguration(); - Console.WriteLine(string.Empty); - } - - private static void ParseParameters() - { - _context.Parameters = new Dictionary(); - for (var i = 0; i < _context.Args.Length; i = i + 2) + private static void ParseParameters() { - if (!_context.Args[i].StartsWith("-")) + _context.Parameters = new Dictionary(); + for (var i = 0; i < _context.Args.Length; i = i + 2) { - continue; + if (!_context.Args[i].StartsWith("-")) + { + continue; + } + + _context.Parameters.Add(_context.Args[i].Substring(1), _context.Args[i + 1]); } - - _context.Parameters.Add(_context.Args[i].Substring(1), _context.Args[i + 1]); } - } - class InstallationValidationResponseModel - { - public bool Enabled { get; init; } + class InstallationValidationResponseModel + { + public bool Enabled { get; init; } + } } } diff --git a/util/Setup/YamlComments.cs b/util/Setup/YamlComments.cs index 5bdb6fddf9..32b935d502 100644 --- a/util/Setup/YamlComments.cs +++ b/util/Setup/YamlComments.cs @@ -7,101 +7,102 @@ using YamlDotNet.Serialization.TypeInspectors; // ref: https://github.com/aaubry/YamlDotNet/issues/152#issuecomment-349034754 -namespace Bit.Setup; - -public class CommentGatheringTypeInspector : TypeInspectorSkeleton +namespace Bit.Setup { - private readonly ITypeInspector _innerTypeDescriptor; - - public CommentGatheringTypeInspector(ITypeInspector innerTypeDescriptor) + public class CommentGatheringTypeInspector : TypeInspectorSkeleton { - _innerTypeDescriptor = innerTypeDescriptor ?? throw new ArgumentNullException(nameof(innerTypeDescriptor)); - } + private readonly ITypeInspector _innerTypeDescriptor; - public override IEnumerable GetProperties(Type type, object container) - { - return _innerTypeDescriptor.GetProperties(type, container).Select(d => new CommentsPropertyDescriptor(d)); - } - - private sealed class CommentsPropertyDescriptor : IPropertyDescriptor - { - private readonly IPropertyDescriptor _baseDescriptor; - - public CommentsPropertyDescriptor(IPropertyDescriptor baseDescriptor) + public CommentGatheringTypeInspector(ITypeInspector innerTypeDescriptor) { - _baseDescriptor = baseDescriptor; - Name = baseDescriptor.Name; + _innerTypeDescriptor = innerTypeDescriptor ?? throw new ArgumentNullException(nameof(innerTypeDescriptor)); } - public string Name { get; set; } - public int Order { get; set; } - public Type Type => _baseDescriptor.Type; - public bool CanWrite => _baseDescriptor.CanWrite; - - public Type TypeOverride + public override IEnumerable GetProperties(Type type, object container) { - get { return _baseDescriptor.TypeOverride; } - set { _baseDescriptor.TypeOverride = value; } + return _innerTypeDescriptor.GetProperties(type, container).Select(d => new CommentsPropertyDescriptor(d)); } - public ScalarStyle ScalarStyle + private sealed class CommentsPropertyDescriptor : IPropertyDescriptor { - get { return _baseDescriptor.ScalarStyle; } - set { _baseDescriptor.ScalarStyle = value; } - } + private readonly IPropertyDescriptor _baseDescriptor; - public void Write(object target, object value) - { - _baseDescriptor.Write(target, value); - } - - public T GetCustomAttribute() where T : Attribute - { - return _baseDescriptor.GetCustomAttribute(); - } - - public IObjectDescriptor Read(object target) - { - var description = _baseDescriptor.GetCustomAttribute(); - return description != null ? - new CommentsObjectDescriptor(_baseDescriptor.Read(target), description.Description) : - _baseDescriptor.Read(target); - } - } -} - -public sealed class CommentsObjectDescriptor : IObjectDescriptor -{ - private readonly IObjectDescriptor _innerDescriptor; - - public CommentsObjectDescriptor(IObjectDescriptor innerDescriptor, string comment) - { - _innerDescriptor = innerDescriptor; - Comment = comment; - } - - public string Comment { get; private set; } - public object Value => _innerDescriptor.Value; - public Type Type => _innerDescriptor.Type; - public Type StaticType => _innerDescriptor.StaticType; - public ScalarStyle ScalarStyle => _innerDescriptor.ScalarStyle; -} - -public class CommentsObjectGraphVisitor : ChainedObjectGraphVisitor -{ - public CommentsObjectGraphVisitor(IObjectGraphVisitor nextVisitor) - : base(nextVisitor) { } - - public override bool EnterMapping(IPropertyDescriptor key, IObjectDescriptor value, IEmitter context) - { - if (value is CommentsObjectDescriptor commentsDescriptor && commentsDescriptor.Comment != null) - { - context.Emit(new Comment(string.Empty, false)); - foreach (var comment in commentsDescriptor.Comment.Split(Environment.NewLine)) + public CommentsPropertyDescriptor(IPropertyDescriptor baseDescriptor) { - context.Emit(new Comment(comment, false)); + _baseDescriptor = baseDescriptor; + Name = baseDescriptor.Name; + } + + public string Name { get; set; } + public int Order { get; set; } + public Type Type => _baseDescriptor.Type; + public bool CanWrite => _baseDescriptor.CanWrite; + + public Type TypeOverride + { + get { return _baseDescriptor.TypeOverride; } + set { _baseDescriptor.TypeOverride = value; } + } + + public ScalarStyle ScalarStyle + { + get { return _baseDescriptor.ScalarStyle; } + set { _baseDescriptor.ScalarStyle = value; } + } + + public void Write(object target, object value) + { + _baseDescriptor.Write(target, value); + } + + public T GetCustomAttribute() where T : Attribute + { + return _baseDescriptor.GetCustomAttribute(); + } + + public IObjectDescriptor Read(object target) + { + var description = _baseDescriptor.GetCustomAttribute(); + return description != null ? + new CommentsObjectDescriptor(_baseDescriptor.Read(target), description.Description) : + _baseDescriptor.Read(target); } } - return base.EnterMapping(key, value, context); + } + + public sealed class CommentsObjectDescriptor : IObjectDescriptor + { + private readonly IObjectDescriptor _innerDescriptor; + + public CommentsObjectDescriptor(IObjectDescriptor innerDescriptor, string comment) + { + _innerDescriptor = innerDescriptor; + Comment = comment; + } + + public string Comment { get; private set; } + public object Value => _innerDescriptor.Value; + public Type Type => _innerDescriptor.Type; + public Type StaticType => _innerDescriptor.StaticType; + public ScalarStyle ScalarStyle => _innerDescriptor.ScalarStyle; + } + + public class CommentsObjectGraphVisitor : ChainedObjectGraphVisitor + { + public CommentsObjectGraphVisitor(IObjectGraphVisitor nextVisitor) + : base(nextVisitor) { } + + public override bool EnterMapping(IPropertyDescriptor key, IObjectDescriptor value, IEmitter context) + { + if (value is CommentsObjectDescriptor commentsDescriptor && commentsDescriptor.Comment != null) + { + context.Emit(new Comment(string.Empty, false)); + foreach (var comment in commentsDescriptor.Comment.Split(Environment.NewLine)) + { + context.Emit(new Comment(comment, false)); + } + } + return base.EnterMapping(key, value, context); + } } }