From 79fe4cf0151728d387b324e8aae033be848b663d Mon Sep 17 00:00:00 2001 From: Rui Tome Date: Mon, 12 Jan 2026 14:53:11 +0000 Subject: [PATCH] Enhance InitPendingOrganizationCommand with policy validation and feature flag support Updated the ValidatePoliciesAsync method to enforce the Automatic User Confirmation Policy when the feature flag is enabled. Added new unit tests to cover scenarios for automatic user confirmation and single organization policy violations, ensuring comprehensive validation during organization initialization. This improves error handling and maintains compliance with organizational policies. --- .../InitPendingOrganizationCommand.cs | 15 +- .../InitPendingOrganizationCommandTests.cs | 213 ++++++++++++++++-- 2 files changed, 203 insertions(+), 25 deletions(-) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs index 073fb1e22d..aa6b8240f5 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs @@ -268,8 +268,19 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand private async Task ValidatePoliciesAsync(User user, Guid organizationId, Organization org, OrganizationUser orgUser) { - var autoConfirmReq = await _policyRequirementQuery.GetAsync(user.Id); - if (autoConfirmReq.CannotCreateNewOrganization() || autoConfirmReq.IsEnabledForOrganizationsOtherThan(organizationId)) + // Enforce Automatic User Confirmation Policy (when feature flag is enabled) + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var autoConfirmReq = await _policyRequirementQuery.GetAsync(user.Id); + if (autoConfirmReq.CannotCreateNewOrganization()) + { + return new SingleOrgPolicyViolationError(); + } + } + + // Enforce Single Organization Policy + var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg); + if (anySingleOrgPolicies) { return new SingleOrgPolicyViolationError(); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs index 8ab467fe6a..f295cb8431 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs @@ -1,8 +1,10 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Enums; @@ -198,9 +200,11 @@ public class InitPendingOrganizationCommandTests var autoConfirmReq = new AutomaticUserConfirmationPolicyRequirement(new List()); var twoFactorReq = new RequireTwoFactorPolicyRequirement(new List()); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(autoConfirmReq); + .GetAsync(user.Id) + .Returns(autoConfirmReq); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(twoFactorReq); + .GetAsync(user.Id) + .Returns(twoFactorReq); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -215,19 +219,46 @@ public class InitPendingOrganizationCommandTests Assert.Equal(user.Id, orgUser.UserId); Assert.Equal(userKey, orgUser.Key); Assert.Null(orgUser.Email); - await sutProvider.GetDependency().Received().UpdateAsync(org); - await sutProvider.GetDependency().Received().ReplaceAsync(orgUser); - await sutProvider.GetDependency().Received() + + await sutProvider.GetDependency() + .Received(1) + .UpdateAsync(org); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(orgUser); + await sutProvider.GetDependency() + .Received(1) .LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); } + [Theory, BitAutoData] + public async Task InitPendingOrganizationVNextAsync_WithNullOrgUser_ReturnsOrganizationUserNotFoundError( + User user, Guid orgId, Guid orgUserId, string publicKey, string privateKey, string userKey, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns((OrganizationUser)null); + + // Act + var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( + user, orgId, orgUserId, publicKey, privateKey, "", "token", userKey); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + [Theory, BitAutoData] public async Task InitPendingOrganizationVNextAsync_WithInvalidToken_ReturnsInvalidTokenError( User user, Guid orgId, Guid orgUserId, string publicKey, string privateKey, string userKey, SutProvider sutProvider, OrganizationUser orgUser) { // Arrange - sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(orgUser); + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns(orgUser); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -249,7 +280,9 @@ public class InitPendingOrganizationCommandTests var token = CreateToken(orgUser, orgUserId, sutProvider); org.Enabled = true; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -272,7 +305,9 @@ public class InitPendingOrganizationCommandTests org.Enabled = false; org.Status = OrganizationStatusType.Created; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -296,7 +331,9 @@ public class InitPendingOrganizationCommandTests org.Status = OrganizationStatusType.Pending; org.PublicKey = "existing-key"; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -307,6 +344,29 @@ public class InitPendingOrganizationCommandTests Assert.IsType(result.AsError); } + [Theory, BitAutoData] + public async Task InitPendingOrganizationVNextAsync_WithNullOrganization_ReturnsOrganizationNotFoundError( + User user, Guid orgId, Guid orgUserId, string publicKey, string privateKey, string userKey, + SutProvider sutProvider, OrganizationUser orgUser) + { + // Arrange + orgUser.Email = user.Email; + orgUser.OrganizationId = orgId; + var token = CreateToken(orgUser, orgUserId, sutProvider); + + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns((Organization)null); + + // Act + var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( + user, orgId, orgUserId, publicKey, privateKey, "", token, userKey); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + [Theory, BitAutoData] public async Task InitPendingOrganizationVNextAsync_WithEmailMismatch_ReturnsEmailMismatchError( User user, Guid orgId, Guid orgUserId, string publicKey, string privateKey, string userKey, @@ -317,7 +377,9 @@ public class InitPendingOrganizationCommandTests orgUser.OrganizationId = orgId; var token = CreateToken(orgUser, orgUserId, sutProvider); - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -344,15 +406,21 @@ public class InitPendingOrganizationCommandTests org.PrivateKey = null; org.PublicKey = null; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); - sutProvider.GetDependency().TwoFactorIsEnabledAsync(user).Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(user) + .Returns(true); var autoConfirmReq = new AutomaticUserConfirmationPolicyRequirement(new List()); var twoFactorReq = new RequireTwoFactorPolicyRequirement(new List()); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(autoConfirmReq); + .GetAsync(user.Id) + .Returns(autoConfirmReq); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(twoFactorReq); + .GetAsync(user.Id) + .Returns(twoFactorReq); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -380,17 +448,23 @@ public class InitPendingOrganizationCommandTests org.PrivateKey = null; org.PublicKey = null; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); - sutProvider.GetDependency().TwoFactorIsEnabledAsync(user).Returns(false); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(user) + .Returns(false); var autoConfirmReq = new AutomaticUserConfirmationPolicyRequirement(new List()); var twoFactorReq = new RequireTwoFactorPolicyRequirement( new List { new PolicyDetails { OrganizationId = orgId } }); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(autoConfirmReq); + .GetAsync(user.Id) + .Returns(autoConfirmReq); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(twoFactorReq); + .GetAsync(user.Id) + .Returns(twoFactorReq); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -417,16 +491,21 @@ public class InitPendingOrganizationCommandTests org.PublicKey = null; org.PlanType = PlanType.Free; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); sutProvider.GetDependency() - .GetCountByFreeOrganizationAdminUserAsync(user.Id).Returns(1); + .GetCountByFreeOrganizationAdminUserAsync(user.Id) + .Returns(1); var autoConfirmReq = new AutomaticUserConfirmationPolicyRequirement(new List()); var twoFactorReq = new RequireTwoFactorPolicyRequirement(new List()); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(autoConfirmReq); + .GetAsync(user.Id) + .Returns(autoConfirmReq); sutProvider.GetDependency() - .GetAsync(user.Id).Returns(twoFactorReq); + .GetAsync(user.Id) + .Returns(twoFactorReq); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( @@ -437,6 +516,92 @@ public class InitPendingOrganizationCommandTests Assert.IsType(result.AsError); } + [Theory, BitAutoData] + public async Task InitPendingOrganizationVNextAsync_WithAutomaticUserConfirmationPolicy_ReturnsSingleOrgPolicyViolationError( + User user, Guid orgId, Guid orgUserId, Guid otherOrgId, string publicKey, string privateKey, string userKey, + SutProvider sutProvider, Organization org, OrganizationUser orgUser) + { + // Arrange + orgUser.Email = user.Email; + orgUser.OrganizationId = orgId; + var token = CreateToken(orgUser, orgUserId, sutProvider); + org.Enabled = false; + org.Status = OrganizationStatusType.Pending; + org.PrivateKey = null; + org.PublicKey = null; + + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + + // Enable AutomaticConfirmUsers feature flag + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + // User is subject to AutomaticUserConfirmation policy from another organization + var autoConfirmReq = new AutomaticUserConfirmationPolicyRequirement( + new List { new PolicyDetails { OrganizationId = otherOrgId } }); + var twoFactorReq = new RequireTwoFactorPolicyRequirement(new List()); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(autoConfirmReq); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(twoFactorReq); + + // No legacy SingleOrg policy + sutProvider.GetDependency() + .AnyPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg).Returns(false); + + // Act + var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( + user, orgId, orgUserId, publicKey, privateKey, "", token, userKey); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory, BitAutoData] + public async Task InitPendingOrganizationVNextAsync_WithSingleOrgPolicy_ReturnsSingleOrgPolicyViolationError( + User user, Guid orgId, Guid orgUserId, string publicKey, string privateKey, string userKey, + SutProvider sutProvider, Organization org, OrganizationUser orgUser) + { + // Arrange + orgUser.Email = user.Email; + orgUser.OrganizationId = orgId; + var token = CreateToken(orgUser, orgUserId, sutProvider); + org.Enabled = false; + org.Status = OrganizationStatusType.Pending; + org.PrivateKey = null; + org.PublicKey = null; + + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + + var autoConfirmReq = new AutomaticUserConfirmationPolicyRequirement(new List()); + var twoFactorReq = new RequireTwoFactorPolicyRequirement(new List()); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(autoConfirmReq); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(twoFactorReq); + + // User is subject to SingleOrg policy from another organization + sutProvider.GetDependency() + .AnyPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg) + .Returns(true); + + // Act + var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync( + user, orgId, orgUserId, publicKey, privateKey, "", token, userKey); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + [Theory, BitAutoData] public async Task InitPendingOrganizationVNextAsync_WithMismatchedOrganizationId_ReturnsOrganizationMismatchError( User user, Guid orgId, Guid differentOrgId, Guid orgUserId, string publicKey, string privateKey, string userKey, @@ -453,7 +618,9 @@ public class InitPendingOrganizationCommandTests org.PrivateKey = null; org.PublicKey = null; - sutProvider.GetDependency().GetByIdAsync(orgId).Returns(org); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns(org); // Act var result = await sutProvider.Sut.InitPendingOrganizationVNextAsync(