using System.Net.Mail; using AutoMapper; using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations; using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; namespace Bit.Infrastructure.EntityFramework.Repositories; public class OrganizationDomainRepository : Repository, IOrganizationDomainRepository { public OrganizationDomainRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationDomains) { } public async Task> GetClaimedDomainsByDomainNameAsync( string domainName) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var claimedDomains = await dbContext.OrganizationDomains .Where(x => x.DomainName == domainName && x.VerifiedDate != null) .AsNoTracking() .ToListAsync(); return Mapper.Map>(claimedDomains); } public async Task> GetDomainsByOrganizationIdAsync(Guid orgId) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var domains = await dbContext.OrganizationDomains .Where(x => x.OrganizationId == orgId) .AsNoTracking() .ToListAsync(); return Mapper.Map>(domains); } public async Task> GetManyByNextRunDateAsync(DateTime date) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var domains = await dbContext.OrganizationDomains .Where(x => x.VerifiedDate == null && x.JobRunCount != 3 && x.NextRunDate.Year == date.Year && x.NextRunDate.Month == date.Month && x.NextRunDate.Day == date.Day && x.NextRunDate.Hour == date.Hour) .AsNoTracking() .ToListAsync(); //Get records that have ignored/failed by the background service var pastDomains = dbContext.OrganizationDomains .AsEnumerable() .Where(x => (date - x.NextRunDate).TotalHours > 36 && x.VerifiedDate == null && x.JobRunCount != 3) .ToList(); var results = domains.Union(pastDomains); return Mapper.Map>(results); } public async Task GetOrganizationDomainSsoDetailsAsync(string email) { var domainName = new MailAddress(email).Host; using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var ssoDetails = await dbContext.Organizations .Join(dbContext.OrganizationDomains, o => o.Id, od => od.OrganizationId, (organization, domain) => new { resOrganization = organization, resDomain = domain }) .Join(dbContext.Policies, o => o.resOrganization.Id, p => p.OrganizationId, (combinedOrgDomain, policy) => new { Organization = combinedOrgDomain.resOrganization, Domain = combinedOrgDomain.resDomain, Policy = policy }) .Select(x => new OrganizationDomainSsoDetailsData { OrganizationId = x.Organization.Id, OrganizationName = x.Organization.Name, SsoAvailable = x.Organization.UseSso, OrganizationIdentifier = x.Organization.Identifier, SsoRequired = x.Policy.Enabled, VerifiedDate = x.Domain.VerifiedDate, PolicyType = x.Policy.Type, DomainName = x.Domain.DomainName, OrganizationEnabled = x.Organization.Enabled }) .Where(y => y.DomainName == domainName && y.OrganizationEnabled == true && y.PolicyType.Equals(PolicyType.RequireSso)) .AsNoTracking() .SingleOrDefaultAsync(); return ssoDetails; } public async Task GetDomainByOrgIdAndDomainNameAsync(Guid orgId, string domainName) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var domain = await dbContext.OrganizationDomains .Where(x => x.OrganizationId == orgId && x.DomainName == domainName) .AsNoTracking() .FirstOrDefaultAsync(); return Mapper.Map(domain); } public async Task> GetExpiredOrganizationDomainsAsync() { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); //Get domains that have not been verified after 72 hours var domains = dbContext.OrganizationDomains .AsEnumerable() .Where(x => (DateTime.UtcNow - x.CreationDate).Days >= 4 && x.VerifiedDate == null) .ToList(); return Mapper.Map>(domains); } public async Task DeleteExpiredAsync(int expirationPeriod) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); var expiredDomains = await dbContext.OrganizationDomains .Where(x => x.LastCheckedDate < DateTime.UtcNow.AddDays(-expirationPeriod)) .ToListAsync(); dbContext.OrganizationDomains.RemoveRange(expiredDomains); return await dbContext.SaveChangesAsync() > 0; } }