diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs index 7e599dd1c3..2d6bd94fd1 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs @@ -21,11 +21,11 @@ public interface IPolicyRequirementQuery /// The policy requirement represents how one or more policy types should be enforced against the users. /// /// - /// A list of applicable policy requirements in corresponding order of the submitted user IDs. + /// A collection of tuples pairing each user ID with their corresponding policy requirement. /// /// The users that you need to enforce the policy against. /// The IPolicyRequirement that corresponds to the policy you want to enforce. - Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement; + Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement; /// /// Get all organization user IDs within an organization that are affected by a given policy type. diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index 3550e33020..8090691540 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -11,9 +11,9 @@ public class PolicyRequirementQuery( : IPolicyRequirementQuery { public async Task GetAsync(Guid userId) where T : IPolicyRequirement - => (await GetAsync([userId])).Single(); + => (await GetAsync([userId])).Single().Requirement; - public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement + public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement { var factory = factories.OfType>().SingleOrDefault(); if (factory is null) @@ -27,7 +27,7 @@ public class PolicyRequirementQuery( .Where(factory.Enforce) .ToLookup(l => l.UserId); - var policyRequirements = userIdList.Select(u => factory.Create(policyDetailsByUser[u])); + var policyRequirements = userIdList.Select(u => (u, factory.Create(policyDetailsByUser[u]))); return policyRequirements; } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs index 9bb8941156..823de89757 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs @@ -83,10 +83,12 @@ public class PolicyRequirementQueryTests var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); Assert.Equal(2, requirements.Count); - Assert.Contains(policyA, requirements[0].Policies); - Assert.DoesNotContain(policyB, requirements[0].Policies); - Assert.Contains(policyB, requirements[1].Policies); - Assert.DoesNotContain(policyA, requirements[1].Policies); + Assert.Equal(userIdA, requirements[0].UserId); + Assert.Equal(userIdB, requirements[1].UserId); + Assert.Contains(policyA, requirements[0].Requirement.Policies); + Assert.DoesNotContain(policyB, requirements[0].Requirement.Policies); + Assert.Contains(policyB, requirements[1].Requirement.Policies); + Assert.DoesNotContain(policyA, requirements[1].Requirement.Policies); } [Theory, BitAutoData] @@ -107,8 +109,8 @@ public class PolicyRequirementQueryTests var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); - Assert.Contains(policyA, requirements[0].Policies); - Assert.Empty(requirements[1].Policies); + Assert.Contains(policyA, requirements[0].Requirement.Policies); + Assert.Empty(requirements[1].Requirement.Policies); callback.Received()(Arg.Is(policyA)); callback.Received()(Arg.Is(policyB)); } @@ -134,9 +136,9 @@ public class PolicyRequirementQueryTests var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); Assert.Equal(2, requirements.Count); - Assert.Contains(enforcedPolicyA, requirements[0].Policies); - Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Policies); - Assert.Contains(enforcedPolicyB, requirements[1].Policies); + Assert.Contains(enforcedPolicyA, requirements[0].Requirement.Policies); + Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Requirement.Policies); + Assert.Contains(enforcedPolicyB, requirements[1].Requirement.Policies); } [Theory, BitAutoData] @@ -164,8 +166,10 @@ public class PolicyRequirementQueryTests var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); Assert.Equal(2, requirements.Count); - Assert.Empty(requirements[0].Policies); - Assert.Empty(requirements[1].Policies); + Assert.Equal(userIdA, requirements[0].UserId); + Assert.Equal(userIdB, requirements[1].UserId); + Assert.Empty(requirements[0].Requirement.Policies); + Assert.Empty(requirements[1].Requirement.Policies); } [Theory, BitAutoData] @@ -185,8 +189,10 @@ public class PolicyRequirementQueryTests var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); Assert.Equal(2, requirements.Count); - Assert.Contains(policyA, requirements[0].Policies); - Assert.Empty(requirements[1].Policies); + Assert.Equal(userIdA, requirements[0].UserId); + Assert.Equal(userIdB, requirements[1].UserId); + Assert.Contains(policyA, requirements[0].Requirement.Policies); + Assert.Empty(requirements[1].Requirement.Policies); } [Theory, BitAutoData]