diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs
index 7e599dd1c3..2d6bd94fd1 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs
@@ -21,11 +21,11 @@ public interface IPolicyRequirementQuery
/// The policy requirement represents how one or more policy types should be enforced against the users.
///
///
- /// A list of applicable policy requirements in corresponding order of the submitted user IDs.
+ /// A collection of tuples pairing each user ID with their corresponding policy requirement.
///
/// The users that you need to enforce the policy against.
/// The IPolicyRequirement that corresponds to the policy you want to enforce.
- Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement;
+ Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement;
///
/// Get all organization user IDs within an organization that are affected by a given policy type.
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs
index 3550e33020..8090691540 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs
@@ -11,9 +11,9 @@ public class PolicyRequirementQuery(
: IPolicyRequirementQuery
{
public async Task GetAsync(Guid userId) where T : IPolicyRequirement
- => (await GetAsync([userId])).Single();
+ => (await GetAsync([userId])).Single().Requirement;
- public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement
+ public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement
{
var factory = factories.OfType>().SingleOrDefault();
if (factory is null)
@@ -27,7 +27,7 @@ public class PolicyRequirementQuery(
.Where(factory.Enforce)
.ToLookup(l => l.UserId);
- var policyRequirements = userIdList.Select(u => factory.Create(policyDetailsByUser[u]));
+ var policyRequirements = userIdList.Select(u => (u, factory.Create(policyDetailsByUser[u])));
return policyRequirements;
}
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs
index 9bb8941156..823de89757 100644
--- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs
@@ -83,10 +83,12 @@ public class PolicyRequirementQueryTests
var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
- Assert.Contains(policyA, requirements[0].Policies);
- Assert.DoesNotContain(policyB, requirements[0].Policies);
- Assert.Contains(policyB, requirements[1].Policies);
- Assert.DoesNotContain(policyA, requirements[1].Policies);
+ Assert.Equal(userIdA, requirements[0].UserId);
+ Assert.Equal(userIdB, requirements[1].UserId);
+ Assert.Contains(policyA, requirements[0].Requirement.Policies);
+ Assert.DoesNotContain(policyB, requirements[0].Requirement.Policies);
+ Assert.Contains(policyB, requirements[1].Requirement.Policies);
+ Assert.DoesNotContain(policyA, requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
@@ -107,8 +109,8 @@ public class PolicyRequirementQueryTests
var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
- Assert.Contains(policyA, requirements[0].Policies);
- Assert.Empty(requirements[1].Policies);
+ Assert.Contains(policyA, requirements[0].Requirement.Policies);
+ Assert.Empty(requirements[1].Requirement.Policies);
callback.Received()(Arg.Is(policyA));
callback.Received()(Arg.Is(policyB));
}
@@ -134,9 +136,9 @@ public class PolicyRequirementQueryTests
var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
- Assert.Contains(enforcedPolicyA, requirements[0].Policies);
- Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Policies);
- Assert.Contains(enforcedPolicyB, requirements[1].Policies);
+ Assert.Contains(enforcedPolicyA, requirements[0].Requirement.Policies);
+ Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Requirement.Policies);
+ Assert.Contains(enforcedPolicyB, requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
@@ -164,8 +166,10 @@ public class PolicyRequirementQueryTests
var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
- Assert.Empty(requirements[0].Policies);
- Assert.Empty(requirements[1].Policies);
+ Assert.Equal(userIdA, requirements[0].UserId);
+ Assert.Equal(userIdB, requirements[1].UserId);
+ Assert.Empty(requirements[0].Requirement.Policies);
+ Assert.Empty(requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
@@ -185,8 +189,10 @@ public class PolicyRequirementQueryTests
var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
- Assert.Contains(policyA, requirements[0].Policies);
- Assert.Empty(requirements[1].Policies);
+ Assert.Equal(userIdA, requirements[0].UserId);
+ Assert.Equal(userIdB, requirements[1].UserId);
+ Assert.Contains(policyA, requirements[0].Requirement.Policies);
+ Assert.Empty(requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]