diff --git a/src/bedrock_agentcore/gateway/__init__.py b/src/bedrock_agentcore/gateway/__init__.py new file mode 100644 index 00000000..82fe9706 --- /dev/null +++ b/src/bedrock_agentcore/gateway/__init__.py @@ -0,0 +1,5 @@ +"""Bedrock AgentCore Gateway client.""" + +from .client import GatewayClient + +__all__ = ["GatewayClient"] diff --git a/src/bedrock_agentcore/gateway/client.py b/src/bedrock_agentcore/gateway/client.py new file mode 100644 index 00000000..dea9a77a --- /dev/null +++ b/src/bedrock_agentcore/gateway/client.py @@ -0,0 +1,302 @@ +"""AgentCore Gateway SDK - Client for MCP gateway and target operations.""" + +import logging +from typing import Any, Dict, Optional + +import boto3 +from botocore.config import Config + +from .._utils.config import WaitConfig +from .._utils.polling import wait_until, wait_until_deleted +from .._utils.snake_case import accept_snake_case_kwargs, convert_kwargs +from .._utils.user_agent import build_user_agent_suffix + +logger = logging.getLogger(__name__) + +_GATEWAY_FAILED_STATUSES = {"FAILED", "UPDATE_UNSUCCESSFUL"} +_TARGET_FAILED_STATUSES = {"FAILED", "UPDATE_UNSUCCESSFUL", "SYNCHRONIZE_UNSUCCESSFUL"} + + +class GatewayClient: + """Client for Bedrock AgentCore Gateway operations. + + Provides access to gateway and gateway target CRUD operations. + Allowlisted boto3 methods can be called directly on this client. + Parameters accept both camelCase and snake_case (auto-converted). + + Example:: + + client = GatewayClient(region_name="us-west-2") + + # Pass-through to boto3 control plane client + gateway = client.create_gateway( + name="my-gateway", + roleArn="arn:aws:iam::123456789:role/gateway-role", + protocolType="MCP", + ) + """ + + _ALLOWED_CP_METHODS = { + # Gateway CRUD + "create_gateway", + "get_gateway", + "list_gateways", + "update_gateway", + "delete_gateway", + # Gateway target CRUD + "create_gateway_target", + "get_gateway_target", + "list_gateway_targets", + "update_gateway_target", + "delete_gateway_target", + } + + def __init__( + self, + region_name: Optional[str] = None, + integration_source: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + ): + """Initialize the Gateway client. + + Args: + region_name: AWS region name. If not provided, uses the session's region or "us-west-2". + integration_source: Optional integration source for user-agent telemetry. + boto3_session: Optional boto3 Session to use. If not provided, a default session + is created. Useful for named profiles or custom credentials. + """ + session = boto3_session if boto3_session else boto3.Session() + self.region_name = region_name or session.region_name or "us-west-2" + self.integration_source = integration_source + + user_agent_extra = build_user_agent_suffix(integration_source) + client_config = Config(user_agent_extra=user_agent_extra) + + self.cp_client = session.client("bedrock-agentcore-control", region_name=self.region_name, config=client_config) + + logger.info("Initialized GatewayClient for region: %s", self.cp_client.meta.region_name) + + # Pass-through + # ------------------------------------------------------------------------- + def __getattr__(self, name: str): + """Dynamically forward allowlisted method calls to the control plane boto3 client.""" + if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name): + method = getattr(self.cp_client, name) + logger.debug("Forwarding method '%s' to cp_client", name) + return accept_snake_case_kwargs(method) + + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'. " + f"Method not found on cp_client. " + f"Available methods can be found in the boto3 documentation for " + f"'bedrock-agentcore-control' service." + ) + + # *_and_wait methods + # ------------------------------------------------------------------------- + def create_gateway_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Create a gateway and wait for it to reach READY status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the create_gateway API. + + Returns: + Gateway details when READY. + + Raises: + RuntimeError: If the gateway reaches a failed state. + TimeoutError: If the gateway doesn't become READY within max_wait. + """ + response = self.cp_client.create_gateway(**convert_kwargs(kwargs)) + gw_id = response["gatewayId"] + return wait_until( + lambda: self.cp_client.get_gateway(gatewayIdentifier=gw_id), + "READY", + _GATEWAY_FAILED_STATUSES, + wait_config, + ) + + def update_gateway_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Update a gateway and wait for it to reach READY status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the update_gateway API. + + Returns: + Gateway details when READY. + + Raises: + RuntimeError: If the gateway reaches a failed state. + TimeoutError: If the gateway doesn't become READY within max_wait. + """ + response = self.cp_client.update_gateway(**convert_kwargs(kwargs)) + gw_id = response["gatewayId"] + return wait_until( + lambda: self.cp_client.get_gateway(gatewayIdentifier=gw_id), + "READY", + _GATEWAY_FAILED_STATUSES, + wait_config, + ) + + def create_gateway_target_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Create a gateway target and wait for it to reach READY status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the create_gateway_target API. + Must include gatewayIdentifier. + + Returns: + Gateway target details when READY. + + Raises: + RuntimeError: If the target reaches a failed state. + TimeoutError: If the target doesn't become READY within max_wait. + """ + response = self.cp_client.create_gateway_target(**convert_kwargs(kwargs)) + gw_id = response["gatewayArn"].rsplit("/", 1)[-1] + target_id = response["targetId"] + return wait_until( + lambda: self.cp_client.get_gateway_target( + gatewayIdentifier=gw_id, + targetId=target_id, + ), + "READY", + _TARGET_FAILED_STATUSES, + wait_config, + ) + + def update_gateway_target_and_wait(self, wait_config: Optional[WaitConfig] = None, **kwargs) -> Dict[str, Any]: + """Update a gateway target and wait for it to reach READY status. + + Args: + wait_config: Optional WaitConfig for polling behavior (default: max_wait=300, poll_interval=10). + **kwargs: Arguments forwarded to the update_gateway_target API. + Must include gatewayIdentifier and targetId. + + Returns: + Gateway target details when READY. + + Raises: + RuntimeError: If the target reaches a failed state. + TimeoutError: If the target doesn't become READY within max_wait. + """ + response = self.cp_client.update_gateway_target(**convert_kwargs(kwargs)) + gw_id = response["gatewayArn"].rsplit("/", 1)[-1] + target_id = response["targetId"] + return wait_until( + lambda: self.cp_client.get_gateway_target( + gatewayIdentifier=gw_id, + targetId=target_id, + ), + "READY", + _TARGET_FAILED_STATUSES, + wait_config, + ) + + def delete_gateway_and_wait( + self, + wait_config: Optional[WaitConfig] = None, + **kwargs, + ) -> None: + """Delete a gateway and wait for deletion to complete. + + Args: + wait_config: Optional WaitConfig for polling behavior. + **kwargs: Arguments forwarded to the delete_gateway API. + + Raises: + TimeoutError: If the gateway isn't deleted within max_wait. + """ + response = self.cp_client.delete_gateway(**convert_kwargs(kwargs)) + gw_id = response["gatewayId"] + wait_until_deleted( + lambda: self.cp_client.get_gateway(gatewayIdentifier=gw_id), + wait_config=wait_config, + ) + + def delete_gateway_target_and_wait( + self, + wait_config: Optional[WaitConfig] = None, + **kwargs, + ) -> None: + """Delete a gateway target and wait for deletion to complete. + + Args: + wait_config: Optional WaitConfig for polling behavior. + **kwargs: Arguments forwarded to the delete_gateway_target API. + + Raises: + TimeoutError: If the target isn't deleted within max_wait. + """ + response = self.cp_client.delete_gateway_target(**convert_kwargs(kwargs)) + gw_id = response["gatewayArn"].rsplit("/", 1)[-1] + target_id = response["targetId"] + wait_until_deleted( + lambda: self.cp_client.get_gateway_target( + gatewayIdentifier=gw_id, + targetId=target_id, + ), + wait_config=wait_config, + ) + + # Name-based lookup + # ------------------------------------------------------------------------- + def get_gateway_by_name(self, name: str, **kwargs) -> Optional[Dict[str, Any]]: + """Look up a gateway by name. + + Paginates through gateways and returns the full resource details + for the first match. Short-circuits on first match without fetching + remaining pages. Returns None if no gateway with that name exists. + + Args: + name: The gateway name to search for. + **kwargs: Additional arguments forwarded to the list_gateways API. + + Returns: + Gateway details from get_gateway, or None if not found. + """ + params = convert_kwargs(kwargs) + params.pop("nextToken", None) + while True: + response = self.cp_client.list_gateways(**params) + for gw in response.get("items", []): + if gw.get("name") == name: + return self.cp_client.get_gateway( + gatewayIdentifier=gw["gatewayId"], + ) + if not response.get("nextToken"): + return None + params["nextToken"] = response["nextToken"] + + def get_gateway_target_by_name(self, gateway_identifier: str, name: str, **kwargs) -> Optional[Dict[str, Any]]: + """Look up a gateway target by name. + + Paginates through targets for the given gateway and returns the + full resource details for the first match. Short-circuits on first + match without fetching remaining pages. Returns None if not found. + + Args: + gateway_identifier: Gateway ID or ARN. + name: The target name to search for. + **kwargs: Additional arguments forwarded to the list_gateway_targets API. + + Returns: + Gateway target details from get_gateway_target, or None if not found. + """ + params = convert_kwargs(kwargs) + params.pop("nextToken", None) + params["gatewayIdentifier"] = gateway_identifier + while True: + response = self.cp_client.list_gateway_targets(**params) + for target in response.get("items", []): + if target.get("name") == name: + return self.cp_client.get_gateway_target( + gatewayIdentifier=gateway_identifier, + targetId=target["targetId"], + ) + if not response.get("nextToken"): + return None + params["nextToken"] = response["nextToken"] diff --git a/tests/unit/gateway/__init__.py b/tests/unit/gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/gateway/test_gateway_client.py b/tests/unit/gateway/test_gateway_client.py new file mode 100644 index 00000000..fcee0b96 --- /dev/null +++ b/tests/unit/gateway/test_gateway_client.py @@ -0,0 +1,132 @@ +"""Tests for GatewayClient.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore.gateway.client import GatewayClient + + +class TestGatewayClientInit: + def test_init_with_region(self): + mock_session = MagicMock() + mock_session.region_name = "eu-west-1" + client = GatewayClient(region_name="us-west-2", boto3_session=mock_session) + assert client.region_name == "us-west-2" + + def test_init_default_region_fallback(self): + mock_session = MagicMock() + mock_session.region_name = None + client = GatewayClient(boto3_session=mock_session) + assert client.region_name == "us-west-2" + + +class TestGatewayClientPassthrough: + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = GatewayClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_cp_method_forwarded(self): + client = self._make_client() + client.cp_client.create_gateway.return_value = {"gatewayId": "gw-123"} + result = client.create_gateway(name="test") + client.cp_client.create_gateway.assert_called_once_with(name="test") + assert result["gatewayId"] == "gw-123" + + def test_snake_case_kwargs_converted(self): + client = self._make_client() + client.cp_client.get_gateway.return_value = {"gatewayId": "gw-123"} + client.get_gateway(gateway_identifier="gw-123") + client.cp_client.get_gateway.assert_called_once_with(gatewayIdentifier="gw-123") + + def test_non_allowlisted_method_raises(self): + client = self._make_client() + with pytest.raises(AttributeError): + client.not_a_real_method() + + +class TestGatewayAndWait: + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = GatewayClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_create_gateway_and_wait(self): + client = self._make_client() + client.cp_client.create_gateway.return_value = {"gatewayId": "gw-123"} + client.cp_client.get_gateway.return_value = {"status": "READY", "gatewayId": "gw-123"} + + result = client.create_gateway_and_wait(name="test") + assert result["status"] == "READY" + + def test_create_gateway_and_wait_failed(self): + client = self._make_client() + client.cp_client.create_gateway.return_value = {"gatewayId": "gw-123"} + client.cp_client.get_gateway.return_value = { + "status": "FAILED", + "statusReasons": ["bad config"], + } + + with pytest.raises(RuntimeError, match="FAILED"): + client.create_gateway_and_wait(name="test") + + def test_create_gateway_target_and_wait(self): + client = self._make_client() + client.cp_client.create_gateway_target.return_value = { + "gatewayArn": "arn:gw", + "targetId": "t-123", + } + client.cp_client.get_gateway_target.return_value = { + "status": "READY", + "targetId": "t-123", + } + + result = client.create_gateway_target_and_wait(gatewayIdentifier="gw-1") + assert result["status"] == "READY" + + @patch("bedrock_agentcore._utils.polling.time.sleep") + def test_delete_gateway_and_wait(self, _mock_sleep): + client = self._make_client() + client.cp_client.delete_gateway.return_value = {"gatewayId": "gw-123"} + client.cp_client.get_gateway.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "gone"}}, + "GetGateway", + ) + + client.delete_gateway_and_wait(gatewayIdentifier="gw-123") + client.cp_client.delete_gateway.assert_called_once() + + +class TestGatewayNameLookup: + def _make_client(self): + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + client = GatewayClient(boto3_session=mock_session) + client.cp_client = Mock() + return client + + def test_get_gateway_by_name_found(self): + client = self._make_client() + client.cp_client.list_gateways.return_value = { + "items": [{"name": "my-gw", "gatewayId": "gw-123"}], + } + client.cp_client.get_gateway.return_value = { + "gatewayId": "gw-123", + "name": "my-gw", + } + + result = client.get_gateway_by_name(name="my-gw") + assert result["gatewayId"] == "gw-123" + + def test_get_gateway_by_name_not_found(self): + client = self._make_client() + client.cp_client.list_gateways.return_value = {"items": []} + + result = client.get_gateway_by_name(name="nonexistent") + assert result is None diff --git a/tests_integ/gateway/__init__.py b/tests_integ/gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests_integ/gateway/test_gateway_client.py b/tests_integ/gateway/test_gateway_client.py new file mode 100644 index 00000000..730af424 --- /dev/null +++ b/tests_integ/gateway/test_gateway_client.py @@ -0,0 +1,271 @@ +"""Integration tests for GatewayClient. + +Requires environment variables: + BEDROCK_TEST_REGION: AWS region (default: us-west-2) + GATEWAY_ROLE_ARN: IAM role ARN with AgentCore gateway trust policy +""" + +import os +import time + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore.gateway.client import GatewayClient + + +@pytest.mark.integration +class TestGatewayClient: + """Integration tests for GatewayClient CRUD and wait methods.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.role_arn = os.environ.get("GATEWAY_ROLE_ARN") + if not cls.role_arn: + pytest.fail("GATEWAY_ROLE_ARN must be set") + cls.client = GatewayClient(region_name=cls.region) + cls.test_prefix = f"sdk-integ-{int(time.time())}" + cls.gateway_ids = [] + + @classmethod + def teardown_class(cls): + for gw_id in cls.gateway_ids: + try: + cls.client.delete_gateway(gatewayIdentifier=gw_id) + except Exception as e: + print(f"Failed to delete gateway {gw_id}: {e}") + + @pytest.mark.order(1) + def test_create_gateway_and_wait(self): + gw = self.client.create_gateway_and_wait( + name=f"{self.test_prefix}-gw", + roleArn=self.role_arn, + authorizerType="NONE", + protocolType="MCP", + ) + self.__class__.gateway_ids.append(gw["gatewayId"]) + assert gw["status"] == "READY" + assert gw["name"] == f"{self.test_prefix}-gw" + + @pytest.mark.order(2) + def test_get_gateway_passthrough(self): + if not self.gateway_ids: + pytest.skip("prerequisite test did not create gateway") + gw = self.client.get_gateway( + gatewayIdentifier=self.gateway_ids[0], + ) + assert gw["status"] == "READY" + + @pytest.mark.order(3) + def test_get_gateway_snake_case(self): + if not self.gateway_ids: + pytest.skip("prerequisite test did not create gateway") + gw = self.client.get_gateway( + gateway_identifier=self.gateway_ids[0], + ) + assert gw["status"] == "READY" + + @pytest.mark.order(4) + def test_get_gateway_by_name(self): + if not self.gateway_ids: + pytest.skip("prerequisite test did not create gateway") + gw = self.client.get_gateway_by_name( + name=f"{self.test_prefix}-gw", + ) + assert gw is not None + assert gw["gatewayId"] == self.gateway_ids[0] + + @pytest.mark.order(5) + def test_get_gateway_by_name_not_found(self): + result = self.client.get_gateway_by_name( + name="nonexistent-gateway-name", + ) + assert result is None + + @pytest.mark.order(6) + def test_list_gateways_passthrough(self): + gateways = self.client.list_gateways() + assert "items" in gateways + + @pytest.mark.order(7) + def test_update_gateway_and_wait(self): + if not self.gateway_ids: + pytest.skip("prerequisite test did not create gateway") + updated = self.client.update_gateway_and_wait( + gatewayIdentifier=self.gateway_ids[0], + name=f"{self.test_prefix}-gw", + roleArn=self.role_arn, + authorizerType="NONE", + description="updated by integ test", + ) + assert updated["status"] == "READY" + assert updated.get("description") == "updated by integ test" + + @pytest.mark.order(8) + def test_delete_gateway_and_wait(self): + if not self.gateway_ids: + pytest.skip("prerequisite test did not create gateway") + gw_id = self.gateway_ids.pop(0) + self.client.delete_gateway_and_wait( + gatewayIdentifier=gw_id, + ) + with pytest.raises(ClientError): + self.client.get_gateway(gatewayIdentifier=gw_id) + + +@pytest.mark.integration +class TestGatewayTargetClient: + """Integration tests for gateway target CRUD. + + Requires GATEWAY_LAMBDA_ARN in addition to GATEWAY_ROLE_ARN. + """ + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.role_arn = os.environ.get("GATEWAY_ROLE_ARN") + cls.lambda_arn = os.environ.get("GATEWAY_LAMBDA_ARN") + if not cls.role_arn or not cls.lambda_arn: + pytest.fail("GATEWAY_ROLE_ARN and GATEWAY_LAMBDA_ARN must be set") + cls.client = GatewayClient(region_name=cls.region) + cls.test_prefix = f"sdk-integ-tgt-{int(time.time())}" + cls.gateway_id = None + cls.target_ids = [] + + # Create a gateway for target tests + gw = cls.client.create_gateway_and_wait( + name=f"{cls.test_prefix}-gw", + roleArn=cls.role_arn, + authorizerType="NONE", + protocolType="MCP", + ) + cls.gateway_id = gw["gatewayId"] + + @classmethod + def teardown_class(cls): + if cls.gateway_id: + for target_id in cls.target_ids: + try: + cls.client.delete_gateway_target( + gatewayIdentifier=cls.gateway_id, + targetId=target_id, + ) + except Exception as e: + print(f"Failed to delete target {target_id}: {e}") + try: + cls.client.delete_gateway_and_wait( + gatewayIdentifier=cls.gateway_id, + ) + except Exception as e: + print(f"Failed to delete gateway {cls.gateway_id}: {e}") + + @pytest.mark.order(9) + def test_create_gateway_target_and_wait(self): + target = self.client.create_gateway_target_and_wait( + gatewayIdentifier=self.gateway_id, + name=f"{self.test_prefix}-target", + targetConfiguration={ + "mcp": { + "lambda": { + "lambdaArn": self.lambda_arn, + "toolSchema": { + "inlinePayload": [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"type": "object"}, + } + ] + }, + } + }, + }, + credentialProviderConfigurations=[ + {"credentialProviderType": "GATEWAY_IAM_ROLE"}, + ], + ) + self.__class__.target_ids.append(target["targetId"]) + assert target["status"] == "READY" + + @pytest.mark.order(10) + def test_get_gateway_target_passthrough(self): + if not self.target_ids: + pytest.skip("prerequisite test did not create target") + target = self.client.get_gateway_target( + gatewayIdentifier=self.gateway_id, + targetId=self.target_ids[0], + ) + assert target["status"] == "READY" + + @pytest.mark.order(11) + def test_get_gateway_target_by_name(self): + if not self.target_ids: + pytest.skip("prerequisite test did not create target") + target = self.client.get_gateway_target_by_name( + gateway_identifier=self.gateway_id, + name=f"{self.test_prefix}-target", + ) + assert target is not None + assert target["targetId"] == self.target_ids[0] + + @pytest.mark.order(12) + def test_list_gateway_targets_passthrough(self): + targets = self.client.list_gateway_targets( + gatewayIdentifier=self.gateway_id, + ) + assert "items" in targets + + @pytest.mark.order(13) + def test_update_gateway_target_and_wait(self): + if not self.target_ids: + pytest.skip("prerequisite test did not create target") + updated = self.client.update_gateway_target_and_wait( + gatewayIdentifier=self.gateway_id, + targetId=self.target_ids[0], + name=f"{self.test_prefix}-target", + targetConfiguration={ + "mcp": { + "lambda": { + "lambdaArn": self.lambda_arn, + "toolSchema": { + "inlinePayload": [ + { + "name": "test_tool", + "description": "An updated test tool", + "inputSchema": {"type": "object"}, + } + ] + }, + } + }, + }, + credentialProviderConfigurations=[ + {"credentialProviderType": "GATEWAY_IAM_ROLE"}, + ], + description="updated by integ test", + ) + assert updated["status"] == "READY" + + @pytest.mark.order(14) + def test_get_gateway_target_by_name_not_found(self): + result = self.client.get_gateway_target_by_name( + gateway_identifier=self.gateway_id, + name="nonexistent-target-name", + ) + assert result is None + + @pytest.mark.order(15) + def test_delete_gateway_target_and_wait(self): + if not self.target_ids: + pytest.skip("prerequisite test did not create target") + target_id = self.target_ids.pop(0) + self.client.delete_gateway_target_and_wait( + gatewayIdentifier=self.gateway_id, + targetId=target_id, + ) + with pytest.raises(ClientError): + self.client.get_gateway_target( + gatewayIdentifier=self.gateway_id, + targetId=target_id, + )