diff --git a/CHANGELOG.md b/CHANGELOG.md index af1a1e7c..67965e6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [Unreleased] + +### Added +- feat(runtime): add HTTP invocation methods (`invoke`, `invoke_streaming`, `invoke_streaming_async`) and `InvokeAgentRuntimeCommand` shell-exec support (`execute_command`, `execute_command_streaming`) to `AgentCoreRuntimeClient`, plus `stop_runtime_session`. Methods take `runtime_arn` and `bearer_token` per-call, matching `generate_ws_connection_oauth`. The `urllib3.PoolManager` is lazy-initialized so SigV4 URL-generation users pay no cost. +- feat(runtime): add `AgentRuntimeError` exception raised by the HTTP methods on non-2xx responses and in-band SSE error events. + ## [1.6.3] - 2026-04-16 ### Fixed diff --git a/src/bedrock_agentcore/runtime/__init__.py b/src/bedrock_agentcore/runtime/__init__.py index a08bbc93..5763b7ab 100644 --- a/src/bedrock_agentcore/runtime/__init__.py +++ b/src/bedrock_agentcore/runtime/__init__.py @@ -4,15 +4,18 @@ - BedrockAgentCoreApp: Main application class - RequestContext: HTTP request context - BedrockAgentCoreContext: Agent identity context +- AgentCoreRuntimeClient: Authentication + HTTP invocation client for deployed runtimes +- AgentRuntimeError: Exception raised by the HTTP invocation methods """ -from .agent_core_runtime_client import AgentCoreRuntimeClient +from .agent_core_runtime_client import AgentCoreRuntimeClient, AgentRuntimeError from .app import BedrockAgentCoreApp from .context import BedrockAgentCoreContext, RequestContext from .models import PingStatus __all__ = [ "AgentCoreRuntimeClient", + "AgentRuntimeError", "AGUIApp", "BedrockAgentCoreApp", "BedrockCallContextBuilder", @@ -29,7 +32,12 @@ def __getattr__(name: str): """Lazy imports for A2A and AG-UI symbols so optional dependencies are not required at import time.""" - _a2a_exports = {"BedrockCallContextBuilder", "build_a2a_app", "build_runtime_url", "serve_a2a"} + _a2a_exports = { + "BedrockCallContextBuilder", + "build_a2a_app", + "build_runtime_url", + "serve_a2a", + } if name in _a2a_exports: from . import a2a as _a2a_module diff --git a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py index ab91de5a..106df03a 100644 --- a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py +++ b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py @@ -1,34 +1,110 @@ -"""Client for generating WebSocket authentication for AgentCore Runtime. +"""Client for AgentCore Runtime authentication and invocation. -This module provides a client for generating authentication credentials -for WebSocket connections to AgentCore Runtime endpoints. +This module provides a single client for interacting with Bedrock AgentCore +Runtime endpoints: + +- WebSocket URL and header generation (SigV4, SigV4 presigned, OAuth bearer). +- HTTP invocation (blocking, sync streaming, async streaming) with bearer-token + auth. +- ``InvokeAgentRuntimeCommand`` shell-exec support (blocking and streaming). +- Session termination. + +Bearer-token HTTP methods mirror the shape of +:meth:`generate_ws_connection_oauth` — ``runtime_arn`` and ``bearer_token`` +are passed per call. The ``urllib3.PoolManager`` used by HTTP methods is +lazy-initialized; callers who only use URL-generation methods pay no cost. """ +import asyncio import base64 import datetime +import json import logging import secrets +import threading import uuid -from typing import Dict, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Generator, + Iterator, + Optional, + Tuple, +) from urllib.parse import quote, urlencode, urlparse import boto3 from botocore.auth import SigV4Auth, SigV4QueryAuth from botocore.awsrequest import AWSRequest +from botocore.eventstream import EventStreamBuffer from .._utils.endpoints import get_data_plane_endpoint from .utils import is_valid_partition +if TYPE_CHECKING: + # urllib3 v2 returns ``BaseHTTPResponse`` from ``PoolManager.request``. + # On urllib3 v1.26 the returned object is ``HTTPResponse`` which is + # structurally compatible with the same attributes (``status``, ``headers``, + # ``read``, ``stream``, ``release_conn``), so the annotation is purely for + # type-checking under v2. + import urllib3 + from urllib3 import BaseHTTPResponse as _UrllibResponse + DEFAULT_PRESIGNED_URL_TIMEOUT = 300 MAX_PRESIGNED_URL_TIMEOUT = 300 +DEFAULT_HTTP_TIMEOUT = 300 +DEFAULT_COMMAND_TIMEOUT = 600 + + +class AgentRuntimeError(Exception): + """Raised when an AgentCore runtime returns an error response. + + Used by the HTTP invocation methods of :class:`AgentCoreRuntimeClient` + for both non-2xx HTTP responses and in-band error events embedded in + SSE streams. + + Attributes: + error: Short machine-readable error token (for example the + runtime's ``error`` field, or the HTTP status text). + error_type: Category label. For HTTP failures this is + ``"HTTP "`` (e.g. ``"HTTP 404"``). For runtime error + payloads it is whatever the server set on the ``error_type`` + field. + """ + + def __init__(self, error: str, error_type: str = "", message: str = "") -> None: + """Initialize the exception. + + Args: + error: Short error token (see :attr:`error`). + error_type: Category label (see :attr:`error_type`). Defaults + to the empty string. + message: Human-readable message used as the exception's + string representation. Defaults to ``error`` when empty. + """ + self.error = error + self.error_type = error_type + super().__init__(message or error) class AgentCoreRuntimeClient: - """Client for generating WebSocket authentication for AgentCore Runtime. + """Client for AgentCore Runtime authentication and invocation. + + This client supports four auth/transport modes against AgentCore Runtime: + + - SigV4-signed WebSocket URL + headers (:meth:`generate_ws_connection`). + - SigV4 presigned WebSocket URL (:meth:`generate_presigned_url`). + - OAuth bearer-token WebSocket URL + headers + (:meth:`generate_ws_connection_oauth`). + - OAuth bearer-token HTTP invocation (:meth:`invoke`, + :meth:`invoke_streaming`, :meth:`invoke_streaming_async`, + :meth:`execute_command`, :meth:`execute_command_streaming`, + :meth:`stop_runtime_session`). - This client provides authentication credentials for WebSocket connections - to AgentCore Runtime endpoints, allowing applications to establish - bidirectional streaming connections with agent runtimes. + The ``urllib3.PoolManager`` used by the HTTP methods is lazy-initialized + on first access; callers who only use URL-generation methods pay no cost. Attributes: region (str): The AWS region being used. @@ -53,6 +129,24 @@ def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None session = boto3.Session() self.session = session + self._pool_manager: Optional["urllib3.PoolManager"] = None + + @property + def _http(self) -> "urllib3.PoolManager": + """Return a lazy-initialized ``urllib3.PoolManager`` shared across HTTP calls. + + The pool is not created until an HTTP method (``invoke``, + ``execute_command``, etc.) is actually called. Callers that only use + the SigV4 / OAuth URL-generation methods never trigger this import. + + Returns: + A shared :class:`urllib3.PoolManager` instance. + """ + if self._pool_manager is None: + import urllib3 + + self._pool_manager = urllib3.PoolManager() + return self._pool_manager def _parse_runtime_arn(self, runtime_arn: str) -> Dict[str, str]: """Parse runtime ARN and extract components. @@ -72,7 +166,11 @@ def _parse_runtime_arn(self, runtime_arn: str) -> Dict[str, str]: if len(parts) != 6: raise ValueError(f"Invalid runtime ARN format: {runtime_arn}") - if parts[0] != "arn" or not is_valid_partition(parts[1]) or parts[2] != "bedrock-agentcore": + if ( + parts[0] != "arn" + or not is_valid_partition(parts[1]) + or parts[2] != "bedrock-agentcore" + ): raise ValueError(f"Invalid runtime ARN format: {runtime_arn}") # Parse the resource part (runtime/{runtime_id}) @@ -138,6 +236,88 @@ def _build_websocket_url( return ws_url + def _build_http_url( + self, + runtime_arn: str, + path_suffix: str, + endpoint_name: Optional[str] = None, + ) -> str: + """Build an HTTPS URL for a data-plane API on a runtime. + + Reuses :meth:`_parse_runtime_arn` to extract the region from the ARN + and :func:`get_data_plane_endpoint` to construct the host, matching + the convention established by the WebSocket URL builder. + + Args: + runtime_arn: Full runtime ARN. + path_suffix: Path under ``/runtimes/{arn}/``. Must not start with + ``/``. Examples: ``"invocations"``, ``"commands"``, + ``"stopruntimesession"``. + endpoint_name: Endpoint qualifier sent as the ``qualifier`` query + parameter. Defaults to ``"DEFAULT"`` when not supplied so the + API receives a valid qualifier value. + + Returns: + Absolute URL including the ``?qualifier=...`` query string. + + Raises: + ValueError: If the ARN format is invalid. + """ + parsed = self._parse_runtime_arn(runtime_arn) + base = get_data_plane_endpoint(parsed["region"]) + encoded_arn = quote(runtime_arn, safe="") + query = urlencode({"qualifier": endpoint_name or "DEFAULT"}) + return f"{base}/runtimes/{encoded_arn}/{path_suffix}?{query}" + + def _build_bearer_headers( + self, + bearer_token: str, + session_id: str, + accept: str, + content_type: str, + ) -> Dict[str, str]: + """Build request headers for a bearer-authenticated HTTP call. + + Args: + bearer_token: OAuth/JWT bearer token. + session_id: Runtime session id. + accept: Value for the ``Accept`` header. + content_type: Value for the ``Content-Type`` header. + + Returns: + Header dict including ``Authorization``, ``Content-Type``, + ``Accept``, and ``X-Amzn-Bedrock-AgentCore-Runtime-Session-Id``. + """ + return { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": content_type, + "Accept": accept, + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, + } + + @staticmethod + def _serialize_body(body: Any, content_type: str) -> bytes: + """Serialize a request body to bytes based on the content type. + + JSON content types serialize via :func:`json.dumps`. For other + content types, ``bytes`` pass through unchanged, ``str`` is + UTF-8-encoded, and anything else falls back to JSON serialization. + + Args: + body: The payload to send. + content_type: The MIME type of the request. + + Returns: + UTF-8 encoded request body. + """ + if "json" in content_type: + return json.dumps(body).encode("utf-8") + if isinstance(body, bytes): + return body + if isinstance(body, str): + return body.encode("utf-8") + return json.dumps(body).encode("utf-8") + def generate_ws_connection( self, runtime_arn: str, @@ -200,7 +380,9 @@ def generate_ws_connection( url=https_url, headers={ "host": host, - "x-amz-date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ"), + "x-amz-date": datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y%m%dT%H%M%SZ" + ), }, ) @@ -225,7 +407,9 @@ def generate_ws_connection( if frozen_credentials.token: headers["X-Amz-Security-Token"] = frozen_credentials.token - self.logger.info("✓ WebSocket connection credentials generated (Session: %s)", session_id) + self.logger.info( + "✓ WebSocket connection credentials generated (Session: %s)", session_id + ) return ws_url, headers def generate_presigned_url( @@ -273,7 +457,9 @@ def generate_presigned_url( # Validate expires parameter if expires > MAX_PRESIGNED_URL_TIMEOUT: - raise ValueError(f"Expiry timeout cannot exceed {MAX_PRESIGNED_URL_TIMEOUT} seconds, got {expires}") + raise ValueError( + f"Expiry timeout cannot exceed {MAX_PRESIGNED_URL_TIMEOUT} seconds, got {expires}" + ) # Validate ARN self._parse_runtime_arn(runtime_arn) @@ -305,7 +491,9 @@ def generate_presigned_url( frozen_credentials = credentials.get_frozen_credentials() # Create the request to sign - request = AWSRequest(method="GET", url=https_url, headers={"host": url.hostname}) + request = AWSRequest( + method="GET", url=https_url, headers={"host": url.hostname} + ) # Sign the request with SigV4QueryAuth signer = SigV4QueryAuth( @@ -322,7 +510,11 @@ def generate_presigned_url( # Convert back to wss:// for WebSocket connection presigned_url = request.url.replace("https://", "wss://") - self.logger.info("✓ Presigned URL generated (expires in %s seconds, Session: %s)", expires, session_id) + self.logger.info( + "✓ Presigned URL generated (expires in %s seconds, Session: %s)", + expires, + session_id, + ) return presigned_url def generate_ws_connection_oauth( @@ -397,7 +589,607 @@ def generate_ws_connection_oauth( "User-Agent": "OAuth-WebSocket-Client/1.0", } - self.logger.info("✓ OAuth WebSocket connection credentials generated (Session: %s)", session_id) + self.logger.info( + "✓ OAuth WebSocket connection credentials generated (Session: %s)", + session_id, + ) self.logger.debug("Bearer token length: %d characters", len(bearer_token)) return ws_url, headers + + # ------------------------------------------------------------------ # + # HTTP invocation (bearer auth) + # ------------------------------------------------------------------ # + + def invoke( + self, + runtime_arn: str, + bearer_token: str, + body: Any, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + content_type: str = "application/json", + accept: str = "application/json", + ) -> str: + """Invoke the runtime over HTTP and return the full response body. + + Handles both JSON and Server-Sent Events responses transparently. For + SSE responses the decoded event data is concatenated in arrival order. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + body: Request payload (JSON-serializable for JSON content types, + ``bytes`` or ``str`` for others). + session_id: Runtime session id. Auto-generated as a UUID4 when + not supplied, matching :meth:`generate_ws_connection_oauth`. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"`` on + the wire when not supplied. + headers: Optional extra headers merged over the auth headers. + timeout: HTTP read timeout in seconds. + content_type: Request ``Content-Type``. + accept: Request ``Accept``. + + Returns: + The complete response body as a string. For JSON responses + whose body is a JSON-encoded string, the unwrapped string is + returned. + + Raises: + AgentRuntimeError: On non-2xx responses or in-band SSE error + events. + ValueError: If the ARN format is invalid. + """ + if not session_id: + session_id = str(uuid.uuid4()) + url = self._build_http_url(runtime_arn, "invocations", endpoint_name) + request_headers = self._build_bearer_headers( + bearer_token, session_id, accept, content_type + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + url, + headers=request_headers, + body=self._serialize_body(body, content_type), + timeout=timeout, + preload_content=False, + ) + try: + self._check_response(response) + response_content_type = response.headers.get("content-type", "") + if "text/event-stream" not in response_content_type: + return self._read_non_streaming(response) + return "".join(self._iter_sse_decoded(response)) + finally: + response.release_conn() + + def invoke_streaming( + self, + runtime_arn: str, + bearer_token: str, + body: Any, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + content_type: str = "application/json", + accept: str = "application/json", + ) -> Generator[str, None, None]: + """Invoke the runtime and yield SSE chunks as they arrive. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + body: Request payload. + session_id: Runtime session id. Auto-generated if not supplied. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + timeout: HTTP read timeout in seconds. + content_type: Request ``Content-Type``. + accept: Request ``Accept``. + + Yields: + Decoded payload strings from the SSE stream. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the stream. + ValueError: If the ARN format is invalid. + """ + if not session_id: + session_id = str(uuid.uuid4()) + url = self._build_http_url(runtime_arn, "invocations", endpoint_name) + request_headers = self._build_bearer_headers( + bearer_token, session_id, accept, content_type + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + url, + headers=request_headers, + body=self._serialize_body(body, content_type), + timeout=timeout, + preload_content=False, + ) + try: + self._check_response(response) + yield from self._iter_sse_decoded(response) + finally: + response.release_conn() + + async def invoke_streaming_async( + self, + runtime_arn: str, + bearer_token: str, + body: Any, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + content_type: str = "application/json", + accept: str = "application/json", + ) -> AsyncGenerator[str, None]: + """Async generator version of :meth:`invoke_streaming`. + + The underlying HTTP call is blocking; this wrapper runs it on a + background thread and delivers chunks through an :class:`asyncio.Queue`, + so it is safe to ``async for`` over. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + body: Request payload. + session_id: Runtime session id. Auto-generated if not supplied. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + timeout: HTTP read timeout in seconds. + content_type: Request ``Content-Type``. + accept: Request ``Accept``. + + Yields: + Decoded payload strings from the SSE stream. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the stream. + ValueError: If the ARN format is invalid. + """ + chunk_queue: asyncio.Queue = asyncio.Queue() + sentinel = object() + loop = asyncio.get_running_loop() + + def stream_in_thread() -> None: + try: + for decoded in self.invoke_streaming( + runtime_arn, + bearer_token, + body, + session_id=session_id, + endpoint_name=endpoint_name, + headers=headers, + timeout=timeout, + content_type=content_type, + accept=accept, + ): + loop.call_soon_threadsafe(chunk_queue.put_nowait, decoded) + loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) + except Exception as exc: # noqa: BLE001 — propagated to caller + loop.call_soon_threadsafe(chunk_queue.put_nowait, exc) + loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) + + thread = threading.Thread(target=stream_in_thread, daemon=True) + thread.start() + + while True: + item = await chunk_queue.get() + if item is sentinel: + break + if isinstance(item, Exception): + raise item + yield item + await asyncio.sleep(0) + + # ------------------------------------------------------------------ # + # execute_command / execute_command_streaming + # ------------------------------------------------------------------ # + + def execute_command( + self, + runtime_arn: str, + bearer_token: str, + command: str, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + command_timeout: int = DEFAULT_COMMAND_TIMEOUT, + ) -> Dict[str, Any]: + """Run a shell command inside the runtime session and collect the full result. + + Backed by the ``InvokeAgentRuntimeCommand`` API. Blocking; accumulates + all ``stdout`` and ``stderr`` chunks from the EventStream and returns + the final exit status. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + command: Shell command to run (1 B – 64 KB per the + ``InvokeAgentRuntimeCommand`` API). + session_id: Runtime session id. Auto-generated if not supplied. + The filesystem inside the container persists across calls in + the same session, but a fresh shell is spawned each time, so + working directory and environment variables do not. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + command_timeout: Server-side command wall-clock timeout in + seconds (1–3600). The HTTP read timeout is derived internally + as ``command_timeout + 30``. + + Returns: + Dict with keys ``"stdout"`` (str), ``"stderr"`` (str), + ``"exitCode"`` (int, ``-1`` if no ``contentStop`` was received), + and ``"status"`` (``"COMPLETED"``, ``"TIMED_OUT"``, or + ``"UNKNOWN"``). + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status. + ValueError: If the ARN format is invalid. + """ + stdout_parts: list = [] + stderr_parts: list = [] + exit_code: int = -1 + status: str = "UNKNOWN" + + for event in self.execute_command_streaming( + runtime_arn, + bearer_token, + command, + session_id=session_id, + endpoint_name=endpoint_name, + headers=headers, + command_timeout=command_timeout, + ): + if "contentDelta" in event: + delta = event["contentDelta"] + if "stdout" in delta: + stdout_parts.append(delta["stdout"]) + if "stderr" in delta: + stderr_parts.append(delta["stderr"]) + elif "contentStop" in event: + exit_code = int(event["contentStop"].get("exitCode", -1)) + status = str(event["contentStop"].get("status", "UNKNOWN")) + + return { + "stdout": "".join(stdout_parts), + "stderr": "".join(stderr_parts), + "exitCode": exit_code, + "status": status, + } + + def execute_command_streaming( + self, + runtime_arn: str, + bearer_token: str, + command: str, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + command_timeout: int = DEFAULT_COMMAND_TIMEOUT, + ) -> Generator[Dict[str, Any], None, None]: + """Stream AWS EventStream events from ``InvokeAgentRuntimeCommand``. + + Yields the decoded event payloads (the value inside the server's + ``"chunk"`` envelope). Each yielded dict has exactly one of the keys + ``"contentStart"``, ``"contentDelta"``, or ``"contentStop"``. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + command: Shell command to run. + session_id: Runtime session id. Auto-generated if not supplied. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + command_timeout: Server-side wall-clock timeout in seconds + (1–3600). + + Yields: + Parsed event payload dicts. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status. + ValueError: If the ARN format is invalid. + """ + if not session_id: + session_id = str(uuid.uuid4()) + + request_headers = self._build_bearer_headers( + bearer_token, + session_id, + accept="application/vnd.amazon.eventstream", + content_type="application/json", + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + self._build_http_url(runtime_arn, "commands", endpoint_name), + headers=request_headers, + body=json.dumps({"command": command, "timeout": command_timeout}).encode( + "utf-8" + ), + timeout=command_timeout + 30, + preload_content=False, + ) + try: + self._check_response(response) + + buf = EventStreamBuffer() + for chunk in response.stream(4096): + if not chunk: + continue + buf.add_data(chunk) + for event in buf: + payload = event.payload + if not payload: + continue + try: + decoded = json.loads(payload) + except json.JSONDecodeError: + continue + inner = decoded.get("chunk") if isinstance(decoded, dict) else None + yield inner if isinstance(inner, dict) else decoded + finally: + response.release_conn() + + # ------------------------------------------------------------------ # + # stop_runtime_session + # ------------------------------------------------------------------ # + + def stop_runtime_session( + self, + runtime_arn: str, + bearer_token: str, + *, + session_id: str, + endpoint_name: Optional[str] = None, + client_token: Optional[str] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + ) -> Dict[str, Any]: + """Terminate a runtime session via HTTP. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + session_id: The session id to stop. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + client_token: Idempotency token. Auto-generated as a UUID4 + when not supplied. + timeout: HTTP read timeout in seconds. + + Returns: + Parsed JSON body of the response (often an empty dict). + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + (for example ``HTTP 404`` for an unknown session). + ValueError: If the ARN format is invalid. + """ + request_headers = { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": "application/json", + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, + } + response = self._http.request( + "POST", + self._build_http_url(runtime_arn, "stopruntimesession", endpoint_name), + headers=request_headers, + body=json.dumps({"clientToken": client_token or str(uuid.uuid4())}).encode( + "utf-8" + ), + timeout=timeout, + preload_content=True, + ) + self._check_response(response) + if not response.data: + return {} + try: + parsed = json.loads(response.data) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {"response": parsed} + + # ------------------------------------------------------------------ # + # Response parsing helpers + # ------------------------------------------------------------------ # + + @staticmethod + def _check_response(response: "_UrllibResponse") -> None: + """Raise :class:`AgentRuntimeError` for non-2xx responses. + + Attempts to parse the body as JSON and surface ``error``, + ``error_type``, and ``message`` fields. Falls back to the raw body + text. + + Args: + response: The urllib3 response to inspect. + + Raises: + AgentRuntimeError: If the status code is 400 or above. + """ + if response.status < 400: + return + + body_bytes = response.read() + try: + body: Any = json.loads(body_bytes) + except (json.JSONDecodeError, ValueError): + body = body_bytes.decode("utf-8", errors="replace") + + error_type = f"HTTP {response.status}" + reason = response.reason or error_type + + if isinstance(body, dict): + raise AgentRuntimeError( + error=str(body.get("error", reason)), + error_type=str(body.get("error_type", error_type)), + message=str( + body.get("message", body_bytes.decode("utf-8", errors="replace")) + ), + ) + raise AgentRuntimeError( + error=str(body) or reason, + error_type=error_type, + ) + + @staticmethod + def _read_non_streaming(response: "_UrllibResponse") -> str: + """Read a fully-buffered response body and unwrap JSON strings. + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Returns: + The response body as text. If the body is a JSON-encoded + string (e.g. ``"hello"``), the unwrapped value is returned. + """ + data = response.read() + text = data.decode("utf-8", errors="replace") + try: + parsed = json.loads(text) + except (json.JSONDecodeError, ValueError): + return text + return parsed if isinstance(parsed, str) else text + + @classmethod + def _iter_sse_decoded(cls, response: "_UrllibResponse") -> Iterator[str]: + """Iterate decoded SSE payloads from a streaming response. + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Yields: + Decoded payload strings (one per non-empty ``data:`` line + or JSON line). + """ + for raw_line in cls._iter_lines(response): + if not raw_line: + continue + decoded = cls._decode_sse_line(raw_line.decode("utf-8", errors="replace")) + if decoded: + yield decoded + + @staticmethod + def _iter_lines(response: "_UrllibResponse") -> Iterator[bytes]: + r"""Yield lines from a streaming urllib3 response. + + Splits on ``\n`` and strips a trailing ``\r`` so it handles both + LF and CRLF line endings. Preserves empty lines (the caller + filters them). + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Yields: + Each line as bytes (with no trailing newline). + """ + pending = b"" + for chunk in response.stream(1024): + if not chunk: + continue + pending += chunk + while True: + idx = pending.find(b"\n") + if idx < 0: + break + line, pending = pending[:idx], pending[idx + 1 :] + if line.endswith(b"\r"): + line = line[:-1] + yield line + if pending: + if pending.endswith(b"\r"): + pending = pending[:-1] + yield pending + + @staticmethod + def _decode_sse_line(line: str) -> Optional[str]: + """Decode a single SSE or JSON-Lines line. + + Handles SSE ``data:`` lines, JSON error envelopes, JSON-encoded + strings, and plain text. Error envelopes (with an ``error`` key) + are raised as :class:`AgentRuntimeError`. + + Args: + line: A single line of the streamed response. + + Returns: + The decoded payload, or ``None`` if the line is a comment, + empty, or carries no renderable text. + + Raises: + AgentRuntimeError: If the line carries a JSON error payload. + """ + line = line.strip() + if not line or line.startswith(":"): + return None + + if line.startswith("data:"): + content = line[5:].strip() + + if content.startswith("{"): + try: + data = json.loads(content) + if isinstance(data, dict) and "error" in data: + raise AgentRuntimeError( + error=str(data["error"]), + error_type=str(data.get("error_type", "")), + message=str(data.get("message", data["error"])), + ) + except json.JSONDecodeError: + pass + + if content.startswith('"'): + try: + unwrapped = json.loads(content) + except json.JSONDecodeError: + if content.endswith('"'): + return content[1:-1] + return content + return unwrapped if isinstance(unwrapped, str) else content + return content + + try: + data = json.loads(line) + except json.JSONDecodeError: + return line + + if isinstance(data, dict): + if "error" in data: + raise AgentRuntimeError( + error=str(data["error"]), + error_type=str(data.get("error_type", "")), + message=str(data.get("message", data["error"])), + ) + if "text" in data: + value = data["text"] + return value if isinstance(value, str) else str(value) + if "content" in data: + value = data["content"] + return value if isinstance(value, str) else str(value) + if "data" in data: + return str(data["data"]) + return None diff --git a/tests/bedrock_agentcore/test_init.py b/tests/bedrock_agentcore/test_init.py index f757719c..4c3ea081 100644 --- a/tests/bedrock_agentcore/test_init.py +++ b/tests/bedrock_agentcore/test_init.py @@ -7,7 +7,10 @@ def test_getattr_raises_for_unknown_attribute(): """Test that __getattr__ raises AttributeError for unknown attributes.""" import bedrock_agentcore - with pytest.raises(AttributeError, match="module 'bedrock_agentcore' has no attribute 'UnknownAttribute'"): + with pytest.raises( + AttributeError, + match="module 'bedrock_agentcore' has no attribute 'UnknownAttribute'", + ): _ = bedrock_agentcore.UnknownAttribute diff --git a/tests/unit/runtime/test_agent_core_runtime_client.py b/tests/unit/runtime/test_agent_core_runtime_client.py index 4cb33085..62d86c96 100644 --- a/tests/unit/runtime/test_agent_core_runtime_client.py +++ b/tests/unit/runtime/test_agent_core_runtime_client.py @@ -1,11 +1,16 @@ """Tests for AgentCoreRuntimeClient.""" -from unittest.mock import Mock, patch +import json +from typing import Iterator, Optional +from unittest.mock import MagicMock, Mock, patch from urllib.parse import quote import pytest -from bedrock_agentcore.runtime.agent_core_runtime_client import AgentCoreRuntimeClient +from bedrock_agentcore.runtime.agent_core_runtime_client import ( + AgentCoreRuntimeClient, + AgentRuntimeError, +) class TestAgentCoreRuntimeClientInit: @@ -28,7 +33,9 @@ class TestParseRuntimeArn: def test_parse_valid_arn(self): """Test parsing a valid runtime ARN.""" client = AgentCoreRuntimeClient(region="us-west-2") - arn = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc123" + arn = ( + "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc123" + ) result = client._parse_runtime_arn(arn) @@ -99,7 +106,9 @@ def test_parse_empty_runtime_id_raises_error(self): class TestBuildWebsocketUrl: """Tests for _build_websocket_url helper.""" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_basic_url(self, mock_endpoint): """Test building basic WebSocket URL without query params.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -112,7 +121,9 @@ def test_build_basic_url(self, mock_endpoint): encoded_arn = quote(runtime_arn, safe="") assert result == f"wss://example.aws.dev/runtimes/{encoded_arn}/ws" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_url_with_endpoint_name(self, mock_endpoint): """Test building URL with endpoint name (qualifier param).""" mock_endpoint.return_value = "https://example.aws.dev" @@ -122,30 +133,41 @@ def test_build_url_with_endpoint_name(self, mock_endpoint): result = client._build_websocket_url(runtime_arn, endpoint_name="DEFAULT") encoded_arn = quote(runtime_arn, safe="") - assert result == f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?qualifier=DEFAULT" - - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + assert ( + result + == f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?qualifier=DEFAULT" + ) + + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_url_with_custom_headers(self, mock_endpoint): """Test building URL with custom headers as query params.""" mock_endpoint.return_value = "https://example.aws.dev" client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - result = client._build_websocket_url(runtime_arn, custom_headers={"abc": "pqr", "foo": "bar"}) + result = client._build_websocket_url( + runtime_arn, custom_headers={"abc": "pqr", "foo": "bar"} + ) encoded_arn = quote(runtime_arn, safe="") assert f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?" in result assert "abc=pqr" in result assert "foo=bar" in result - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_url_with_all_params(self, mock_endpoint): """Test building URL with endpoint name and custom headers.""" mock_endpoint.return_value = "https://example.aws.dev" client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - result = client._build_websocket_url(runtime_arn, endpoint_name="DEFAULT", custom_headers={"abc": "pqr"}) + result = client._build_websocket_url( + runtime_arn, endpoint_name="DEFAULT", custom_headers={"abc": "pqr"} + ) encoded_arn = quote(runtime_arn, safe="") assert f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?" in result @@ -157,13 +179,17 @@ class TestGenerateWsConnection: """Tests for generate_ws_connection method.""" @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_basic_connection(self, mock_endpoint, mock_session): """Test generating basic WebSocket connection.""" # Setup mocks mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -185,32 +211,44 @@ def test_generate_basic_connection(self, mock_endpoint, mock_session): assert "Sec-WebSocket-Key" in headers @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_connection_with_session_id(self, mock_endpoint, mock_session): """Test generating connection with explicit session ID.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - ws_url, headers = client.generate_ws_connection(runtime_arn, session_id="test-session-123") + ws_url, headers = client.generate_ws_connection( + runtime_arn, session_id="test-session-123" + ) assert ws_url is not None assert headers is not None # Verify session ID is in headers assert "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id" in headers - assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == "test-session-123" + assert ( + headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == "test-session-123" + ) @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_connection_user_agent(self, mock_endpoint, mock_session): """Test that User-Agent header is set correctly.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -222,18 +260,24 @@ def test_generate_connection_user_agent(self, mock_endpoint, mock_session): assert headers["User-Agent"] == "AgentCoreRuntimeClient/1.0" @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_connection_with_endpoint_name(self, mock_endpoint, mock_session): """Test generating connection with endpoint name.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - ws_url, headers = client.generate_ws_connection(runtime_arn, endpoint_name="DEFAULT") + ws_url, headers = client.generate_ws_connection( + runtime_arn, endpoint_name="DEFAULT" + ) assert "qualifier=DEFAULT" in ws_url @@ -253,12 +297,16 @@ class TestGeneratePresignedUrl: """Tests for generate_presigned_url method.""" @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_basic_presigned_url(self, mock_endpoint, mock_session): """Test generating basic presigned URL.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -279,61 +327,92 @@ def test_generate_basic_presigned_url(self, mock_endpoint, mock_session): assert "X-Amz-Signature" in presigned_url @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") - def test_generate_presigned_url_with_endpoint_name(self, mock_endpoint, mock_session): + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) + def test_generate_presigned_url_with_endpoint_name( + self, mock_endpoint, mock_session + ): """Test generating presigned URL with endpoint name.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - presigned_url = client.generate_presigned_url(runtime_arn, endpoint_name="DEFAULT") + presigned_url = client.generate_presigned_url( + runtime_arn, endpoint_name="DEFAULT" + ) assert "qualifier=DEFAULT" in presigned_url @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") - def test_generate_presigned_url_with_custom_headers(self, mock_endpoint, mock_session): + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) + def test_generate_presigned_url_with_custom_headers( + self, mock_endpoint, mock_session + ): """Test generating presigned URL with custom headers.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - presigned_url = client.generate_presigned_url(runtime_arn, custom_headers={"abc": "pqr"}) + presigned_url = client.generate_presigned_url( + runtime_arn, custom_headers={"abc": "pqr"} + ) assert "abc=pqr" in presigned_url @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_presigned_url_with_session_id(self, mock_endpoint, mock_session): """Test generating presigned URL with explicit session ID.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - presigned_url = client.generate_presigned_url(runtime_arn, session_id="test-session-456") + presigned_url = client.generate_presigned_url( + runtime_arn, session_id="test-session-456" + ) # Verify session ID is in query params - assert "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id=test-session-456" in presigned_url + assert ( + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id=test-session-456" + in presigned_url + ) @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") - def test_generate_presigned_url_with_custom_expires(self, mock_endpoint, mock_session): + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) + def test_generate_presigned_url_with_custom_expires( + self, mock_endpoint, mock_session + ): """Test generating presigned URL with custom expiration.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -387,7 +466,9 @@ def test_init_without_session_creates_default(self, mock_session_class): assert client.session == mock_session mock_session_class.assert_called_once() - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_ws_connection_uses_custom_session(self, mock_endpoint): """Test that generate_ws_connection uses the custom session.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -395,7 +476,9 @@ def test_generate_ws_connection_uses_custom_session(self, mock_endpoint): # Create custom session with credentials custom_session = Mock() mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) custom_session.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2", session=custom_session) @@ -412,7 +495,9 @@ def test_generate_ws_connection_uses_custom_session(self, mock_endpoint): class TestGenerateWsConnectionOAuth: """Tests for generate_ws_connection_oauth method.""" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_oauth_connection_basic(self, mock_endpoint): """Test generating basic OAuth WebSocket connection.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -434,7 +519,9 @@ def test_generate_oauth_connection_basic(self, mock_endpoint): assert "Sec-WebSocket-Version" in headers assert headers["Sec-WebSocket-Version"] == "13" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_oauth_connection_with_session_id(self, mock_endpoint): """Test generating OAuth connection with explicit session ID.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -443,11 +530,17 @@ def test_generate_oauth_connection_with_session_id(self, mock_endpoint): bearer_token = "test-token" custom_session_id = "custom-oauth-session-123" - ws_url, headers = client.generate_ws_connection_oauth(runtime_arn, bearer_token, session_id=custom_session_id) + ws_url, headers = client.generate_ws_connection_oauth( + runtime_arn, bearer_token, session_id=custom_session_id + ) - assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == custom_session_id + assert ( + headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == custom_session_id + ) - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_oauth_connection_with_endpoint_name(self, mock_endpoint): """Test generating OAuth connection with endpoint name.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -455,7 +548,9 @@ def test_generate_oauth_connection_with_endpoint_name(self, mock_endpoint): runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" bearer_token = "test-token" - ws_url, headers = client.generate_ws_connection_oauth(runtime_arn, bearer_token, endpoint_name="DEFAULT") + ws_url, headers = client.generate_ws_connection_oauth( + runtime_arn, bearer_token, endpoint_name="DEFAULT" + ) assert "qualifier=DEFAULT" in ws_url @@ -475,3 +570,842 @@ def test_generate_oauth_connection_invalid_arn_raises_error(self): with pytest.raises(ValueError, match="Invalid runtime ARN format"): client.generate_ws_connection_oauth(invalid_arn, bearer_token) + + +# ===================================================================== # +# HTTP invocation tests (bearer-token auth) +# ===================================================================== # + +ARN = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc" +BEARER = "test-token" +SESSION = "a" * 36 + + +class _FakeResponse: + """Minimal stand-in for urllib3's response used by the HTTP methods. + + Exposes the exact surface the client touches: ``status``, ``reason``, + ``headers``, ``data``, ``read``, ``stream``, and ``release_conn``. + """ + + def __init__( + self, + status: int = 200, + headers: Optional[dict] = None, + body: bytes = b"", + chunks: Optional[list] = None, + reason: str = "", + ) -> None: + self.status = status + self.reason = reason + self.headers = headers or {} + self._body = body + self._chunks = chunks + self._consumed = False + self.release_calls = 0 + + @property + def data(self) -> bytes: + return self._body + + def read(self) -> bytes: + if self._consumed: + return b"" + self._consumed = True + return self._body + + def stream(self, amt: int = 1024) -> Iterator[bytes]: + if self._chunks is not None: + for chunk in self._chunks: + yield chunk + return + if self._body: + yield self._body + + def release_conn(self) -> None: + self.release_calls += 1 + + +def _make_client(response: _FakeResponse) -> "tuple[AgentCoreRuntimeClient, MagicMock]": + """Build a client whose lazy PoolManager is bypassed with a mock.""" + pool = MagicMock() + pool.request.return_value = response + client = AgentCoreRuntimeClient(region="us-west-2") + client._pool_manager = pool # bypass the lazy property + return client, pool + + +# --------------------------------------------------------------------- # +# Lazy PoolManager +# --------------------------------------------------------------------- # + + +class TestLazyPoolManager: + """Tests for the @property _http lazy PoolManager.""" + + def test_pool_not_created_at_init(self) -> None: + """Constructor does not create a PoolManager.""" + client = AgentCoreRuntimeClient(region="us-west-2") + assert client._pool_manager is None + + def test_pool_created_on_first_access(self) -> None: + """First access to _http creates the PoolManager.""" + import urllib3 + + client = AgentCoreRuntimeClient(region="us-west-2") + pool = client._http + assert isinstance(pool, urllib3.PoolManager) + assert client._pool_manager is pool + + def test_pool_reused_across_accesses(self) -> None: + """Repeated _http access returns the same instance.""" + client = AgentCoreRuntimeClient(region="us-west-2") + first = client._http + second = client._http + assert first is second + + +# --------------------------------------------------------------------- # +# _build_http_url / _build_bearer_headers / _serialize_body +# --------------------------------------------------------------------- # + + +class TestBuildHttpUrl: + """Tests for _build_http_url.""" + + def test_invocations_url(self) -> None: + """URL embeds region (via endpoint helper), URL-encoded ARN, path, and qualifier.""" + client = AgentCoreRuntimeClient(region="us-west-2") + url = client._build_http_url(ARN, "invocations") + assert url.startswith( + "https://bedrock-agentcore.us-west-2.amazonaws.com/runtimes/" + ) + assert "/invocations?qualifier=DEFAULT" in url + assert "arn%3Aaws%3Abedrock-agentcore" in url + + def test_commands_url(self) -> None: + """Different path suffix yields the commands endpoint.""" + client = AgentCoreRuntimeClient(region="us-west-2") + assert "/commands?qualifier=DEFAULT" in client._build_http_url(ARN, "commands") + + def test_non_default_qualifier(self) -> None: + """Qualifier reflects the endpoint_name argument.""" + client = AgentCoreRuntimeClient(region="us-west-2") + assert "qualifier=DEV" in client._build_http_url( + ARN, "invocations", endpoint_name="DEV" + ) + + def test_region_derived_from_arn(self) -> None: + """URL uses the ARN region, not the client's init region.""" + client = AgentCoreRuntimeClient(region="us-east-1") + arn = "arn:aws:bedrock-agentcore:eu-west-2:123456789012:runtime/other" + assert ( + "https://bedrock-agentcore.eu-west-2.amazonaws.com/" + in client._build_http_url(arn, "invocations") + ) + + def test_invalid_arn_raises(self) -> None: + """Invalid ARN propagates from _parse_runtime_arn.""" + client = AgentCoreRuntimeClient(region="us-west-2") + with pytest.raises(ValueError, match="Invalid runtime ARN format"): + client._build_http_url("not-an-arn", "invocations") + + +class TestBuildBearerHeaders: + """Tests for _build_bearer_headers.""" + + def test_all_headers_populated(self) -> None: + """All four bearer-auth headers are set.""" + client = AgentCoreRuntimeClient(region="us-west-2") + headers = client._build_bearer_headers( + BEARER, SESSION, accept="application/json", content_type="application/json" + ) + assert headers["Authorization"] == f"Bearer {BEARER}" + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == SESSION + + def test_eventstream_accept(self) -> None: + """Accept value is passed through verbatim.""" + client = AgentCoreRuntimeClient(region="us-west-2") + headers = client._build_bearer_headers( + BEARER, + SESSION, + accept="application/vnd.amazon.eventstream", + content_type="application/json", + ) + assert headers["Accept"] == "application/vnd.amazon.eventstream" + + +class TestSerializeBody: + """Tests for _serialize_body.""" + + def test_json_dict(self) -> None: + """Dict body is JSON-serialized when content type is JSON.""" + assert ( + AgentCoreRuntimeClient._serialize_body({"a": 1}, "application/json") + == b'{"a": 1}' + ) + + def test_bytes_passthrough_non_json(self) -> None: + """Bytes body is sent verbatim for non-JSON content types.""" + assert ( + AgentCoreRuntimeClient._serialize_body(b"raw", "application/octet-stream") + == b"raw" + ) + + def test_str_utf8_encoded_non_json(self) -> None: + """String body is UTF-8 encoded for non-JSON content types.""" + assert AgentCoreRuntimeClient._serialize_body( + "héllo", "text/plain" + ) == "héllo".encode("utf-8") + + def test_fallback_to_json(self) -> None: + """Non-str, non-bytes body falls back to JSON for non-JSON content types.""" + assert ( + AgentCoreRuntimeClient._serialize_body({"x": 1}, "application/cbor") + == b'{"x": 1}' + ) + + +# --------------------------------------------------------------------- # +# invoke (non-streaming JSON, non-streaming plain, SSE) +# --------------------------------------------------------------------- # + + +class TestInvoke: + """Tests for invoke.""" + + def test_non_streaming_json_string(self) -> None: + """JSON string response is unwrapped.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b'"hello"', + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) == "hello" + ) + assert resp.release_calls == 1 + + def test_non_streaming_json_object_returns_text(self) -> None: + """Non-string JSON response returns the raw body text.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b'{"answer": 42}', + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) + == '{"answer": 42}' + ) + + def test_non_streaming_invalid_json_returns_text(self) -> None: + """If the body isn't valid JSON, the raw text is returned.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b"plain-text-not-json", + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) + == "plain-text-not-json" + ) + + def test_sse_streaming_concatenates(self) -> None: + """Multiple SSE data lines are concatenated in order.""" + body = b"data: hello\ndata: world\n" + resp = _FakeResponse( + status=200, + headers={"content-type": "text/event-stream"}, + chunks=[body], + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) + == "helloworld" + ) + + def test_custom_headers_merged(self) -> None: + """Caller-supplied headers are merged into the request.""" + resp = _FakeResponse( + status=200, headers={"content-type": "application/json"}, body=b'"ok"' + ) + client, pool = _make_client(resp) + client.invoke(ARN, BEARER, {}, session_id=SESSION, headers={"X-Test": "1"}) + sent_headers = pool.request.call_args.kwargs["headers"] + assert sent_headers["X-Test"] == "1" + assert sent_headers["Authorization"] == f"Bearer {BEARER}" + + def test_non_ok_raises(self) -> None: + """4xx and 5xx responses become AgentRuntimeError.""" + resp = _FakeResponse( + status=500, + body=b'{"error": "oops", "message": "boom"}', + reason="Server Error", + ) + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError) as excinfo: + client.invoke(ARN, BEARER, {}, session_id=SESSION) + assert excinfo.value.error == "oops" + assert "boom" in str(excinfo.value) + + def test_session_id_auto_generated(self) -> None: + """Missing session_id is auto-generated as a UUID.""" + resp = _FakeResponse( + status=200, headers={"content-type": "application/json"}, body=b'"ok"' + ) + client, pool = _make_client(resp) + client.invoke(ARN, BEARER, {}) + sent_headers = pool.request.call_args.kwargs["headers"] + # UUID4 string is 36 chars. + assert len(sent_headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"]) == 36 + + +# --------------------------------------------------------------------- # +# invoke_streaming +# --------------------------------------------------------------------- # + + +class TestInvokeStreaming: + """Tests for invoke_streaming.""" + + def test_yields_chunks(self) -> None: + """Yields each decoded SSE payload in order.""" + body = b"data: first\ndata: second\n" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[body] + ) + client, _ = _make_client(resp) + assert list(client.invoke_streaming(ARN, BEARER, {}, session_id=SESSION)) == [ + "first", + "second", + ] + + def test_non_ok_raises(self) -> None: + """Errors are surfaced before the generator yields.""" + resp = _FakeResponse(status=404, body=b"not found", reason="Not Found") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + list(client.invoke_streaming(ARN, BEARER, {}, session_id=SESSION)) + + def test_custom_headers_merged(self) -> None: + """Custom headers are merged into the request.""" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[b""] + ) + client, pool = _make_client(resp) + list( + client.invoke_streaming( + ARN, BEARER, {}, session_id=SESSION, headers={"X-Custom": "v"} + ) + ) + assert pool.request.call_args.kwargs["headers"]["X-Custom"] == "v" + + +# --------------------------------------------------------------------- # +# invoke_streaming_async +# --------------------------------------------------------------------- # + + +class TestInvokeStreamingAsync: + """Tests for invoke_streaming_async.""" + + @pytest.mark.asyncio + async def test_yields_async_chunks(self) -> None: + """Async generator yields the same chunks as the sync version.""" + body = b"data: a\ndata: b\n" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[body] + ) + client, _ = _make_client(resp) + out = [] + async for chunk in client.invoke_streaming_async( + ARN, BEARER, {}, session_id=SESSION + ): + out.append(chunk) + assert out == ["a", "b"] + + @pytest.mark.asyncio + async def test_propagates_exception(self) -> None: + """Exceptions from the background thread surface to the caller.""" + resp = _FakeResponse(status=500, body=b"err", reason="Server Error") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + async for _ in client.invoke_streaming_async( + ARN, BEARER, {}, session_id=SESSION + ): + pass + + +# --------------------------------------------------------------------- # +# execute_command / execute_command_streaming +# --------------------------------------------------------------------- # + + +def _encode_eventstream_frame(payload: dict) -> bytes: + """Encode a JSON payload using the EventStream binary framing. + + Builds a valid frame the client's EventStreamBuffer can parse. Format: + total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4). + """ + import binascii + import struct + + body = json.dumps(payload).encode("utf-8") + headers = b"" + headers_length = len(headers) + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + return message_bytes + message_crc + + +class TestExecuteCommand: + """Tests for execute_command (blocking).""" + + def test_aggregates_output(self) -> None: + """stdout/stderr/exitCode/status are aggregated across events.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentStart": {}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "hi "}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "there"}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stderr": "warn"}}}), + _encode_eventstream_frame( + {"chunk": {"contentStop": {"exitCode": 0, "status": "COMPLETED"}}} + ), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + result = client.execute_command( + ARN, BEARER, "echo hi there", session_id=SESSION + ) + assert result == { + "stdout": "hi there", + "stderr": "warn", + "exitCode": 0, + "status": "COMPLETED", + } + + def test_missing_stop_has_unknown_status(self) -> None: + """Without a contentStop event, defaults are UNKNOWN / -1.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "x"}}}), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + result = client.execute_command(ARN, BEARER, "echo x", session_id=SESSION) + assert result == { + "stdout": "x", + "stderr": "", + "exitCode": -1, + "status": "UNKNOWN", + } + + +class TestExecuteCommandStreaming: + """Tests for execute_command_streaming.""" + + def test_sends_command_body(self) -> None: + """Request body contains command and timeout.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[], + ) + client, pool = _make_client(resp) + list( + client.execute_command_streaming( + ARN, BEARER, "ls", session_id=SESSION, command_timeout=42 + ) + ) + sent = pool.request.call_args + assert sent.kwargs["body"] == b'{"command": "ls", "timeout": 42}' + assert sent.kwargs["timeout"] == 72 # command_timeout + 30 + + def test_defaults_to_constant(self) -> None: + """Omitting command_timeout uses DEFAULT_COMMAND_TIMEOUT.""" + from bedrock_agentcore.runtime.agent_core_runtime_client import ( + DEFAULT_COMMAND_TIMEOUT, + ) + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[], + ) + client, pool = _make_client(resp) + list(client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION)) + sent = pool.request.call_args + body = json.loads(sent.kwargs["body"]) + assert body["timeout"] == DEFAULT_COMMAND_TIMEOUT + assert sent.kwargs["timeout"] == DEFAULT_COMMAND_TIMEOUT + 30 + + def test_yields_parsed_events(self) -> None: + """Each EventStream frame is yielded as a dict.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentStart": {}}}), + _encode_eventstream_frame( + {"chunk": {"contentStop": {"exitCode": 2, "status": "COMPLETED"}}} + ), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + parsed = list( + client.execute_command_streaming(ARN, BEARER, "echo hi", session_id=SESSION) + ) + assert parsed[0] == {"contentStart": {}} + assert parsed[1] == {"contentStop": {"exitCode": 2, "status": "COMPLETED"}} + + def test_skips_event_without_payload(self) -> None: + """Events with empty payload are skipped silently.""" + import binascii + import struct + + headers = b"" + body = b"" + headers_length = 0 + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + empty_frame = message_bytes + message_crc + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[empty_frame], + ) + client, _ = _make_client(resp) + assert ( + list( + client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION) + ) + == [] + ) + + def test_skips_bad_json_payload(self) -> None: + """Events whose payload is not valid JSON are skipped.""" + import binascii + import struct + + headers = b"" + body = b"not-json" + headers_length = 0 + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + bad_frame = message_bytes + message_crc + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[bad_frame], + ) + client, _ = _make_client(resp) + assert ( + list( + client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION) + ) + == [] + ) + + def test_non_ok_raises(self) -> None: + """Non-2xx responses raise AgentRuntimeError.""" + resp = _FakeResponse(status=403, body=b"forbidden", reason="Forbidden") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + list( + client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION) + ) + + +# --------------------------------------------------------------------- # +# stop_runtime_session +# --------------------------------------------------------------------- # + + +class TestStopRuntimeSession: + """Tests for stop_runtime_session.""" + + def test_sends_client_token(self) -> None: + """Request body contains the provided client token.""" + resp = _FakeResponse(status=200, body=b"{}") + client, pool = _make_client(resp) + client.stop_runtime_session(ARN, BEARER, session_id=SESSION, client_token="abc") + sent = json.loads(pool.request.call_args.kwargs["body"]) + assert sent == {"clientToken": "abc"} + + def test_autogenerates_client_token(self) -> None: + """Missing client_token is replaced with a UUID4.""" + resp = _FakeResponse(status=200, body=b"{}") + client, pool = _make_client(resp) + client.stop_runtime_session(ARN, BEARER, session_id=SESSION) + sent = json.loads(pool.request.call_args.kwargs["body"]) + assert len(sent["clientToken"]) == 36 + + def test_returns_empty_dict_on_blank_body(self) -> None: + """Empty response body yields an empty dict.""" + resp = _FakeResponse(status=200, body=b"") + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == {} + + def test_returns_parsed_dict(self) -> None: + """Dict body is returned as-is.""" + resp = _FakeResponse(status=200, body=b'{"sessionId": "x"}') + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == { + "sessionId": "x" + } + + def test_non_dict_body_wrapped(self) -> None: + """A JSON body that isn't a dict is wrapped under 'response'.""" + resp = _FakeResponse(status=200, body=b'["a", "b"]') + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == { + "response": ["a", "b"] + } + + def test_invalid_json_body_returns_empty(self) -> None: + """Invalid JSON body returns empty dict rather than raising.""" + resp = _FakeResponse(status=200, body=b"not json") + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == {} + + def test_404_raises(self) -> None: + """Unknown-session HTTP 404 becomes AgentRuntimeError.""" + resp = _FakeResponse(status=404, body=b"{}", reason="Not Found") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError) as excinfo: + client.stop_runtime_session(ARN, BEARER, session_id=SESSION) + assert excinfo.value.error_type == "HTTP 404" + + +# --------------------------------------------------------------------- # +# _check_response +# --------------------------------------------------------------------- # + + +class TestCheckResponse: + """Tests for _check_response.""" + + def test_noop_on_2xx(self) -> None: + """2xx responses do not raise.""" + AgentCoreRuntimeClient._check_response(_FakeResponse(status=200)) + + def test_parses_json_error_body(self) -> None: + """JSON error bodies populate error/error_type/message.""" + resp = _FakeResponse( + status=400, + body=b'{"error": "bad", "error_type": "Validation", "message": "details"}', + reason="Bad Request", + ) + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert excinfo.value.error == "bad" + assert excinfo.value.error_type == "Validation" + + def test_falls_back_to_text_body(self) -> None: + """Non-JSON bodies become the error string.""" + resp = _FakeResponse(status=500, body=b"internal boom", reason="Server Error") + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert "internal boom" in excinfo.value.error + assert excinfo.value.error_type == "HTTP 500" + + def test_defaults_error_type_for_dict_body(self) -> None: + """Missing error_type in JSON body defaults to 'HTTP '.""" + resp = _FakeResponse(status=502, body=b'{"message": "bad gateway"}') + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert excinfo.value.error_type == "HTTP 502" + + def test_empty_body_uses_reason(self) -> None: + """Empty body falls back to the response reason string.""" + resp = _FakeResponse(status=503, body=b"", reason="Service Unavailable") + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert excinfo.value.error == "Service Unavailable" + + +# --------------------------------------------------------------------- # +# _decode_sse_line — all branches +# --------------------------------------------------------------------- # + + +class TestDecodeSseLine: + """Tests for _decode_sse_line.""" + + def test_empty_line(self) -> None: + """Empty / whitespace-only lines return None.""" + assert AgentCoreRuntimeClient._decode_sse_line("") is None + assert AgentCoreRuntimeClient._decode_sse_line(" ") is None + + def test_comment_line(self) -> None: + """SSE comment lines (starting with ':') return None.""" + assert AgentCoreRuntimeClient._decode_sse_line(": keepalive") is None + + def test_data_with_json_encoded_string(self) -> None: + """data: "hello" -> hello (unwrapped).""" + assert AgentCoreRuntimeClient._decode_sse_line('data: "hello"') == "hello" + + def test_data_with_plain_text(self) -> None: + """data: plain -> plain.""" + assert AgentCoreRuntimeClient._decode_sse_line("data: plain") == "plain" + + def test_data_with_malformed_json_string(self) -> None: + """data: "bad" missing closing quote: strip what we can.""" + assert AgentCoreRuntimeClient._decode_sse_line('data: "oops') == '"oops' + + def test_data_with_malformed_quoted_string(self) -> None: + """data: \"bad\\escape\" with invalid escape yields trimmed string.""" + assert ( + AgentCoreRuntimeClient._decode_sse_line('data: "bad\\escape"') + == "bad\\escape" + ) + + def test_data_with_json_error_raises(self) -> None: + """data: {"error": ...} raises AgentRuntimeError.""" + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._decode_sse_line( + 'data: {"error": "boom", "error_type": "T"}' + ) + assert excinfo.value.error == "boom" + + def test_data_with_json_object_no_error(self) -> None: + """data: {...} without an error key passes through as raw content.""" + assert ( + AgentCoreRuntimeClient._decode_sse_line('data: {"foo": "bar"}') + == '{"foo": "bar"}' + ) + + def test_data_with_broken_json_passes_through(self) -> None: + """data: { broken passes through as the content after 'data:'.""" + assert AgentCoreRuntimeClient._decode_sse_line("data: {not json") == "{not json" + + def test_plain_json_text_key(self) -> None: + """Plain JSON line with 'text' key returns that value.""" + assert AgentCoreRuntimeClient._decode_sse_line('{"text": "hi"}') == "hi" + + def test_plain_json_content_key(self) -> None: + """Plain JSON line with 'content' key returns that value.""" + assert AgentCoreRuntimeClient._decode_sse_line('{"content": "c"}') == "c" + + def test_plain_json_data_key(self) -> None: + """Plain JSON line with 'data' key returns str(that value).""" + assert AgentCoreRuntimeClient._decode_sse_line('{"data": 42}') == "42" + + def test_plain_json_error_raises(self) -> None: + """Plain JSON line with 'error' raises.""" + with pytest.raises(AgentRuntimeError): + AgentCoreRuntimeClient._decode_sse_line('{"error": "bad"}') + + def test_plain_json_unknown_shape_returns_none(self) -> None: + """Plain JSON dict without known keys returns None.""" + assert AgentCoreRuntimeClient._decode_sse_line('{"something": "else"}') is None + + def test_plain_text_passes_through(self) -> None: + """Non-JSON plain text is returned verbatim.""" + assert AgentCoreRuntimeClient._decode_sse_line("hello world") == "hello world" + + +# --------------------------------------------------------------------- # +# _iter_lines +# --------------------------------------------------------------------- # + + +class TestIterLines: + """Tests for _iter_lines chunk splitting.""" + + def test_splits_on_lf(self) -> None: + """Splits on newline and yields each line.""" + resp = _FakeResponse(chunks=[b"one\ntwo\nthree\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"one", b"two", b"three"] + + def test_strips_crlf(self) -> None: + r"""Trailing \r is stripped.""" + resp = _FakeResponse(chunks=[b"a\r\nb\r\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"a", b"b"] + + def test_trailing_without_newline(self) -> None: + """A trailing fragment without newline is still yielded.""" + resp = _FakeResponse(chunks=[b"done"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"done"] + + def test_trailing_cr_stripped(self) -> None: + """A trailing CR without LF is stripped.""" + resp = _FakeResponse(chunks=[b"done\r"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"done"] + + def test_split_across_chunks(self) -> None: + """Lines spanning chunk boundaries are reassembled.""" + resp = _FakeResponse(chunks=[b"hel", b"lo\nwo", b"rld\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"hello", b"world"] + + def test_empty_chunk_skipped(self) -> None: + """Empty chunks in the stream are ignored.""" + resp = _FakeResponse(chunks=[b"a\n", b"", b"b\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"a", b"b"] + + +# --------------------------------------------------------------------- # +# AgentRuntimeError +# --------------------------------------------------------------------- # + + +class TestAgentRuntimeError: + """Tests for AgentRuntimeError construction and rendering.""" + + def test_stores_error_and_type(self) -> None: + """error and error_type are accessible as attributes.""" + exc = AgentRuntimeError(error="boom", error_type="T") + assert exc.error == "boom" + assert exc.error_type == "T" + + def test_default_error_type_empty(self) -> None: + """error_type defaults to the empty string.""" + exc = AgentRuntimeError(error="boom") + assert exc.error_type == "" + + def test_str_uses_message(self) -> None: + """str() prefers the message when provided.""" + exc = AgentRuntimeError(error="e", message="human-readable") + assert str(exc) == "human-readable" + + def test_str_falls_back_to_error(self) -> None: + """str() falls back to the error token when no message.""" + exc = AgentRuntimeError(error="only-error") + assert str(exc) == "only-error" + + def test_is_exception(self) -> None: + """Subclasses Exception so it can be raised.""" + try: + raise AgentRuntimeError(error="x") + except Exception as exc: + assert isinstance(exc, AgentRuntimeError)