using System.Net.Mail; using AutoMapper; using Bit.Core.Models.Data.Organizations; using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; #nullable enable namespace Bit.Infrastructure.EntityFramework.Repositories; public class OrganizationDomainRepository : Repository, IOrganizationDomainRepository { public OrganizationDomainRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationDomains) { } public async Task> GetClaimedDomainsByDomainNameAsync( string domainName) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var claimedDomains = await dbContext.OrganizationDomains .Where(x => x.DomainName == domainName && x.VerifiedDate != null) .AsNoTracking() .ToListAsync(); return Mapper.Map>(claimedDomains); } public async Task> GetDomainsByOrganizationIdAsync(Guid orgId) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var domains = await dbContext.OrganizationDomains .Where(x => x.OrganizationId == orgId) .AsNoTracking() .ToListAsync(); return Mapper.Map>(domains); } public async Task> GetManyByNextRunDateAsync(DateTime date) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var start36HoursWindow = date.AddHours(-36); var end36HoursWindow = date; var pastDomains = await dbContext.OrganizationDomains .Where(x => x.NextRunDate >= start36HoursWindow && x.NextRunDate <= end36HoursWindow && x.VerifiedDate == null && x.JobRunCount != 3) .ToListAsync(); return Mapper.Map>(pastDomains); } public async Task GetOrganizationDomainSsoDetailsAsync(string email) { var domainName = new MailAddress(email).Host; using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var ssoDetails = await (from o in dbContext.Organizations from od in o.Domains join s in dbContext.SsoConfigs on o.Id equals s.OrganizationId into sJoin from s in sJoin.DefaultIfEmpty() where od.DomainName == domainName && o.Enabled select new OrganizationDomainSsoDetailsData { OrganizationId = o.Id, OrganizationName = o.Name, SsoAvailable = o.SsoConfigs.Any(sc => sc.Enabled), OrganizationIdentifier = o.Identifier, VerifiedDate = od.VerifiedDate, DomainName = od.DomainName }) .AsNoTracking() .SingleOrDefaultAsync(); return ssoDetails; } public async Task> GetVerifiedOrganizationDomainSsoDetailsAsync(string email) { var domainName = new MailAddress(email).Host; using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); return await (from o in dbContext.Organizations from od in o.Domains join s in dbContext.SsoConfigs on o.Id equals s.OrganizationId into sJoin from s in sJoin.DefaultIfEmpty() where od.DomainName == domainName && o.Enabled && s.Enabled && od.VerifiedDate != null select new VerifiedOrganizationDomainSsoDetail( o.Id, o.Name, od.DomainName, o.Identifier)) .AsNoTracking() .ToListAsync(); } public async Task GetDomainByIdOrganizationIdAsync(Guid id, Guid orgId) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var domain = await dbContext.OrganizationDomains .Where(x => x.Id == id && x.OrganizationId == orgId) .AsNoTracking() .FirstOrDefaultAsync(); return Mapper.Map(domain); } public async Task GetDomainByOrgIdAndDomainNameAsync(Guid orgId, string domainName) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var domain = await dbContext.OrganizationDomains .Where(x => x.OrganizationId == orgId && x.DomainName == domainName) .AsNoTracking() .FirstOrDefaultAsync(); return Mapper.Map(domain); } public async Task> GetExpiredOrganizationDomainsAsync() { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var threeDaysOldUnverifiedDomains = await dbContext.OrganizationDomains .Where(x => x.CreationDate.Date == DateTime.UtcNow.AddDays(-4).Date && x.VerifiedDate == null) .AsNoTracking() .ToListAsync(); return Mapper.Map>(threeDaysOldUnverifiedDomains); } public async Task DeleteExpiredAsync(int expirationPeriod) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var expiredDomains = await dbContext.OrganizationDomains .Where(x => x.LastCheckedDate < DateTime.UtcNow.AddDays(-expirationPeriod)) .ToListAsync(); dbContext.OrganizationDomains.RemoveRange(expiredDomains); return await dbContext.SaveChangesAsync() > 0; } public async Task> GetVerifiedDomainsByOrganizationIdsAsync( IEnumerable organizationIds) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var verifiedDomains = await (from d in dbContext.OrganizationDomains where organizationIds.Contains(d.OrganizationId) && d.VerifiedDate != null select new OrganizationDomain { OrganizationId = d.OrganizationId, DomainName = d.DomainName }) .AsNoTracking() .ToListAsync(); return Mapper.Map>(verifiedDomains); } }