diff --git a/src/bedrock_agentcore/policy/__init__.py b/src/bedrock_agentcore/policy/__init__.py new file mode 100644 index 00000000..2cc24518 --- /dev/null +++ b/src/bedrock_agentcore/policy/__init__.py @@ -0,0 +1,5 @@ +"""Bedrock AgentCore Policy Engine client.""" + +from .client import PolicyEngineClient + +__all__ = ["PolicyEngineClient"] diff --git a/src/bedrock_agentcore/policy/client.py b/src/bedrock_agentcore/policy/client.py new file mode 100644 index 00000000..1b463ca7 --- /dev/null +++ b/src/bedrock_agentcore/policy/client.py @@ -0,0 +1,546 @@ +"""AgentCore Policy Engine SDK - Client for Cedar policy engine operations.""" + +import logging +from typing import Any, Dict, Optional + +import boto3 +from botocore.config import Config +from botocore.exceptions import ClientError + +from .._utils.config import WaitConfig +from .._utils.polling import wait_until, wait_until_deleted +from .._utils.snake_case import accept_snake_case_kwargs, convert_kwargs +from .._utils.user_agent import build_user_agent_suffix + +logger = logging.getLogger(__name__) + +_FAILED_STATUSES = {"CREATE_FAILED", "UPDATE_FAILED", "DELETE_FAILED"} + + +class PolicyEngineClient: + """Client for Bedrock AgentCore Policy Engine operations. + + Provides access to policy engine and Cedar policy CRUD operations. + Allowlisted boto3 methods can be called directly on this client. + Parameters accept both camelCase and snake_case (auto-converted). + + Example:: + + client = PolicyEngineClient(region_name="us-west-2") + + # These are forwarded to the underlying boto3 control plane client + engine = client.create_policy_engine(name="my_engine") + client.create_policy( + policy_engine_id=engine["policyEngineId"], + name="my_policy", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + """ + + _ALLOWED_CP_METHODS = { + # Policy engine CRUD + "create_policy_engine", + "get_policy_engine", + "list_policy_engines", + "update_policy_engine", + "delete_policy_engine", + # Policy CRUD + "create_policy", + "get_policy", + "list_policies", + "update_policy", + "delete_policy", + # Policy generation + "start_policy_generation", + "get_policy_generation", + "list_policy_generations", + "list_policy_generation_assets", + } + + def __init__( + self, + region_name: Optional[str] = None, + integration_source: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + ): + """Initialize the Policy Engine client. + + Args: + region_name: AWS region name. If not provided, uses the session's region or "us-west-2". + integration_source: Optional integration source for user-agent telemetry. + boto3_session: Optional boto3 Session to use. If not provided, a default session + is created. Useful for named profiles or custom credentials. + """ + session = boto3_session if boto3_session else boto3.Session() + self.region_name = region_name or session.region_name or "us-west-2" + self.integration_source = integration_source + + user_agent_extra = build_user_agent_suffix(integration_source) + client_config = Config(user_agent_extra=user_agent_extra) + + self.cp_client = session.client("bedrock-agentcore-control", region_name=self.region_name, config=client_config) + + logger.info("Initialized PolicyEngineClient for region: %s", self.cp_client.meta.region_name) + + # Pass-through + # ------------------------------------------------------------------------- + def __getattr__(self, name: str): + """Dynamically forward allowlisted method calls to the control plane boto3 client.""" + if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name): + method = getattr(self.cp_client, name) + logger.debug("Forwarding method '%s' to cp_client", name) + return accept_snake_case_kwargs(method) + + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'. " + f"Method not found on cp_client. " + f"Available methods can be found in the boto3 documentation for " + f"'bedrock-agentcore-control' service." + ) + + # *_and_wait methods + # ------------------------------------------------------------------------- + def create_policy_engine_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Create a policy engine and wait for it to reach ACTIVE status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the create_policy_engine API. + + Returns: + Policy engine details when ACTIVE. + + Raises: + RuntimeError: If the engine reaches a failed state. + TimeoutError: If the engine doesn't become ACTIVE within max_wait. + """ + response = self.cp_client.create_policy_engine(**convert_kwargs(kwargs)) + engine_id = response["policyEngineId"] + return wait_until( + lambda: self.cp_client.get_policy_engine(policyEngineId=engine_id), + "ACTIVE", + _FAILED_STATUSES, + wait_config, + ) + + def update_policy_engine_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Update a policy engine and wait for it to reach ACTIVE status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the update_policy_engine API. + + Returns: + Policy engine details when ACTIVE. + + Raises: + RuntimeError: If the engine reaches a failed state. + TimeoutError: If the engine doesn't become ACTIVE within max_wait. + """ + response = self.cp_client.update_policy_engine(**convert_kwargs(kwargs)) + engine_id = response["policyEngineId"] + return wait_until( + lambda: self.cp_client.get_policy_engine(policyEngineId=engine_id), + "ACTIVE", + _FAILED_STATUSES, + wait_config, + ) + + def create_policy_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Create a policy and wait for it to reach ACTIVE status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the create_policy API. + Must include policyEngineId. + + Returns: + Policy details when ACTIVE. + + Raises: + RuntimeError: If the policy reaches a failed state. + TimeoutError: If the policy doesn't become ACTIVE within max_wait. + """ + response = self.cp_client.create_policy(**convert_kwargs(kwargs)) + engine_id = response["policyEngineId"] + policy_id = response["policyId"] + return wait_until( + lambda: self.cp_client.get_policy( + policyEngineId=engine_id, + policyId=policy_id, + ), + "ACTIVE", + _FAILED_STATUSES, + wait_config, + ) + + def delete_policy_engine_and_wait( + self, + wait_config: Optional[WaitConfig] = None, + **kwargs, + ) -> None: + """Delete a policy engine and wait for deletion to complete. + + Args: + wait_config: Optional WaitConfig for polling behavior. + **kwargs: Arguments forwarded to the delete_policy_engine API. + + Raises: + RuntimeError: If the engine reaches DELETE_FAILED. + TimeoutError: If the engine isn't deleted within max_wait. + """ + response = self.cp_client.delete_policy_engine(**convert_kwargs(kwargs)) + engine_id = response["policyEngineId"] + wait_until_deleted( + lambda: self.cp_client.get_policy_engine( + policyEngineId=engine_id, + ), + failed=_FAILED_STATUSES, + wait_config=wait_config, + ) + + def delete_policy_and_wait( + self, + wait_config: Optional[WaitConfig] = None, + **kwargs, + ) -> None: + """Delete a policy and wait for deletion to complete. + + Args: + wait_config: Optional WaitConfig for polling behavior. + **kwargs: Arguments forwarded to the delete_policy API. + Must include policyEngineId and policyId. + + Raises: + RuntimeError: If the policy reaches DELETE_FAILED. + TimeoutError: If the policy isn't deleted within max_wait. + """ + response = self.cp_client.delete_policy(**convert_kwargs(kwargs)) + engine_id = response["policyEngineId"] + policy_id = response["policyId"] + wait_until_deleted( + lambda: self.cp_client.get_policy( + policyEngineId=engine_id, + policyId=policy_id, + ), + failed=_FAILED_STATUSES, + wait_config=wait_config, + ) + + def generate_policy_asset_and_wait( + self, + policy_engine_id: str, + name: str, + resource: Dict[str, Any], + content: Dict[str, Any], + client_token: Optional[str] = None, + wait_config: Optional[WaitConfig] = None, + fetch_assets: bool = False, + ) -> Dict[str, Any]: + """Generate Cedar policy assets from natural language and wait for completion. + + Starts policy generation, polls until complete, and optionally fetches + the generated policy assets. + + Args: + policy_engine_id: ID of the policy engine. + name: Name for the generation. + resource: Resource for which policies will be generated (e.g., {"arn": "..."}). + content: Natural language input (e.g., {"rawText": "allow refunds..."}). + client_token: Optional idempotency token. + wait_config: Optional WaitConfig for polling behavior. + fetch_assets: If True, fetch generated assets and include in response. + + Returns: + Generation details. If fetch_assets=True, includes 'generatedPolicies' field. + + Raises: + RuntimeError: If generation fails. + TimeoutError: If generation doesn't complete within max_wait. + """ + request: Dict[str, Any] = { + "policyEngineId": policy_engine_id, + "name": name, + "resource": resource, + "content": content, + } + if client_token is not None: + request["clientToken"] = client_token + + generation = self.cp_client.start_policy_generation(**request) + generation_id = generation["policyGenerationId"] + logger.info("Started policy generation %s, waiting for completion...", generation_id) + + _generation_failed = {"GENERATE_FAILED", "DELETE_FAILED"} + result = wait_until( + lambda: self.cp_client.get_policy_generation( + policyEngineId=policy_engine_id, + policyGenerationId=generation_id, + ), + "GENERATED", + _generation_failed, + wait_config, + ) + + if fetch_assets: + assets = self.cp_client.list_policy_generation_assets( + policyEngineId=policy_engine_id, + policyGenerationId=generation_id, + ) + result["generatedPolicies"] = assets.get("policyGenerationAssets", []) + + return result + + # Higher-level orchestration methods + # ------------------------------------------------------------------------- + def create_policy_from_generation_asset( + self, + policy_engine_id: str, + name: str, + policy_generation_id: str, + policy_generation_asset_id: str, + description: Optional[str] = None, + validation_mode: Optional[str] = None, + client_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Create a policy from a generation asset. + + Args: + policy_engine_id: ID of the policy engine. + name: Name of the policy. + policy_generation_id: ID of the policy generation. + policy_generation_asset_id: ID of the generation asset. + description: Optional description. + validation_mode: Optional validation mode (FAIL_ON_ANY_FINDINGS, IGNORE_ALL_FINDINGS). + client_token: Optional idempotency token. + + Returns: + Policy details including policyId, ARN, and status. + """ + definition = { + "policyGeneration": { + "policyGenerationId": policy_generation_id, + "policyGenerationAssetId": policy_generation_asset_id, + } + } + + request: Dict[str, Any] = { + "policyEngineId": policy_engine_id, + "name": name, + "definition": definition, + } + if description is not None: + request["description"] = description + if validation_mode is not None: + request["validationMode"] = validation_mode + if client_token is not None: + request["clientToken"] = client_token + + return self.cp_client.create_policy(**request) + + def generate_and_create_policy( + self, + policy_engine_id: str, + generation_name: str, + policy_name: str, + resource: Dict[str, Any], + content: Dict[str, Any], + description: Optional[str] = None, + validation_mode: Optional[str] = None, + wait_config: Optional[WaitConfig] = None, + ) -> Dict[str, Any]: + """Generate a Cedar policy from natural language and create it in one step. + + End-to-end flow: starts generation, waits for completion, picks the + first generated asset, and creates a policy from it. + + Args: + policy_engine_id: ID of the policy engine. + generation_name: Name for the generation job. + policy_name: Name for the created policy. + resource: Resource for which policies will be generated. + content: Natural language input (e.g., {"rawText": "allow refunds..."}). + description: Optional description for the created policy. + validation_mode: Optional validation mode for the created policy. + wait_config: Optional WaitConfig for polling behavior. + + Returns: + Policy details including policyId, ARN, and status. + + Raises: + RuntimeError: If generation fails or produces no assets. + TimeoutError: If generation doesn't complete within max_wait. + """ + generation = self.generate_policy_asset_and_wait( + policy_engine_id=policy_engine_id, + name=generation_name, + resource=resource, + content=content, + fetch_assets=True, + wait_config=wait_config, + ) + + assets = generation.get("generatedPolicies", []) + if not assets: + raise RuntimeError( + "Policy generation %s produced no assets" % generation.get("policyGenerationId", "unknown") + ) + + asset = assets[0] + return self.create_policy_from_generation_asset( + policy_engine_id=policy_engine_id, + name=policy_name, + policy_generation_id=generation["policyGenerationId"], + policy_generation_asset_id=asset["policyGenerationAssetId"], + description=description, + validation_mode=validation_mode, + ) + + # Idempotent creates + # ------------------------------------------------------------------------- + def create_or_get_policy_engine( + self, + name: str, + description: Optional[str] = None, + encryption_key_arn: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + client_token: Optional[str] = None, + wait_config: Optional[WaitConfig] = None, + ) -> Dict[str, Any]: + """Create a policy engine or return existing one with the same name. + + Idempotent — if a ConflictException occurs, finds the existing engine + by name. Waits for ACTIVE status before returning. + + Args: + name: Name of the policy engine. + description: Optional description. + encryption_key_arn: Optional KMS key ARN. + tags: Optional tags. + client_token: Optional idempotency token. + wait_config: Optional WaitConfig for polling behavior. + + Returns: + Policy engine details in ACTIVE status. + """ + try: + request: Dict[str, Any] = {"name": name} + if description is not None: + request["description"] = description + if encryption_key_arn is not None: + request["encryptionKeyArn"] = encryption_key_arn + if tags is not None: + request["tags"] = tags + if client_token is not None: + request["clientToken"] = client_token + + resp = self.cp_client.create_policy_engine(**request) + engine_id = resp["policyEngineId"] + logger.info("Created policy engine %s, waiting for ACTIVE...", engine_id) + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise + logger.info("Policy engine '%s' already exists, looking up...", name) + engine_id = self._find_policy_engine_by_name(name) + if not engine_id: + raise RuntimeError(f"Policy engine '{name}' already exists but could not be found by name.") from e + + return wait_until( + lambda: self.cp_client.get_policy_engine(policyEngineId=engine_id), + "ACTIVE", + _FAILED_STATUSES, + wait_config, + ) + + def create_or_get_policy( + self, + policy_engine_id: str, + name: str, + definition: Dict[str, Any], + description: Optional[str] = None, + validation_mode: Optional[str] = None, + client_token: Optional[str] = None, + wait_config: Optional[WaitConfig] = None, + ) -> Dict[str, Any]: + """Create a policy or return existing one with the same name. + + Idempotent — if a ConflictException occurs, finds the existing policy + by name. Waits for ACTIVE status before returning. + + Args: + policy_engine_id: ID of the policy engine. + name: Name of the policy. + definition: Policy definition. + description: Optional description. + validation_mode: Optional validation mode. + client_token: Optional idempotency token. + wait_config: Optional WaitConfig for polling behavior. + + Returns: + Policy details in ACTIVE status. + """ + try: + request: Dict[str, Any] = { + "policyEngineId": policy_engine_id, + "name": name, + "definition": definition, + } + if description is not None: + request["description"] = description + if validation_mode is not None: + request["validationMode"] = validation_mode + if client_token is not None: + request["clientToken"] = client_token + + resp = self.cp_client.create_policy(**request) + policy_id = resp["policyId"] + logger.info("Created policy %s, waiting for ACTIVE...", policy_id) + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise + logger.info("Policy '%s' already exists, looking up...", name) + policy_id = self._find_policy_by_name(policy_engine_id, name) + if not policy_id: + raise RuntimeError(f"Policy '{name}' already exists but could not be found by name.") from e + + return wait_until( + lambda: self.cp_client.get_policy( + policyEngineId=policy_engine_id, + policyId=policy_id, + ), + "ACTIVE", + _FAILED_STATUSES, + wait_config, + ) + + # Helper methods + # ------------------------------------------------------------------------- + def _find_policy_engine_by_name(self, name: str) -> Optional[str]: + """Find a policy engine ID by name. Returns None if not found.""" + next_token = None + while True: + params: Dict[str, Any] = {"maxResults": 100} + if next_token: + params["nextToken"] = next_token + resp = self.cp_client.list_policy_engines(**params) + for engine in resp.get("policyEngines", []): + if engine.get("name") == name: + return engine["policyEngineId"] + next_token = resp.get("nextToken") + if not next_token: + return None + + def _find_policy_by_name(self, policy_engine_id: str, name: str) -> Optional[str]: + """Find a policy ID by name within an engine. Returns None if not found.""" + next_token = None + while True: + params: Dict[str, Any] = {"policyEngineId": policy_engine_id, "maxResults": 100} + if next_token: + params["nextToken"] = next_token + resp = self.cp_client.list_policies(**params) + for policy in resp.get("policies", []): + if policy.get("name") == name: + return policy["policyId"] + next_token = resp.get("nextToken") + if not next_token: + return None diff --git a/tests/unit/policy/__init__.py b/tests/unit/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/policy/test_policy_engine_client.py b/tests/unit/policy/test_policy_engine_client.py new file mode 100644 index 00000000..b25ee9c5 --- /dev/null +++ b/tests/unit/policy/test_policy_engine_client.py @@ -0,0 +1,456 @@ +"""Tests for PolicyEngineClient.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore.policy.client import PolicyEngineClient + + +class TestPolicyEngineClientInit: + """Tests for PolicyEngineClient initialization.""" + + def test_init_with_region(self): + mock_session = MagicMock() + mock_session.region_name = "eu-west-1" + client = PolicyEngineClient(region_name="us-west-2", boto3_session=mock_session) + assert client.region_name == "us-west-2" + + def test_init_default_region_from_session(self): + mock_session = MagicMock() + mock_session.region_name = "eu-west-1" + client = PolicyEngineClient(boto3_session=mock_session) + assert client.region_name == "eu-west-1" + + def test_init_default_region_fallback(self): + mock_session = MagicMock() + mock_session.region_name = None + client = PolicyEngineClient(boto3_session=mock_session) + assert client.region_name == "us-west-2" + + def test_init_creates_cp_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + PolicyEngineClient(boto3_session=mock_session) + + mock_session.client.assert_called_once() + call_args = mock_session.client.call_args + assert call_args[0][0] == "bedrock-agentcore-control" + + def test_init_with_integration_source(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = PolicyEngineClient(boto3_session=mock_session, integration_source="langchain") + assert client.integration_source == "langchain" + + +class TestPolicyEngineClientPassthrough: + """Tests for __getattr__ passthrough.""" + + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = PolicyEngineClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_policy_engine_crud_forwarded(self): + client = self._make_client() + client.cp_client.create_policy_engine.return_value = {"policyEngineId": "pe-123"} + + result = client.create_policy_engine(name="test-engine") + + client.cp_client.create_policy_engine.assert_called_once_with(name="test-engine") + assert result["policyEngineId"] == "pe-123" + + def test_policy_crud_forwarded(self): + client = self._make_client() + client.cp_client.create_policy.return_value = {"policyId": "pol-123"} + + result = client.create_policy( + policyEngineId="pe-123", + name="test-policy", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + + client.cp_client.create_policy.assert_called_once() + assert result["policyId"] == "pol-123" + + def test_snake_case_kwargs_converted(self): + client = self._make_client() + client.cp_client.get_policy_engine.return_value = {"policyEngineId": "pe-123"} + + client.get_policy_engine(policy_engine_id="pe-123") + + client.cp_client.get_policy_engine.assert_called_once_with(policyEngineId="pe-123") + + def test_non_allowlisted_method_raises_attribute_error(self): + client = self._make_client() + + with pytest.raises(AttributeError, match="has no attribute 'not_a_real_method'"): + client.not_a_real_method() + + def test_all_cp_methods_in_allowlist(self): + expected = { + "create_policy_engine", + "get_policy_engine", + "list_policy_engines", + "update_policy_engine", + "delete_policy_engine", + "create_policy", + "get_policy", + "list_policies", + "update_policy", + "delete_policy", + "start_policy_generation", + "get_policy_generation", + "list_policy_generations", + "list_policy_generation_assets", + } + assert expected == PolicyEngineClient._ALLOWED_CP_METHODS + + +class TestGeneratePolicy: + """Tests for generate_policy_asset_and_wait orchestration method.""" + + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = PolicyEngineClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_generate_policy_immediate_success(self): + client = self._make_client() + client.cp_client.start_policy_generation.return_value = {"policyGenerationId": "gen-123"} + client.cp_client.get_policy_generation.return_value = {"status": "GENERATED", "policyGenerationId": "gen-123"} + + result = client.generate_policy_asset_and_wait( + policy_engine_id="pe-123", + name="test-gen", + resource={"arn": "arn:aws:test"}, + content={"rawText": "allow refunds"}, + ) + + assert result["status"] == "GENERATED" + client.cp_client.start_policy_generation.assert_called_once() + + def test_generate_policy_with_fetch_assets(self): + client = self._make_client() + client.cp_client.start_policy_generation.return_value = {"policyGenerationId": "gen-123"} + client.cp_client.get_policy_generation.return_value = {"status": "GENERATED", "policyGenerationId": "gen-123"} + client.cp_client.list_policy_generation_assets.return_value = { + "policyGenerationAssets": [{"assetId": "asset-1"}] + } + + result = client.generate_policy_asset_and_wait( + policy_engine_id="pe-123", + name="test-gen", + resource={"arn": "arn:aws:test"}, + content={"rawText": "allow refunds"}, + fetch_assets=True, + ) + + assert result["generatedPolicies"] == [{"assetId": "asset-1"}] + client.cp_client.list_policy_generation_assets.assert_called_once() + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch("bedrock_agentcore._utils.polling.time.time", side_effect=[0, 0, 0, 1, 1]) + def test_generate_policy_polls_through_generating(self, _mock_time, _mock_sleep): + client = self._make_client() + client.cp_client.start_policy_generation.return_value = {"policyGenerationId": "gen-123"} + client.cp_client.get_policy_generation.side_effect = [ + {"status": "GENERATING"}, + {"status": "GENERATED", "policyGenerationId": "gen-123"}, + ] + + result = client.generate_policy_asset_and_wait( + policy_engine_id="pe-123", + name="test-gen", + resource={"arn": "arn:aws:test"}, + content={"rawText": "allow refunds"}, + ) + + assert result["status"] == "GENERATED" + assert client.cp_client.get_policy_generation.call_count == 2 + + def test_generate_policy_failed(self): + client = self._make_client() + client.cp_client.start_policy_generation.return_value = {"policyGenerationId": "gen-123"} + client.cp_client.get_policy_generation.return_value = { + "status": "GENERATE_FAILED", + "statusReasons": ["Invalid input"], + } + + with pytest.raises(RuntimeError, match="GENERATE_FAILED"): + client.generate_policy_asset_and_wait( + policy_engine_id="pe-123", + name="test-gen", + resource={"arn": "arn:aws:test"}, + content={"rawText": "bad input"}, + ) + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch("bedrock_agentcore._utils.polling.time.time", side_effect=[0, 0, 0, 301]) + def test_generate_policy_timeout(self, _mock_time, _mock_sleep): + client = self._make_client() + client.cp_client.start_policy_generation.return_value = {"policyGenerationId": "gen-123"} + client.cp_client.get_policy_generation.return_value = {"status": "GENERATING"} + + with pytest.raises(TimeoutError): + client.generate_policy_asset_and_wait( + policy_engine_id="pe-123", + name="test-gen", + resource={"arn": "arn:aws:test"}, + content={"rawText": "allow refunds"}, + ) + + +class TestCreatePolicyFromGenerationAsset: + """Tests for create_policy_from_generation_asset.""" + + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = PolicyEngineClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_creates_policy_with_generation_definition(self): + client = self._make_client() + client.cp_client.create_policy.return_value = {"policyId": "pol-123", "policyArn": "arn:test"} + + result = client.create_policy_from_generation_asset( + policy_engine_id="pe-123", + name="generated-policy", + policy_generation_id="gen-456", + policy_generation_asset_id="asset-789", + ) + + assert result["policyId"] == "pol-123" + call_kwargs = client.cp_client.create_policy.call_args[1] + assert call_kwargs["definition"] == { + "policyGeneration": { + "policyGenerationId": "gen-456", + "policyGenerationAssetId": "asset-789", + } + } + + def test_passes_optional_params(self): + client = self._make_client() + client.cp_client.create_policy.return_value = {"policyId": "pol-123"} + + client.create_policy_from_generation_asset( + policy_engine_id="pe-123", + name="generated-policy", + policy_generation_id="gen-456", + policy_generation_asset_id="asset-789", + description="From NL generation", + validation_mode="IGNORE_ALL_FINDINGS", + client_token="token-abc", + ) + + call_kwargs = client.cp_client.create_policy.call_args[1] + assert call_kwargs["description"] == "From NL generation" + assert call_kwargs["validationMode"] == "IGNORE_ALL_FINDINGS" + assert call_kwargs["clientToken"] == "token-abc" + + +class TestWaitForActive: + """Tests for *_and_wait methods using shared wait_until.""" + + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = PolicyEngineClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_create_engine_and_wait_immediate(self): + client = self._make_client() + client.cp_client.create_policy_engine.return_value = {"policyEngineId": "pe-123"} + client.cp_client.get_policy_engine.return_value = {"status": "ACTIVE", "policyEngineId": "pe-123"} + + result = client.create_policy_engine_and_wait(name="test") + + assert result["status"] == "ACTIVE" + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch("bedrock_agentcore._utils.polling.time.time", side_effect=[0, 0, 1, 1]) + def test_create_engine_and_wait_polls(self, _mock_time, _mock_sleep): + client = self._make_client() + client.cp_client.create_policy_engine.return_value = {"policyEngineId": "pe-123"} + client.cp_client.get_policy_engine.side_effect = [ + {"status": "CREATING"}, + {"status": "ACTIVE", "policyEngineId": "pe-123"}, + ] + + result = client.create_policy_engine_and_wait(name="test") + + assert result["status"] == "ACTIVE" + assert client.cp_client.get_policy_engine.call_count == 2 + + def test_create_engine_and_wait_failed_status(self): + client = self._make_client() + client.cp_client.create_policy_engine.return_value = {"policyEngineId": "pe-123"} + client.cp_client.get_policy_engine.return_value = { + "status": "CREATE_FAILED", + "statusReasons": ["something broke"], + } + + with pytest.raises(RuntimeError, match="CREATE_FAILED"): + client.create_policy_engine_and_wait(name="test") + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch("bedrock_agentcore._utils.polling.time.time", side_effect=[0, 0, 0, 301]) + def test_create_engine_and_wait_timeout(self, _mock_time, _mock_sleep): + client = self._make_client() + client.cp_client.create_policy_engine.return_value = {"policyEngineId": "pe-123"} + client.cp_client.get_policy_engine.return_value = {"status": "CREATING"} + + with pytest.raises(TimeoutError): + client.create_policy_engine_and_wait(name="test") + + def test_create_policy_and_wait_immediate(self): + client = self._make_client() + client.cp_client.create_policy.return_value = { + "policyEngineId": "pe-123", + "policyId": "pol-123", + } + client.cp_client.get_policy.return_value = {"status": "ACTIVE", "policyId": "pol-123"} + + result = client.create_policy_and_wait( + policyEngineId="pe-123", + name="test", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + + assert result["status"] == "ACTIVE" + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch("bedrock_agentcore._utils.polling.time.time", side_effect=[0, 0, 1, 1]) + def test_create_policy_and_wait_polls(self, _mock_time, _mock_sleep): + client = self._make_client() + client.cp_client.create_policy.return_value = { + "policyEngineId": "pe-123", + "policyId": "pol-123", + } + client.cp_client.get_policy.side_effect = [ + {"status": "CREATING"}, + {"status": "ACTIVE", "policyId": "pol-123"}, + ] + + result = client.create_policy_and_wait( + policyEngineId="pe-123", + name="test", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + + assert result["status"] == "ACTIVE" + assert client.cp_client.get_policy.call_count == 2 + + def test_create_policy_and_wait_failed_status(self): + client = self._make_client() + client.cp_client.create_policy.return_value = { + "policyEngineId": "pe-123", + "policyId": "pol-123", + } + client.cp_client.get_policy.return_value = { + "status": "CREATE_FAILED", + "statusReasons": ["something broke"], + } + + with pytest.raises(RuntimeError, match="CREATE_FAILED"): + client.create_policy_and_wait( + policyEngineId="pe-123", + name="test", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + + +class TestCreateOrGet: + """Tests for create_or_get_policy_engine and create_or_get_policy.""" + + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = PolicyEngineClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def _conflict_error(self, operation="CreatePolicyEngine"): + return ClientError( + {"Error": {"Code": "ConflictException", "Message": "already exists"}}, + operation, + ) + + def test_create_or_get_engine_creates_new(self): + client = self._make_client() + client.cp_client.create_policy_engine.return_value = {"policyEngineId": "pe-123"} + client.cp_client.get_policy_engine.return_value = {"status": "ACTIVE", "policyEngineId": "pe-123"} + + result = client.create_or_get_policy_engine(name="test-engine") + + assert result["status"] == "ACTIVE" + client.cp_client.create_policy_engine.assert_called_once() + + def test_create_or_get_engine_finds_existing(self): + client = self._make_client() + client.cp_client.create_policy_engine.side_effect = self._conflict_error() + client.cp_client.list_policy_engines.return_value = { + "policyEngines": [{"name": "test-engine", "policyEngineId": "pe-existing"}], + } + client.cp_client.get_policy_engine.return_value = {"status": "ACTIVE", "policyEngineId": "pe-existing"} + + result = client.create_or_get_policy_engine(name="test-engine") + + assert result["policyEngineId"] == "pe-existing" + + def test_create_or_get_engine_not_found_reraises(self): + client = self._make_client() + client.cp_client.create_policy_engine.side_effect = self._conflict_error() + client.cp_client.list_policy_engines.return_value = {"policyEngines": []} + + with pytest.raises(RuntimeError, match="already exists but could not be found"): + client.create_or_get_policy_engine(name="ghost-engine") + + def test_create_or_get_policy_creates_new(self): + client = self._make_client() + client.cp_client.create_policy.return_value = {"policyId": "pol-123"} + client.cp_client.get_policy.return_value = {"status": "ACTIVE", "policyId": "pol-123"} + + result = client.create_or_get_policy( + policy_engine_id="pe-123", + name="test-policy", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + + assert result["status"] == "ACTIVE" + client.cp_client.create_policy.assert_called_once() + + def test_create_or_get_policy_finds_existing(self): + client = self._make_client() + client.cp_client.create_policy.side_effect = self._conflict_error("CreatePolicy") + client.cp_client.list_policies.return_value = { + "policies": [{"name": "test-policy", "policyId": "pol-existing"}], + } + client.cp_client.get_policy.return_value = {"status": "ACTIVE", "policyId": "pol-existing"} + + result = client.create_or_get_policy( + policy_engine_id="pe-123", + name="test-policy", + definition={"cedar": {"statement": "permit(principal, action, resource);"}}, + ) + + assert result["policyId"] == "pol-existing" + + def test_non_conflict_error_reraises(self): + client = self._make_client() + client.cp_client.create_policy_engine.side_effect = ClientError( + {"Error": {"Code": "ValidationException", "Message": "bad input"}}, + "CreatePolicyEngine", + ) + + with pytest.raises(ClientError, match="ValidationException"): + client.create_or_get_policy_engine(name="bad-engine") diff --git a/tests_integ/policy/__init__.py b/tests_integ/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests_integ/policy/test_policy_engine_client.py b/tests_integ/policy/test_policy_engine_client.py new file mode 100644 index 00000000..2bc1334e --- /dev/null +++ b/tests_integ/policy/test_policy_engine_client.py @@ -0,0 +1,169 @@ +"""Integration tests for PolicyEngineClient.""" + +import os +import time + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore.policy.client import PolicyEngineClient + + +@pytest.mark.integration +class TestPolicyEngineClient: + """Integration tests for PolicyEngineClient CRUD and wait methods.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.client = PolicyEngineClient(region_name=cls.region) + cls.test_prefix = f"sdk_integ_{int(time.time())}" + cls.engine_ids = [] + cls.policy_ids = [] + + @classmethod + def teardown_class(cls): + for engine_id, policy_id in cls.policy_ids: + try: + cls.client.delete_policy_and_wait( + policyEngineId=engine_id, + policyId=policy_id, + ) + except Exception as e: + print(f"Failed to delete policy {policy_id}: {e}") + for engine_id in cls.engine_ids: + try: + cls.client.delete_policy_engine_and_wait( + policyEngineId=engine_id, + ) + except Exception as e: + print(f"Failed to delete engine {engine_id}: {e}") + + @pytest.mark.order(1) + def test_create_policy_engine_and_wait(self): + engine = self.client.create_policy_engine_and_wait( + name=f"{self.test_prefix}_engine", + ) + self.__class__.engine_ids.append(engine["policyEngineId"]) + assert engine["status"] == "ACTIVE" + assert engine["name"] == f"{self.test_prefix}_engine" + + @pytest.mark.order(2) + def test_get_policy_engine_passthrough(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + engine = self.client.get_policy_engine( + policyEngineId=self.engine_ids[0], + ) + assert engine["status"] == "ACTIVE" + + @pytest.mark.order(3) + def test_get_policy_engine_snake_case(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + engine = self.client.get_policy_engine( + policy_engine_id=self.engine_ids[0], + ) + assert engine["status"] == "ACTIVE" + + @pytest.mark.order(4) + def test_update_policy_engine_and_wait(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + updated = self.client.update_policy_engine_and_wait( + policyEngineId=self.engine_ids[0], + description={"optionalValue": "updated by integ test"}, + ) + assert updated["status"] == "ACTIVE" + assert updated.get("description") == "updated by integ test" + + @pytest.mark.order(5) + def test_create_policy_and_wait(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + policy = self.client.create_policy_and_wait( + policyEngineId=self.engine_ids[0], + name=f"{self.test_prefix}_policy", + definition={"cedar": {"statement": ("permit(principal, action, resource is AgentCore::Gateway);")}}, + validationMode="IGNORE_ALL_FINDINGS", + ) + self.__class__.policy_ids.append( + (self.engine_ids[0], policy["policyId"]), + ) + assert policy["status"] == "ACTIVE" + + @pytest.mark.order(6) + def test_get_policy_passthrough(self): + if not self.policy_ids: + pytest.skip("prerequisite test did not create policy") + engine_id, policy_id = self.policy_ids[0] + policy = self.client.get_policy( + policyEngineId=engine_id, + policyId=policy_id, + ) + assert policy["status"] == "ACTIVE" + + @pytest.mark.order(7) + def test_list_policy_engines_passthrough(self): + engines = self.client.list_policy_engines() + assert "policyEngines" in engines + + @pytest.mark.order(8) + def test_list_policies_passthrough(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + policies = self.client.list_policies( + policyEngineId=self.engine_ids[0], + ) + assert "policies" in policies + + @pytest.mark.order(9) + def test_create_or_get_policy_engine(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + # Call create_or_get with the same name — should return existing + engine = self.client.create_or_get_policy_engine( + name=f"{self.test_prefix}_engine", + ) + assert engine["policyEngineId"] == self.engine_ids[0] + assert engine["status"] == "ACTIVE" + + @pytest.mark.order(10) + def test_create_or_get_policy(self): + if not self.engine_ids or not self.policy_ids: + pytest.skip("prerequisite tests did not create resources") + engine_id, existing_policy_id = self.policy_ids[0] + # Call create_or_get with the same name — should return existing + policy = self.client.create_or_get_policy( + policy_engine_id=engine_id, + name=f"{self.test_prefix}_policy", + definition={"cedar": {"statement": ("permit(principal, action, resource is AgentCore::Gateway);")}}, + ) + assert policy["policyId"] == existing_policy_id + assert policy["status"] == "ACTIVE" + + @pytest.mark.order(11) + def test_delete_policy_and_wait(self): + if not self.policy_ids: + pytest.skip("prerequisite test did not create policy") + engine_id, policy_id = self.policy_ids.pop(0) + self.client.delete_policy_and_wait( + policyEngineId=engine_id, + policyId=policy_id, + ) + with pytest.raises(ClientError): + self.client.get_policy( + policyEngineId=engine_id, + policyId=policy_id, + ) + + @pytest.mark.order(12) + def test_delete_policy_engine_and_wait(self): + if not self.engine_ids: + pytest.skip("prerequisite test did not create engine") + engine_id = self.engine_ids.pop(0) + self.client.delete_policy_engine_and_wait( + policyEngineId=engine_id, + ) + with pytest.raises(ClientError): + self.client.get_policy_engine(policyEngineId=engine_id)