From 7820df5940c4c27f0da0a0477c335dfddcba6994 Mon Sep 17 00:00:00 2001 From: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:11:16 -0400 Subject: [PATCH 1/2] feat(runtime): add AgentCoreRuntimeHttpClient for HTTP invocation and execute_command --- CHANGELOG.md | 6 + src/bedrock_agentcore/runtime/__init__.py | 15 +- .../runtime/agent_core_runtime_http_client.py | 733 ++++++++++++++++ tests/bedrock_agentcore/test_init.py | 5 +- .../test_agent_core_runtime_http_client.py | 823 ++++++++++++++++++ 5 files changed, 1580 insertions(+), 2 deletions(-) create mode 100644 src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py create mode 100644 tests/unit/runtime/test_agent_core_runtime_http_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index af1a1e7c..c3b918ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [Unreleased] + +### Added +- feat(runtime): add AgentCoreRuntimeHttpClient for bearer-token HTTP invocation, SSE streaming, and the InvokeAgentRuntimeCommand shell-exec API (alongside existing AgentCoreRuntimeClient for WebSocket URLs) +- feat(runtime): add AgentRuntimeError exception raised by the HTTP client on non-2xx responses and in-band SSE error events + ## [1.6.3] - 2026-04-16 ### Fixed diff --git a/src/bedrock_agentcore/runtime/__init__.py b/src/bedrock_agentcore/runtime/__init__.py index a08bbc93..df47df85 100644 --- a/src/bedrock_agentcore/runtime/__init__.py +++ b/src/bedrock_agentcore/runtime/__init__.py @@ -4,15 +4,23 @@ - BedrockAgentCoreApp: Main application class - RequestContext: HTTP request context - BedrockAgentCoreContext: Agent identity context +- AgentCoreRuntimeHttpClient: Bearer-token HTTP client for invoking a deployed runtime +- AgentRuntimeError: Exception raised by the HTTP client """ from .agent_core_runtime_client import AgentCoreRuntimeClient +from .agent_core_runtime_http_client import ( + AgentCoreRuntimeHttpClient, + AgentRuntimeError, +) from .app import BedrockAgentCoreApp from .context import BedrockAgentCoreContext, RequestContext from .models import PingStatus __all__ = [ "AgentCoreRuntimeClient", + "AgentCoreRuntimeHttpClient", + "AgentRuntimeError", "AGUIApp", "BedrockAgentCoreApp", "BedrockCallContextBuilder", @@ -29,7 +37,12 @@ def __getattr__(name: str): """Lazy imports for A2A and AG-UI symbols so optional dependencies are not required at import time.""" - _a2a_exports = {"BedrockCallContextBuilder", "build_a2a_app", "build_runtime_url", "serve_a2a"} + _a2a_exports = { + "BedrockCallContextBuilder", + "build_a2a_app", + "build_runtime_url", + "serve_a2a", + } if name in _a2a_exports: from . import a2a as _a2a_module diff --git a/src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py b/src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py new file mode 100644 index 00000000..45fd7d49 --- /dev/null +++ b/src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py @@ -0,0 +1,733 @@ +"""HTTP client for invoking a deployed Bedrock AgentCore runtime. + +Complements :class:`AgentCoreRuntimeClient` (which only builds WebSocket URLs +and headers) by providing bearer-token HTTP invocation, SSE streaming, and the +``InvokeAgentRuntimeCommand`` API for running shell commands inside a runtime +session. + +Wire formats targeted: + +- ``POST /runtimes/{arn}/invocations`` — agent invocation. Response is either + a JSON document or Server-Sent Events (``text/event-stream``). +- ``POST /runtimes/{arn}/commands`` — ``InvokeAgentRuntimeCommand``. Response + is the AWS EventStream binary framing + (``application/vnd.amazon.eventstream``). Each event's payload is JSON + wrapped under a ``chunk`` key containing one of ``contentStart``, + ``contentDelta {stdout, stderr}``, or + ``contentStop {exitCode, status}``. +- ``POST /runtimes/{arn}/stopruntimesession`` — session termination. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import threading +import urllib.parse +import uuid +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Iterator, Optional + +import urllib3 +from botocore.eventstream import EventStreamBuffer + +if TYPE_CHECKING: + # urllib3 v2 returns ``BaseHTTPResponse`` from ``PoolManager.request``. + # On urllib3 v1.26 the returned object is ``HTTPResponse`` which is + # structurally compatible with the same attributes (``status``, ``headers``, + # ``read``, ``stream``, ``release_conn``), so the annotation is purely for + # type-checking under v2. + from urllib3 import BaseHTTPResponse as _UrllibResponse + +logger = logging.getLogger(__name__) + + +class AgentRuntimeError(Exception): + """Raised when an AgentCore runtime returns an error response. + + Used by :class:`AgentCoreRuntimeHttpClient` for both non-2xx HTTP + responses and in-band error events embedded in SSE streams. + + Attributes: + error: Short machine-readable error token (for example the + runtime's ``error`` field, or the HTTP status text). + error_type: Category label. For HTTP failures this is + ``"HTTP "`` (e.g. ``"HTTP 404"``). For runtime error + payloads it is whatever the server set on the ``error_type`` + field. + """ + + def __init__(self, error: str, error_type: str = "", message: str = "") -> None: + """Initialize the exception. + + Args: + error: Short error token (see :attr:`error`). + error_type: Category label (see :attr:`error_type`). Defaults + to the empty string. + message: Human-readable message used as the exception's + string representation. Defaults to ``error`` when empty. + """ + self.error = error + self.error_type = error_type + super().__init__(message or error) + + +class AgentCoreRuntimeHttpClient: + """HTTP client for invoking a deployed Bedrock AgentCore runtime. + + Use this client when you need bearer-token authentication (JWT/OAuth) + instead of IAM/SigV4. It supports blocking invocation, synchronous and + asynchronous streaming, shell command execution via + ``InvokeAgentRuntimeCommand``, and session termination. + + Each method takes the ``bearer_token`` per-call so the same client can + be reused with rotating credentials. + + Attributes: + agent_arn: Full ARN of the target agent runtime. + region: AWS region extracted from ``agent_arn``. + endpoint_name: Endpoint qualifier (defaults to ``"DEFAULT"``). + timeout: Default HTTP read timeout in seconds for non-command + methods. ``execute_command`` derives its HTTP timeout from + ``command_timeout`` instead (see that method). + content_type: Request ``Content-Type`` for :meth:`invoke` and + :meth:`invoke_streaming`. + accept: Request ``Accept`` for :meth:`invoke` and + :meth:`invoke_streaming`. + """ + + def __init__( + self, + agent_arn: str, + endpoint_name: str = "DEFAULT", + timeout: int = 300, + content_type: str = "application/json", + accept: str = "application/json", + pool_manager: Optional[urllib3.PoolManager] = None, + ) -> None: + """Initialize the HTTP client. + + Args: + agent_arn: The ARN of the agent runtime to invoke. The AWS + region is extracted from the ARN automatically. + endpoint_name: Endpoint qualifier sent as the ``qualifier`` + query parameter. Defaults to ``"DEFAULT"``. + timeout: Default HTTP read timeout in seconds (used by + :meth:`invoke`, :meth:`invoke_streaming`, + :meth:`invoke_streaming_async`, and + :meth:`stop_runtime_session`). ``execute_command`` derives + its HTTP timeout internally from ``command_timeout``. + content_type: MIME type of the request payload for invocation + calls. + accept: Desired MIME type for invocation responses. + pool_manager: Optional pre-configured + :class:`urllib3.PoolManager`. Primarily useful for tests + and for callers who want to control connection pooling. + A fresh manager is created when not provided. + + Raises: + ValueError: If the ARN does not contain a parseable region + component. + """ + parts = agent_arn.split(":") + if len(parts) < 4 or not parts[3]: + raise ValueError(f"Invalid agent ARN (missing region): {agent_arn}") + + self.agent_arn = agent_arn + self.region = parts[3] + self.endpoint_name = endpoint_name + self.timeout = timeout + self.content_type = content_type + self.accept = accept + self._http: urllib3.PoolManager = pool_manager or urllib3.PoolManager() + + # ------------------------------------------------------------------ # + # URL, headers, body helpers + # ------------------------------------------------------------------ # + + def _build_url(self, path_suffix: str) -> str: + """Build a full runtime URL. + + Args: + path_suffix: Path under ``/runtimes/{arn}/``. Must not start + with ``/``. Examples: ``"invocations"``, ``"commands"``, + ``"stopruntimesession"``. + + Returns: + Absolute URL including the qualifier query string. + """ + escaped_arn = urllib.parse.quote(self.agent_arn, safe="") + base = f"https://bedrock-agentcore.{self.region}.amazonaws.com/runtimes/{escaped_arn}/{path_suffix}" + query = urllib.parse.urlencode({"qualifier": self.endpoint_name}) + return f"{base}?{query}" + + def _build_headers( + self, + bearer_token: str, + session_id: str, + accept: Optional[str] = None, + content_type: Optional[str] = None, + ) -> dict[str, str]: + """Build the base request headers. + + Args: + bearer_token: OAuth/JWT bearer token. + session_id: Runtime session id. + accept: Optional override for the ``Accept`` header. + content_type: Optional override for the ``Content-Type`` + header. + + Returns: + Header dict including ``Authorization``, ``Content-Type``, + ``Accept``, and ``X-Amzn-Bedrock-AgentCore-Runtime-Session-Id``. + """ + return { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": content_type or self.content_type, + "Accept": accept or self.accept, + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, + } + + def _serialize_body(self, body: Any) -> bytes: + """Serialize a request body to bytes. + + JSON content types serialize via :func:`json.dumps`. For other + content types, ``bytes`` pass through unchanged, ``str`` is + UTF-8-encoded, and anything else falls back to JSON serialization. + + Args: + body: The payload to send. + + Returns: + UTF-8 encoded request body. + """ + if "json" in self.content_type: + return json.dumps(body).encode("utf-8") + if isinstance(body, bytes): + return body + if isinstance(body, str): + return body.encode("utf-8") + return json.dumps(body).encode("utf-8") + + # ------------------------------------------------------------------ # + # invoke / invoke_streaming / invoke_streaming_async + # ------------------------------------------------------------------ # + + def invoke( + self, + body: Any, + session_id: str, + bearer_token: str, + headers: Optional[dict[str, str]] = None, + ) -> str: + """Invoke the agent and return the full response body as a string. + + Handles both JSON responses and Server-Sent Events transparently. + For SSE responses, the decoded event data is concatenated in the + order it arrived. + + Args: + body: The request body to send to the agent. + session_id: Session id for conversation continuity. + bearer_token: Bearer token for authentication. + headers: Optional extra headers to include. Overwrite the + defaults on key collision. + + Returns: + The complete response body as a string. For JSON responses + that happen to be a JSON-encoded string, the unwrapped string + value is returned. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the SSE stream. + """ + request_headers = self._build_headers(bearer_token, session_id) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + self._build_url("invocations"), + headers=request_headers, + body=self._serialize_body(body), + timeout=self.timeout, + preload_content=False, + ) + try: + self._check_response(response) + content_type = response.headers.get("content-type", "") + if "text/event-stream" not in content_type: + return self._read_non_streaming(response) + return "".join(self._iter_sse_decoded(response)) + finally: + response.release_conn() + + def invoke_streaming( + self, + body: Any, + session_id: str, + bearer_token: str, + headers: Optional[dict[str, str]] = None, + ) -> Generator[str, None, None]: + """Invoke the agent and yield SSE chunks as they arrive. + + Args: + body: The request body to send to the agent. + session_id: Session id for conversation continuity. + bearer_token: Bearer token for authentication. + headers: Optional extra headers to include. + + Yields: + Decoded payload strings from the SSE stream, one per + non-empty ``data:`` line. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the stream. + """ + request_headers = self._build_headers(bearer_token, session_id) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + self._build_url("invocations"), + headers=request_headers, + body=self._serialize_body(body), + timeout=self.timeout, + preload_content=False, + ) + try: + self._check_response(response) + yield from self._iter_sse_decoded(response) + finally: + response.release_conn() + + async def invoke_streaming_async( + self, + body: Any, + session_id: str, + bearer_token: str, + headers: Optional[dict[str, str]] = None, + ) -> AsyncGenerator[str, None]: + """Async generator version of :meth:`invoke_streaming`. + + The underlying HTTP call is blocking; this wrapper runs it on a + background thread and delivers chunks to the caller through an + :class:`asyncio.Queue`, so it is safe to ``async for`` over. + + Args: + body: The request body to send to the agent. + session_id: Session id for conversation continuity. + bearer_token: Bearer token for authentication. + headers: Optional extra headers to include. + + Yields: + Decoded payload strings from the SSE stream. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the stream. + """ + chunk_queue: asyncio.Queue[Any] = asyncio.Queue() + sentinel = object() + loop = asyncio.get_running_loop() + + def stream_in_thread() -> None: + try: + for decoded in self.invoke_streaming( + body=body, + session_id=session_id, + bearer_token=bearer_token, + headers=headers, + ): + loop.call_soon_threadsafe(chunk_queue.put_nowait, decoded) + loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) + except Exception as exc: # noqa: BLE001 — propagated to caller + loop.call_soon_threadsafe(chunk_queue.put_nowait, exc) + loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) + + thread = threading.Thread(target=stream_in_thread, daemon=True) + thread.start() + + while True: + item = await chunk_queue.get() + if item is sentinel: + break + if isinstance(item, Exception): + raise item + yield item + await asyncio.sleep(0) + + # ------------------------------------------------------------------ # + # execute_command / execute_command_streaming + # ------------------------------------------------------------------ # + + def execute_command( + self, + command: str, + session_id: str, + bearer_token: str, + command_timeout: Optional[int] = None, + headers: Optional[dict[str, str]] = None, + ) -> dict[str, Any]: + """Run a shell command inside the runtime session and collect the full result. + + Blocking. Accumulates all ``stdout`` and ``stderr`` chunks from + the EventStream and returns the final exit status. + + Args: + command: Shell command to run (1 B – 64 KB per the + ``InvokeAgentRuntimeCommand`` API). + session_id: Runtime session id to target. The filesystem + inside the container persists across calls, but a fresh + shell is spawned each time, so working directory and + environment variables do not. + bearer_token: Bearer token for authentication. + command_timeout: Server-side command wall-clock timeout in + seconds (1–3600). Defaults to :attr:`timeout`. The + HTTP read timeout is derived internally as + ``command_timeout + 30``. + headers: Optional extra headers to include. + + Returns: + Dict with keys ``"stdout"`` (str), ``"stderr"`` (str), + ``"exitCode"`` (int, ``-1`` if no ``contentStop`` was + received), and ``"status"`` (``"COMPLETED"`` or + ``"TIMED_OUT"``, or ``"UNKNOWN"`` if no ``contentStop`` was + received). + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status. + """ + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + exit_code: int = -1 + status: str = "UNKNOWN" + + for event in self.execute_command_streaming( + command=command, + session_id=session_id, + bearer_token=bearer_token, + command_timeout=command_timeout, + headers=headers, + ): + if "contentDelta" in event: + delta = event["contentDelta"] + if "stdout" in delta: + stdout_parts.append(delta["stdout"]) + if "stderr" in delta: + stderr_parts.append(delta["stderr"]) + elif "contentStop" in event: + exit_code = int(event["contentStop"].get("exitCode", -1)) + status = str(event["contentStop"].get("status", "UNKNOWN")) + + return { + "stdout": "".join(stdout_parts), + "stderr": "".join(stderr_parts), + "exitCode": exit_code, + "status": status, + } + + def execute_command_streaming( + self, + command: str, + session_id: str, + bearer_token: str, + command_timeout: Optional[int] = None, + headers: Optional[dict[str, str]] = None, + ) -> Generator[dict[str, Any], None, None]: + """Stream AWS EventStream events from ``InvokeAgentRuntimeCommand``. + + Yields the decoded event payloads (the value inside the + server's ``"chunk"`` envelope). Each yielded dict has exactly one + of the keys ``"contentStart"``, ``"contentDelta"``, or + ``"contentStop"``. + + Args: + command: Shell command to run. + session_id: Runtime session id. + bearer_token: Bearer token for authentication. + command_timeout: Server-side wall-clock timeout in seconds + (1–3600). Defaults to :attr:`timeout`. + headers: Optional extra headers to include. + + Yields: + Parsed event payload dicts. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status. + """ + effective_timeout = ( + command_timeout if command_timeout is not None else self.timeout + ) + + request_headers = self._build_headers( + bearer_token, + session_id, + accept="application/vnd.amazon.eventstream", + content_type="application/json", + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + self._build_url("commands"), + headers=request_headers, + body=json.dumps({"command": command, "timeout": effective_timeout}).encode( + "utf-8" + ), + timeout=effective_timeout + 30, + preload_content=False, + ) + try: + self._check_response(response) + + buf = EventStreamBuffer() + for chunk in response.stream(4096): + if not chunk: + continue + buf.add_data(chunk) + for event in buf: + payload = event.payload + if not payload: + continue + try: + decoded = json.loads(payload) + except json.JSONDecodeError: + continue + inner = decoded.get("chunk") if isinstance(decoded, dict) else None + yield inner if isinstance(inner, dict) else decoded + finally: + response.release_conn() + + # ------------------------------------------------------------------ # + # stop_runtime_session + # ------------------------------------------------------------------ # + + def stop_runtime_session( + self, + session_id: str, + bearer_token: str, + client_token: Optional[str] = None, + ) -> dict[str, Any]: + """Terminate a runtime session. + + Args: + session_id: The session id to stop. + bearer_token: Bearer token for authentication. + client_token: Idempotency token. Auto-generated as a UUID4 + when not supplied. + + Returns: + Parsed JSON body of the response (often an empty dict). + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + (for example ``HTTP 404`` for an unknown session). + """ + request_headers = { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": "application/json", + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, + } + response = self._http.request( + "POST", + self._build_url("stopruntimesession"), + headers=request_headers, + body=json.dumps({"clientToken": client_token or str(uuid.uuid4())}).encode( + "utf-8" + ), + timeout=self.timeout, + preload_content=True, + ) + self._check_response(response) + if not response.data: + return {} + try: + parsed = json.loads(response.data) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {"response": parsed} + + # ------------------------------------------------------------------ # + # Response parsing + # ------------------------------------------------------------------ # + + def _check_response(self, response: _UrllibResponse) -> None: + """Raise :class:`AgentRuntimeError` for non-2xx responses. + + Attempts to parse the body as JSON and surface ``error``, + ``error_type``, and ``message`` fields. Falls back to the raw + body text. + + Args: + response: The urllib3 response to inspect. + + Raises: + AgentRuntimeError: If the status code is 400 or above. + """ + if response.status < 400: + return + + body_bytes = response.read() + try: + body: Any = json.loads(body_bytes) + except (json.JSONDecodeError, ValueError): + body = body_bytes.decode("utf-8", errors="replace") + + error_type = f"HTTP {response.status}" + reason = response.reason or error_type + + if isinstance(body, dict): + raise AgentRuntimeError( + error=body.get("error", reason), + error_type=body.get("error_type", error_type), + message=body.get( + "message", body_bytes.decode("utf-8", errors="replace") + ), + ) + raise AgentRuntimeError( + error=str(body) or reason, + error_type=error_type, + ) + + def _read_non_streaming(self, response: _UrllibResponse) -> str: + """Read a fully-buffered response body and unwrap JSON strings. + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Returns: + The response body as text. If the body is a JSON-encoded + string (e.g. ``"hello"``), the unwrapped value is returned. + """ + data = response.read() + text = data.decode("utf-8", errors="replace") + try: + parsed = json.loads(text) + except (json.JSONDecodeError, ValueError): + return text + return parsed if isinstance(parsed, str) else text + + def _iter_sse_decoded(self, response: _UrllibResponse) -> Iterator[str]: + """Iterate decoded SSE payloads from a streaming response. + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Yields: + Decoded payload strings (one per non-empty ``data:`` line + or JSON line). + """ + for raw_line in self._iter_lines(response): + if not raw_line: + continue + decoded = self._decode_sse_line(raw_line.decode("utf-8", errors="replace")) + if decoded: + yield decoded + + @staticmethod + def _iter_lines(response: _UrllibResponse) -> Iterator[bytes]: + r"""Yield lines from a streaming urllib3 response. + + Splits on ``\n`` and strips a trailing ``\r`` so it handles + both LF and CRLF line endings. Preserves empty lines (the caller + filters them). + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Yields: + Each line as bytes (with no trailing newline). + """ + pending = b"" + for chunk in response.stream(1024): + if not chunk: + continue + pending += chunk + while True: + idx = pending.find(b"\n") + if idx < 0: + break + line, pending = pending[:idx], pending[idx + 1 :] + if line.endswith(b"\r"): + line = line[:-1] + yield line + if pending: + if pending.endswith(b"\r"): + pending = pending[:-1] + yield pending + + def _decode_sse_line(self, line: str) -> Optional[str]: + """Decode a single SSE or JSON-Lines line. + + Handles SSE ``data:`` lines, JSON error envelopes, JSON-encoded + strings, and plain text. Error envelopes (with an ``error`` key) + are raised as :class:`AgentRuntimeError`. + + Args: + line: A single line of the streamed response. + + Returns: + The decoded payload, or ``None`` if the line is a comment, + empty, or carries no renderable text. + + Raises: + AgentRuntimeError: If the line carries a JSON error payload. + """ + line = line.strip() + if not line or line.startswith(":"): + return None + + if line.startswith("data:"): + content = line[5:].strip() + + if content.startswith("{"): + try: + data = json.loads(content) + if isinstance(data, dict) and "error" in data: + raise AgentRuntimeError( + error=str(data["error"]), + error_type=str(data.get("error_type", "")), + message=str(data.get("message", data["error"])), + ) + except json.JSONDecodeError: + pass + + if content.startswith('"'): + try: + unwrapped = json.loads(content) + except json.JSONDecodeError: + if content.endswith('"'): + return content[1:-1] + return content + return unwrapped if isinstance(unwrapped, str) else content + return content + + try: + data = json.loads(line) + except json.JSONDecodeError: + return line + + if isinstance(data, dict): + if "error" in data: + raise AgentRuntimeError( + error=str(data["error"]), + error_type=str(data.get("error_type", "")), + message=str(data.get("message", data["error"])), + ) + if "text" in data: + value = data["text"] + return value if isinstance(value, str) else str(value) + if "content" in data: + value = data["content"] + return value if isinstance(value, str) else str(value) + if "data" in data: + return str(data["data"]) + return None diff --git a/tests/bedrock_agentcore/test_init.py b/tests/bedrock_agentcore/test_init.py index f757719c..4c3ea081 100644 --- a/tests/bedrock_agentcore/test_init.py +++ b/tests/bedrock_agentcore/test_init.py @@ -7,7 +7,10 @@ def test_getattr_raises_for_unknown_attribute(): """Test that __getattr__ raises AttributeError for unknown attributes.""" import bedrock_agentcore - with pytest.raises(AttributeError, match="module 'bedrock_agentcore' has no attribute 'UnknownAttribute'"): + with pytest.raises( + AttributeError, + match="module 'bedrock_agentcore' has no attribute 'UnknownAttribute'", + ): _ = bedrock_agentcore.UnknownAttribute diff --git a/tests/unit/runtime/test_agent_core_runtime_http_client.py b/tests/unit/runtime/test_agent_core_runtime_http_client.py new file mode 100644 index 00000000..766cc93d --- /dev/null +++ b/tests/unit/runtime/test_agent_core_runtime_http_client.py @@ -0,0 +1,823 @@ +"""Tests for AgentCoreRuntimeHttpClient.""" + +from __future__ import annotations + +import json +from typing import Any, Iterator, Optional +from unittest.mock import MagicMock + +import pytest + +from bedrock_agentcore.runtime.agent_core_runtime_http_client import ( + AgentCoreRuntimeHttpClient, + AgentRuntimeError, +) + +ARN = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc" +BEARER = "test-token" +SESSION = "a" * 36 + + +class _FakeResponse: + """Minimal stand-in for urllib3's response used by the HTTP client. + + Exposes the exact surface the client touches: ``status``, ``reason``, + ``headers``, ``data``, ``read``, ``stream``, and ``release_conn``. + """ + + def __init__( + self, + status: int = 200, + headers: Optional[dict[str, str]] = None, + body: bytes = b"", + chunks: Optional[list[bytes]] = None, + reason: str = "", + ) -> None: + self.status = status + self.reason = reason + self.headers = headers or {} + self._body = body + self._chunks = chunks + self._consumed = False + self.release_calls = 0 + + @property + def data(self) -> bytes: + return self._body + + def read(self) -> bytes: + if self._consumed: + return b"" + self._consumed = True + return self._body + + def stream(self, amt: int = 1024) -> Iterator[bytes]: + if self._chunks is not None: + for chunk in self._chunks: + yield chunk + return + # Fall back to emitting the full body as one chunk. + if self._body: + yield self._body + + def release_conn(self) -> None: + self.release_calls += 1 + + +def _make_client( + response: _FakeResponse, +) -> tuple[AgentCoreRuntimeHttpClient, MagicMock]: + """Build a client whose PoolManager returns ``response``.""" + pool = MagicMock() + pool.request.return_value = response + client = AgentCoreRuntimeHttpClient(agent_arn=ARN, pool_manager=pool) + return client, pool + + +# --------------------------------------------------------------------- # +# Initialization +# --------------------------------------------------------------------- # + + +class TestInit: + """Tests for AgentCoreRuntimeHttpClient.__init__.""" + + def test_parses_region_from_arn(self) -> None: + """Region is extracted from the ARN.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + assert client.region == "us-west-2" + + def test_defaults(self) -> None: + """Default field values.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + assert client.endpoint_name == "DEFAULT" + assert client.timeout == 300 + assert client.content_type == "application/json" + assert client.accept == "application/json" + + def test_custom_fields(self) -> None: + """Non-default field values are stored.""" + client = AgentCoreRuntimeHttpClient( + agent_arn=ARN, + endpoint_name="DEV", + timeout=42, + content_type="application/cbor", + accept="text/plain", + ) + assert client.endpoint_name == "DEV" + assert client.timeout == 42 + assert client.content_type == "application/cbor" + assert client.accept == "text/plain" + + def test_invalid_arn_missing_region(self) -> None: + """An ARN with no region raises ValueError.""" + with pytest.raises(ValueError, match="Invalid agent ARN"): + AgentCoreRuntimeHttpClient(agent_arn="arn:aws:bedrock-agentcore") + + def test_invalid_arn_empty_region(self) -> None: + """An ARN with empty region raises ValueError.""" + with pytest.raises(ValueError, match="Invalid agent ARN"): + AgentCoreRuntimeHttpClient( + agent_arn="arn:aws:bedrock-agentcore::123:runtime/x" + ) + + def test_pool_manager_injection(self) -> None: + """A caller-provided PoolManager is used verbatim.""" + pool = MagicMock() + client = AgentCoreRuntimeHttpClient(agent_arn=ARN, pool_manager=pool) + assert client._http is pool + + +# --------------------------------------------------------------------- # +# _build_url / _build_headers / _serialize_body +# --------------------------------------------------------------------- # + + +class TestBuildUrl: + """Tests for _build_url.""" + + def test_invocations_url(self) -> None: + """URL embeds region, URL-encoded ARN, path, and qualifier.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN, endpoint_name="DEFAULT") + url = client._build_url("invocations") + assert url.startswith( + "https://bedrock-agentcore.us-west-2.amazonaws.com/runtimes/" + ) + assert "/invocations?qualifier=DEFAULT" in url + assert "arn%3Aaws%3Abedrock-agentcore" in url # colons URL-encoded + + def test_commands_url(self) -> None: + """Different path suffix yields the commands endpoint.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + assert "/commands?qualifier=DEFAULT" in client._build_url("commands") + + def test_non_default_qualifier(self) -> None: + """Qualifier reflects endpoint_name.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN, endpoint_name="DEV") + assert "qualifier=DEV" in client._build_url("invocations") + + +class TestBuildHeaders: + """Tests for _build_headers.""" + + def test_default_headers(self) -> None: + """Includes auth, content-type, accept, session-id.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + headers = client._build_headers(BEARER, SESSION) + assert headers["Authorization"] == f"Bearer {BEARER}" + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == SESSION + + def test_override_accept_and_content_type(self) -> None: + """Per-call overrides replace the defaults.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + headers = client._build_headers( + BEARER, + SESSION, + accept="application/vnd.amazon.eventstream", + content_type="application/xml", + ) + assert headers["Accept"] == "application/vnd.amazon.eventstream" + assert headers["Content-Type"] == "application/xml" + + +class TestSerializeBody: + """Tests for _serialize_body.""" + + def test_json_dict(self) -> None: + """Dict body is JSON-serialized when content type is JSON.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + assert client._serialize_body({"a": 1}) == b'{"a": 1}' + + def test_bytes_passthrough_non_json(self) -> None: + """Bytes body is sent verbatim for non-JSON content types.""" + client = AgentCoreRuntimeHttpClient( + agent_arn=ARN, content_type="application/octet-stream" + ) + assert client._serialize_body(b"raw") == b"raw" + + def test_str_utf8_encoded_non_json(self) -> None: + """String body is UTF-8 encoded for non-JSON content types.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN, content_type="text/plain") + assert client._serialize_body("héllo") == "héllo".encode("utf-8") + + def test_fallback_to_json(self) -> None: + """Non-str, non-bytes body falls back to JSON even with non-JSON content type.""" + client = AgentCoreRuntimeHttpClient( + agent_arn=ARN, content_type="application/cbor" + ) + assert client._serialize_body({"x": 1}) == b'{"x": 1}' + + +# --------------------------------------------------------------------- # +# invoke (non-streaming JSON, non-streaming plain, SSE) +# --------------------------------------------------------------------- # + + +class TestInvoke: + """Tests for invoke.""" + + def test_non_streaming_json_string(self) -> None: + """JSON string response is unwrapped.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b'"hello"', + ) + client, _ = _make_client(resp) + assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == "hello" + assert resp.release_calls == 1 + + def test_non_streaming_json_object_returns_text(self) -> None: + """Non-string JSON response returns the raw body text.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b'{"answer": 42}', + ) + client, _ = _make_client(resp) + assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == '{"answer": 42}' + + def test_non_streaming_invalid_json_returns_text(self) -> None: + """If the body isn't valid JSON, the raw text is returned.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b"plain-text-not-json", + ) + client, _ = _make_client(resp) + assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == "plain-text-not-json" + + def test_sse_streaming_concatenates(self) -> None: + """Multiple SSE data lines are concatenated in order.""" + body = b"data: hello\ndata: world\n" + resp = _FakeResponse( + status=200, + headers={"content-type": "text/event-stream"}, + chunks=[body], + ) + client, _ = _make_client(resp) + assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == "helloworld" + + def test_custom_headers_merged(self) -> None: + """Caller-supplied headers are merged into the request.""" + resp = _FakeResponse( + status=200, headers={"content-type": "application/json"}, body=b'"ok"' + ) + client, pool = _make_client(resp) + client.invoke({}, SESSION, BEARER, headers={"X-Test": "1"}) + sent_headers = pool.request.call_args.kwargs["headers"] + assert sent_headers["X-Test"] == "1" + assert sent_headers["Authorization"] == f"Bearer {BEARER}" + + def test_non_ok_raises(self) -> None: + """4xx and 5xx responses become AgentRuntimeError.""" + resp = _FakeResponse( + status=500, + body=b'{"error": "oops", "message": "boom"}', + reason="Server Error", + ) + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError) as excinfo: + client.invoke({}, SESSION, BEARER) + assert excinfo.value.error == "oops" + assert "boom" in str(excinfo.value) + + +# --------------------------------------------------------------------- # +# invoke_streaming +# --------------------------------------------------------------------- # + + +class TestInvokeStreaming: + """Tests for invoke_streaming.""" + + def test_yields_chunks(self) -> None: + """Yields each decoded SSE payload in order.""" + body = b"data: first\ndata: second\n" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[body] + ) + client, _ = _make_client(resp) + assert list(client.invoke_streaming({}, SESSION, BEARER)) == ["first", "second"] + + def test_non_ok_raises(self) -> None: + """Errors are surfaced before the generator yields.""" + resp = _FakeResponse(status=404, body=b"not found", reason="Not Found") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + list(client.invoke_streaming({}, SESSION, BEARER)) + + def test_custom_headers_merged(self) -> None: + """Custom headers are merged into the request.""" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[b""] + ) + client, pool = _make_client(resp) + list(client.invoke_streaming({}, SESSION, BEARER, headers={"X-Custom": "v"})) + assert pool.request.call_args.kwargs["headers"]["X-Custom"] == "v" + + +# --------------------------------------------------------------------- # +# invoke_streaming_async +# --------------------------------------------------------------------- # + + +class TestInvokeStreamingAsync: + """Tests for invoke_streaming_async.""" + + @pytest.mark.asyncio + async def test_yields_async_chunks(self) -> None: + """Async generator yields the same chunks as the sync version.""" + body = b"data: a\ndata: b\n" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[body] + ) + client, _ = _make_client(resp) + out: list[str] = [] + async for chunk in client.invoke_streaming_async({}, SESSION, BEARER): + out.append(chunk) + assert out == ["a", "b"] + + @pytest.mark.asyncio + async def test_propagates_exception(self) -> None: + """Exceptions from the background thread surface to the caller.""" + resp = _FakeResponse(status=500, body=b"err", reason="Server Error") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + async for _ in client.invoke_streaming_async({}, SESSION, BEARER): + pass + + +# --------------------------------------------------------------------- # +# execute_command / execute_command_streaming +# --------------------------------------------------------------------- # + + +def _encode_eventstream_frame(payload: dict[str, Any]) -> bytes: + """Encode a JSON payload using botocore's EventStream framing. + + Uses the same encoder that the server would use to build a valid frame + the client can parse. + """ + + body = json.dumps(payload).encode("utf-8") + # Build prelude manually. Easier: construct bytes that match the wire format. + # Format: total_length (4), headers_length (4), prelude_crc (4), headers, payload, message_crc (4). + # With no headers: headers_length = 0. + import binascii + import struct + + headers = b"" + headers_length = len(headers) + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 # + message CRC + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + return message_bytes + message_crc + + +class TestExecuteCommand: + """Tests for execute_command (blocking).""" + + def test_aggregates_output(self) -> None: + """stdout/stderr/exitCode/status are aggregated across events.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentStart": {}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "hi "}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "there"}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stderr": "warn"}}}), + _encode_eventstream_frame( + {"chunk": {"contentStop": {"exitCode": 0, "status": "COMPLETED"}}} + ), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + result = client.execute_command("echo hi there", SESSION, BEARER) + assert result == { + "stdout": "hi there", + "stderr": "warn", + "exitCode": 0, + "status": "COMPLETED", + } + + def test_missing_stop_has_unknown_status(self) -> None: + """Without a contentStop event, defaults are UNKNOWN / -1.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "x"}}}), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + result = client.execute_command("echo x", SESSION, BEARER) + assert result == { + "stdout": "x", + "stderr": "", + "exitCode": -1, + "status": "UNKNOWN", + } + + +class TestExecuteCommandStreaming: + """Tests for execute_command_streaming.""" + + def test_sends_command_body(self) -> None: + """Request body contains command and timeout.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[], + ) + client, pool = _make_client(resp) + list( + client.execute_command_streaming("ls", SESSION, BEARER, command_timeout=42) + ) + sent = pool.request.call_args + assert sent.kwargs["body"] == b'{"command": "ls", "timeout": 42}' + # HTTP timeout must be command_timeout + 30. + assert sent.kwargs["timeout"] == 72 + + def test_defaults_to_self_timeout(self) -> None: + """command_timeout=None falls back to self.timeout.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[], + ) + client = AgentCoreRuntimeHttpClient(agent_arn=ARN, timeout=120) + client._http = MagicMock() + client._http.request.return_value = resp + list(client.execute_command_streaming("ls", SESSION, BEARER)) + sent = client._http.request.call_args + body = json.loads(sent.kwargs["body"]) + assert body["timeout"] == 120 + assert sent.kwargs["timeout"] == 150 + + def test_yields_parsed_events(self) -> None: + """Each EventStream frame is yielded as a dict.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentStart": {}}}), + _encode_eventstream_frame( + {"chunk": {"contentStop": {"exitCode": 2, "status": "COMPLETED"}}} + ), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + parsed = list(client.execute_command_streaming("echo hi", SESSION, BEARER)) + assert parsed[0] == {"contentStart": {}} + assert parsed[1] == {"contentStop": {"exitCode": 2, "status": "COMPLETED"}} + + def test_skips_event_without_payload(self) -> None: + """Events with empty payload are skipped silently.""" + # An eventstream frame with empty payload body. + import binascii + import struct + + headers = b"" + body = b"" + headers_length = 0 + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + empty_frame = message_bytes + message_crc + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[empty_frame], + ) + client, _ = _make_client(resp) + assert list(client.execute_command_streaming("ls", SESSION, BEARER)) == [] + + def test_skips_bad_json_payload(self) -> None: + """Events whose payload is not valid JSON are skipped.""" + # Build a frame whose payload is not valid JSON. + import binascii + import struct + + headers = b"" + body = b"not-json" + headers_length = 0 + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + bad_frame = message_bytes + message_crc + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[bad_frame], + ) + client, _ = _make_client(resp) + assert list(client.execute_command_streaming("ls", SESSION, BEARER)) == [] + + def test_non_ok_raises(self) -> None: + """Non-2xx responses raise AgentRuntimeError.""" + resp = _FakeResponse(status=403, body=b"forbidden", reason="Forbidden") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + list(client.execute_command_streaming("ls", SESSION, BEARER)) + + +# --------------------------------------------------------------------- # +# stop_runtime_session +# --------------------------------------------------------------------- # + + +class TestStopRuntimeSession: + """Tests for stop_runtime_session.""" + + def test_sends_client_token(self) -> None: + """Request body contains the provided client token.""" + resp = _FakeResponse(status=200, body=b"{}") + client, pool = _make_client(resp) + client.stop_runtime_session(SESSION, BEARER, client_token="abc") + sent = json.loads(pool.request.call_args.kwargs["body"]) + assert sent == {"clientToken": "abc"} + + def test_autogenerates_client_token(self) -> None: + """Missing client_token is replaced with a UUID4.""" + resp = _FakeResponse(status=200, body=b"{}") + client, pool = _make_client(resp) + client.stop_runtime_session(SESSION, BEARER) + sent = json.loads(pool.request.call_args.kwargs["body"]) + # UUID4 string is 36 chars (8-4-4-4-12). + assert len(sent["clientToken"]) == 36 + + def test_returns_empty_dict_on_blank_body(self) -> None: + """Empty response body yields an empty dict.""" + resp = _FakeResponse(status=200, body=b"") + client, _ = _make_client(resp) + assert client.stop_runtime_session(SESSION, BEARER) == {} + + def test_returns_parsed_dict(self) -> None: + """Dict body is returned as-is.""" + resp = _FakeResponse(status=200, body=b'{"sessionId": "x"}') + client, _ = _make_client(resp) + assert client.stop_runtime_session(SESSION, BEARER) == {"sessionId": "x"} + + def test_non_dict_body_wrapped(self) -> None: + """A JSON body that isn't a dict is wrapped under 'response'.""" + resp = _FakeResponse(status=200, body=b'["a", "b"]') + client, _ = _make_client(resp) + assert client.stop_runtime_session(SESSION, BEARER) == {"response": ["a", "b"]} + + def test_invalid_json_body_returns_empty(self) -> None: + """Invalid JSON body returns empty dict rather than raising.""" + resp = _FakeResponse(status=200, body=b"not json") + client, _ = _make_client(resp) + assert client.stop_runtime_session(SESSION, BEARER) == {} + + def test_404_raises(self) -> None: + """Unknown-session HTTP 404 becomes AgentRuntimeError.""" + resp = _FakeResponse(status=404, body=b"{}", reason="Not Found") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError) as excinfo: + client.stop_runtime_session(SESSION, BEARER) + assert excinfo.value.error_type == "HTTP 404" + + +# --------------------------------------------------------------------- # +# _check_response +# --------------------------------------------------------------------- # + + +class TestCheckResponse: + """Tests for _check_response.""" + + def test_noop_on_2xx(self) -> None: + """2xx responses do not raise.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + client._check_response(_FakeResponse(status=200)) + + def test_parses_json_error_body(self) -> None: + """JSON error bodies populate error/error_type/message.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + resp = _FakeResponse( + status=400, + body=b'{"error": "bad", "error_type": "Validation", "message": "details"}', + reason="Bad Request", + ) + with pytest.raises(AgentRuntimeError) as excinfo: + client._check_response(resp) + assert excinfo.value.error == "bad" + assert excinfo.value.error_type == "Validation" + + def test_falls_back_to_text_body(self) -> None: + """Non-JSON bodies become the error string.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + resp = _FakeResponse(status=500, body=b"internal boom", reason="Server Error") + with pytest.raises(AgentRuntimeError) as excinfo: + client._check_response(resp) + assert "internal boom" in excinfo.value.error + assert excinfo.value.error_type == "HTTP 500" + + def test_defaults_error_type_for_dict_body(self) -> None: + """Missing error_type in JSON body defaults to 'HTTP '.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + resp = _FakeResponse(status=502, body=b'{"message": "bad gateway"}') + with pytest.raises(AgentRuntimeError) as excinfo: + client._check_response(resp) + assert excinfo.value.error_type == "HTTP 502" + + def test_empty_body_uses_reason(self) -> None: + """Empty body falls back to the response reason string.""" + client = AgentCoreRuntimeHttpClient(agent_arn=ARN) + resp = _FakeResponse(status=503, body=b"", reason="Service Unavailable") + with pytest.raises(AgentRuntimeError) as excinfo: + client._check_response(resp) + assert excinfo.value.error == "Service Unavailable" + + +# --------------------------------------------------------------------- # +# _decode_sse_line — all branches +# --------------------------------------------------------------------- # + + +class TestDecodeSseLine: + """Tests for _decode_sse_line.""" + + @pytest.fixture + def client(self) -> AgentCoreRuntimeHttpClient: + return AgentCoreRuntimeHttpClient(agent_arn=ARN) + + def test_empty_line(self, client: AgentCoreRuntimeHttpClient) -> None: + """Empty / whitespace-only lines return None.""" + assert client._decode_sse_line("") is None + assert client._decode_sse_line(" ") is None + + def test_comment_line(self, client: AgentCoreRuntimeHttpClient) -> None: + """SSE comment lines (starting with ':') return None.""" + assert client._decode_sse_line(": keepalive") is None + + def test_data_with_json_encoded_string( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """data: "hello" → hello (unwrapped).""" + assert client._decode_sse_line('data: "hello"') == "hello" + + def test_data_with_plain_text(self, client: AgentCoreRuntimeHttpClient) -> None: + """data: plain → plain.""" + assert client._decode_sse_line("data: plain") == "plain" + + def test_data_with_malformed_json_string( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """data: "bad" but missing closing quote: strip what we can.""" + assert client._decode_sse_line('data: "oops') == '"oops' + + def test_data_with_malformed_quoted_string( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """data: \"truncated\\\\\" with trailing quote but bad escape yields trimmed string.""" + # A string that starts and ends with quotes but contains an invalid escape mid-string. + # json.loads fails; the fallback strips the outer quotes. + assert client._decode_sse_line('data: "bad\\escape"') == "bad\\escape" + + def test_data_with_json_error_raises( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """data: { "error": ... } raises AgentRuntimeError.""" + with pytest.raises(AgentRuntimeError) as excinfo: + client._decode_sse_line('data: {"error": "boom", "error_type": "T"}') + assert excinfo.value.error == "boom" + + def test_data_with_json_object_no_error( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """data: {...} without an error key passes through as the raw JSON content.""" + assert client._decode_sse_line('data: {"foo": "bar"}') == '{"foo": "bar"}' + + def test_data_with_broken_json_passes_through( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """data: { broken passes through as the content after 'data:'.""" + assert client._decode_sse_line("data: {not json") == "{not json" + + def test_plain_json_text_key(self, client: AgentCoreRuntimeHttpClient) -> None: + """Plain JSON line with 'text' key returns that value.""" + assert client._decode_sse_line('{"text": "hi"}') == "hi" + + def test_plain_json_content_key(self, client: AgentCoreRuntimeHttpClient) -> None: + """Plain JSON line with 'content' key returns that value.""" + assert client._decode_sse_line('{"content": "c"}') == "c" + + def test_plain_json_data_key(self, client: AgentCoreRuntimeHttpClient) -> None: + """Plain JSON line with 'data' key returns str(that value).""" + assert client._decode_sse_line('{"data": 42}') == "42" + + def test_plain_json_error_raises(self, client: AgentCoreRuntimeHttpClient) -> None: + """Plain JSON line with 'error' raises.""" + with pytest.raises(AgentRuntimeError): + client._decode_sse_line('{"error": "bad"}') + + def test_plain_json_unknown_shape_returns_none( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """Plain JSON dict without known keys returns None.""" + assert client._decode_sse_line('{"something": "else"}') is None + + def test_plain_text_passes_through( + self, client: AgentCoreRuntimeHttpClient + ) -> None: + """Non-JSON plain text is returned verbatim.""" + assert client._decode_sse_line("hello world") == "hello world" + + +# --------------------------------------------------------------------- # +# _iter_lines +# --------------------------------------------------------------------- # + + +class TestIterLines: + """Tests for _iter_lines chunk splitting.""" + + def test_splits_on_lf(self) -> None: + """Splits on newline and yields each line.""" + resp = _FakeResponse(chunks=[b"one\ntwo\nthree\n"]) + lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) + assert lines == [b"one", b"two", b"three"] + + def test_strips_crlf(self) -> None: + """Trailing \\r is stripped.""" + resp = _FakeResponse(chunks=[b"a\r\nb\r\n"]) + lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) + assert lines == [b"a", b"b"] + + def test_trailing_without_newline(self) -> None: + """A trailing fragment without newline is still yielded.""" + resp = _FakeResponse(chunks=[b"done"]) + lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) + assert lines == [b"done"] + + def test_trailing_cr_stripped(self) -> None: + """A trailing CR without LF is stripped.""" + resp = _FakeResponse(chunks=[b"done\r"]) + lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) + assert lines == [b"done"] + + def test_split_across_chunks(self) -> None: + """Lines spanning chunk boundaries are reassembled.""" + resp = _FakeResponse(chunks=[b"hel", b"lo\nwo", b"rld\n"]) + lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) + assert lines == [b"hello", b"world"] + + def test_empty_chunk_skipped(self) -> None: + """Empty chunks in the stream are ignored.""" + resp = _FakeResponse(chunks=[b"a\n", b"", b"b\n"]) + lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) + assert lines == [b"a", b"b"] + + +# --------------------------------------------------------------------- # +# AgentRuntimeError +# --------------------------------------------------------------------- # + + +class TestAgentRuntimeError: + """Tests for AgentRuntimeError construction and rendering.""" + + def test_stores_error_and_type(self) -> None: + """error and error_type are accessible as attributes.""" + exc = AgentRuntimeError(error="boom", error_type="T") + assert exc.error == "boom" + assert exc.error_type == "T" + + def test_default_error_type_empty(self) -> None: + """error_type defaults to the empty string.""" + exc = AgentRuntimeError(error="boom") + assert exc.error_type == "" + + def test_str_uses_message(self) -> None: + """str() prefers the message when provided.""" + exc = AgentRuntimeError(error="e", message="human-readable") + assert str(exc) == "human-readable" + + def test_str_falls_back_to_error(self) -> None: + """str() falls back to the error token when no message.""" + exc = AgentRuntimeError(error="only-error") + assert str(exc) == "only-error" + + def test_is_exception(self) -> None: + """Subclasses Exception so it can be raised.""" + try: + raise AgentRuntimeError(error="x") + except Exception as exc: + assert isinstance(exc, AgentRuntimeError) From 1fe1dbe198a0d0f74bac99db52913a7c912091d5 Mon Sep 17 00:00:00 2001 From: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:38:50 -0400 Subject: [PATCH 2/2] feat(runtime): update AgentCoreRuntimeClient for HTTP invocation and execute_command --- CHANGELOG.md | 4 +- src/bedrock_agentcore/runtime/__init__.py | 11 +- .../runtime/agent_core_runtime_client.py | 822 ++++++++++++- .../runtime/agent_core_runtime_http_client.py | 733 ------------ .../runtime/test_agent_core_runtime_client.py | 1028 ++++++++++++++++- .../test_agent_core_runtime_http_client.py | 823 ------------- 6 files changed, 1793 insertions(+), 1628 deletions(-) delete mode 100644 src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py delete mode 100644 tests/unit/runtime/test_agent_core_runtime_http_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c3b918ba..67965e6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,8 +3,8 @@ ## [Unreleased] ### Added -- feat(runtime): add AgentCoreRuntimeHttpClient for bearer-token HTTP invocation, SSE streaming, and the InvokeAgentRuntimeCommand shell-exec API (alongside existing AgentCoreRuntimeClient for WebSocket URLs) -- feat(runtime): add AgentRuntimeError exception raised by the HTTP client on non-2xx responses and in-band SSE error events +- feat(runtime): add HTTP invocation methods (`invoke`, `invoke_streaming`, `invoke_streaming_async`) and `InvokeAgentRuntimeCommand` shell-exec support (`execute_command`, `execute_command_streaming`) to `AgentCoreRuntimeClient`, plus `stop_runtime_session`. Methods take `runtime_arn` and `bearer_token` per-call, matching `generate_ws_connection_oauth`. The `urllib3.PoolManager` is lazy-initialized so SigV4 URL-generation users pay no cost. +- feat(runtime): add `AgentRuntimeError` exception raised by the HTTP methods on non-2xx responses and in-band SSE error events. ## [1.6.3] - 2026-04-16 diff --git a/src/bedrock_agentcore/runtime/__init__.py b/src/bedrock_agentcore/runtime/__init__.py index df47df85..5763b7ab 100644 --- a/src/bedrock_agentcore/runtime/__init__.py +++ b/src/bedrock_agentcore/runtime/__init__.py @@ -4,22 +4,17 @@ - BedrockAgentCoreApp: Main application class - RequestContext: HTTP request context - BedrockAgentCoreContext: Agent identity context -- AgentCoreRuntimeHttpClient: Bearer-token HTTP client for invoking a deployed runtime -- AgentRuntimeError: Exception raised by the HTTP client +- AgentCoreRuntimeClient: Authentication + HTTP invocation client for deployed runtimes +- AgentRuntimeError: Exception raised by the HTTP invocation methods """ -from .agent_core_runtime_client import AgentCoreRuntimeClient -from .agent_core_runtime_http_client import ( - AgentCoreRuntimeHttpClient, - AgentRuntimeError, -) +from .agent_core_runtime_client import AgentCoreRuntimeClient, AgentRuntimeError from .app import BedrockAgentCoreApp from .context import BedrockAgentCoreContext, RequestContext from .models import PingStatus __all__ = [ "AgentCoreRuntimeClient", - "AgentCoreRuntimeHttpClient", "AgentRuntimeError", "AGUIApp", "BedrockAgentCoreApp", diff --git a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py index ab91de5a..106df03a 100644 --- a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py +++ b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py @@ -1,34 +1,110 @@ -"""Client for generating WebSocket authentication for AgentCore Runtime. +"""Client for AgentCore Runtime authentication and invocation. -This module provides a client for generating authentication credentials -for WebSocket connections to AgentCore Runtime endpoints. +This module provides a single client for interacting with Bedrock AgentCore +Runtime endpoints: + +- WebSocket URL and header generation (SigV4, SigV4 presigned, OAuth bearer). +- HTTP invocation (blocking, sync streaming, async streaming) with bearer-token + auth. +- ``InvokeAgentRuntimeCommand`` shell-exec support (blocking and streaming). +- Session termination. + +Bearer-token HTTP methods mirror the shape of +:meth:`generate_ws_connection_oauth` — ``runtime_arn`` and ``bearer_token`` +are passed per call. The ``urllib3.PoolManager`` used by HTTP methods is +lazy-initialized; callers who only use URL-generation methods pay no cost. """ +import asyncio import base64 import datetime +import json import logging import secrets +import threading import uuid -from typing import Dict, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Generator, + Iterator, + Optional, + Tuple, +) from urllib.parse import quote, urlencode, urlparse import boto3 from botocore.auth import SigV4Auth, SigV4QueryAuth from botocore.awsrequest import AWSRequest +from botocore.eventstream import EventStreamBuffer from .._utils.endpoints import get_data_plane_endpoint from .utils import is_valid_partition +if TYPE_CHECKING: + # urllib3 v2 returns ``BaseHTTPResponse`` from ``PoolManager.request``. + # On urllib3 v1.26 the returned object is ``HTTPResponse`` which is + # structurally compatible with the same attributes (``status``, ``headers``, + # ``read``, ``stream``, ``release_conn``), so the annotation is purely for + # type-checking under v2. + import urllib3 + from urllib3 import BaseHTTPResponse as _UrllibResponse + DEFAULT_PRESIGNED_URL_TIMEOUT = 300 MAX_PRESIGNED_URL_TIMEOUT = 300 +DEFAULT_HTTP_TIMEOUT = 300 +DEFAULT_COMMAND_TIMEOUT = 600 + + +class AgentRuntimeError(Exception): + """Raised when an AgentCore runtime returns an error response. + + Used by the HTTP invocation methods of :class:`AgentCoreRuntimeClient` + for both non-2xx HTTP responses and in-band error events embedded in + SSE streams. + + Attributes: + error: Short machine-readable error token (for example the + runtime's ``error`` field, or the HTTP status text). + error_type: Category label. For HTTP failures this is + ``"HTTP "`` (e.g. ``"HTTP 404"``). For runtime error + payloads it is whatever the server set on the ``error_type`` + field. + """ + + def __init__(self, error: str, error_type: str = "", message: str = "") -> None: + """Initialize the exception. + + Args: + error: Short error token (see :attr:`error`). + error_type: Category label (see :attr:`error_type`). Defaults + to the empty string. + message: Human-readable message used as the exception's + string representation. Defaults to ``error`` when empty. + """ + self.error = error + self.error_type = error_type + super().__init__(message or error) class AgentCoreRuntimeClient: - """Client for generating WebSocket authentication for AgentCore Runtime. + """Client for AgentCore Runtime authentication and invocation. + + This client supports four auth/transport modes against AgentCore Runtime: + + - SigV4-signed WebSocket URL + headers (:meth:`generate_ws_connection`). + - SigV4 presigned WebSocket URL (:meth:`generate_presigned_url`). + - OAuth bearer-token WebSocket URL + headers + (:meth:`generate_ws_connection_oauth`). + - OAuth bearer-token HTTP invocation (:meth:`invoke`, + :meth:`invoke_streaming`, :meth:`invoke_streaming_async`, + :meth:`execute_command`, :meth:`execute_command_streaming`, + :meth:`stop_runtime_session`). - This client provides authentication credentials for WebSocket connections - to AgentCore Runtime endpoints, allowing applications to establish - bidirectional streaming connections with agent runtimes. + The ``urllib3.PoolManager`` used by the HTTP methods is lazy-initialized + on first access; callers who only use URL-generation methods pay no cost. Attributes: region (str): The AWS region being used. @@ -53,6 +129,24 @@ def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None session = boto3.Session() self.session = session + self._pool_manager: Optional["urllib3.PoolManager"] = None + + @property + def _http(self) -> "urllib3.PoolManager": + """Return a lazy-initialized ``urllib3.PoolManager`` shared across HTTP calls. + + The pool is not created until an HTTP method (``invoke``, + ``execute_command``, etc.) is actually called. Callers that only use + the SigV4 / OAuth URL-generation methods never trigger this import. + + Returns: + A shared :class:`urllib3.PoolManager` instance. + """ + if self._pool_manager is None: + import urllib3 + + self._pool_manager = urllib3.PoolManager() + return self._pool_manager def _parse_runtime_arn(self, runtime_arn: str) -> Dict[str, str]: """Parse runtime ARN and extract components. @@ -72,7 +166,11 @@ def _parse_runtime_arn(self, runtime_arn: str) -> Dict[str, str]: if len(parts) != 6: raise ValueError(f"Invalid runtime ARN format: {runtime_arn}") - if parts[0] != "arn" or not is_valid_partition(parts[1]) or parts[2] != "bedrock-agentcore": + if ( + parts[0] != "arn" + or not is_valid_partition(parts[1]) + or parts[2] != "bedrock-agentcore" + ): raise ValueError(f"Invalid runtime ARN format: {runtime_arn}") # Parse the resource part (runtime/{runtime_id}) @@ -138,6 +236,88 @@ def _build_websocket_url( return ws_url + def _build_http_url( + self, + runtime_arn: str, + path_suffix: str, + endpoint_name: Optional[str] = None, + ) -> str: + """Build an HTTPS URL for a data-plane API on a runtime. + + Reuses :meth:`_parse_runtime_arn` to extract the region from the ARN + and :func:`get_data_plane_endpoint` to construct the host, matching + the convention established by the WebSocket URL builder. + + Args: + runtime_arn: Full runtime ARN. + path_suffix: Path under ``/runtimes/{arn}/``. Must not start with + ``/``. Examples: ``"invocations"``, ``"commands"``, + ``"stopruntimesession"``. + endpoint_name: Endpoint qualifier sent as the ``qualifier`` query + parameter. Defaults to ``"DEFAULT"`` when not supplied so the + API receives a valid qualifier value. + + Returns: + Absolute URL including the ``?qualifier=...`` query string. + + Raises: + ValueError: If the ARN format is invalid. + """ + parsed = self._parse_runtime_arn(runtime_arn) + base = get_data_plane_endpoint(parsed["region"]) + encoded_arn = quote(runtime_arn, safe="") + query = urlencode({"qualifier": endpoint_name or "DEFAULT"}) + return f"{base}/runtimes/{encoded_arn}/{path_suffix}?{query}" + + def _build_bearer_headers( + self, + bearer_token: str, + session_id: str, + accept: str, + content_type: str, + ) -> Dict[str, str]: + """Build request headers for a bearer-authenticated HTTP call. + + Args: + bearer_token: OAuth/JWT bearer token. + session_id: Runtime session id. + accept: Value for the ``Accept`` header. + content_type: Value for the ``Content-Type`` header. + + Returns: + Header dict including ``Authorization``, ``Content-Type``, + ``Accept``, and ``X-Amzn-Bedrock-AgentCore-Runtime-Session-Id``. + """ + return { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": content_type, + "Accept": accept, + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, + } + + @staticmethod + def _serialize_body(body: Any, content_type: str) -> bytes: + """Serialize a request body to bytes based on the content type. + + JSON content types serialize via :func:`json.dumps`. For other + content types, ``bytes`` pass through unchanged, ``str`` is + UTF-8-encoded, and anything else falls back to JSON serialization. + + Args: + body: The payload to send. + content_type: The MIME type of the request. + + Returns: + UTF-8 encoded request body. + """ + if "json" in content_type: + return json.dumps(body).encode("utf-8") + if isinstance(body, bytes): + return body + if isinstance(body, str): + return body.encode("utf-8") + return json.dumps(body).encode("utf-8") + def generate_ws_connection( self, runtime_arn: str, @@ -200,7 +380,9 @@ def generate_ws_connection( url=https_url, headers={ "host": host, - "x-amz-date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ"), + "x-amz-date": datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y%m%dT%H%M%SZ" + ), }, ) @@ -225,7 +407,9 @@ def generate_ws_connection( if frozen_credentials.token: headers["X-Amz-Security-Token"] = frozen_credentials.token - self.logger.info("✓ WebSocket connection credentials generated (Session: %s)", session_id) + self.logger.info( + "✓ WebSocket connection credentials generated (Session: %s)", session_id + ) return ws_url, headers def generate_presigned_url( @@ -273,7 +457,9 @@ def generate_presigned_url( # Validate expires parameter if expires > MAX_PRESIGNED_URL_TIMEOUT: - raise ValueError(f"Expiry timeout cannot exceed {MAX_PRESIGNED_URL_TIMEOUT} seconds, got {expires}") + raise ValueError( + f"Expiry timeout cannot exceed {MAX_PRESIGNED_URL_TIMEOUT} seconds, got {expires}" + ) # Validate ARN self._parse_runtime_arn(runtime_arn) @@ -305,7 +491,9 @@ def generate_presigned_url( frozen_credentials = credentials.get_frozen_credentials() # Create the request to sign - request = AWSRequest(method="GET", url=https_url, headers={"host": url.hostname}) + request = AWSRequest( + method="GET", url=https_url, headers={"host": url.hostname} + ) # Sign the request with SigV4QueryAuth signer = SigV4QueryAuth( @@ -322,7 +510,11 @@ def generate_presigned_url( # Convert back to wss:// for WebSocket connection presigned_url = request.url.replace("https://", "wss://") - self.logger.info("✓ Presigned URL generated (expires in %s seconds, Session: %s)", expires, session_id) + self.logger.info( + "✓ Presigned URL generated (expires in %s seconds, Session: %s)", + expires, + session_id, + ) return presigned_url def generate_ws_connection_oauth( @@ -397,7 +589,607 @@ def generate_ws_connection_oauth( "User-Agent": "OAuth-WebSocket-Client/1.0", } - self.logger.info("✓ OAuth WebSocket connection credentials generated (Session: %s)", session_id) + self.logger.info( + "✓ OAuth WebSocket connection credentials generated (Session: %s)", + session_id, + ) self.logger.debug("Bearer token length: %d characters", len(bearer_token)) return ws_url, headers + + # ------------------------------------------------------------------ # + # HTTP invocation (bearer auth) + # ------------------------------------------------------------------ # + + def invoke( + self, + runtime_arn: str, + bearer_token: str, + body: Any, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + content_type: str = "application/json", + accept: str = "application/json", + ) -> str: + """Invoke the runtime over HTTP and return the full response body. + + Handles both JSON and Server-Sent Events responses transparently. For + SSE responses the decoded event data is concatenated in arrival order. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + body: Request payload (JSON-serializable for JSON content types, + ``bytes`` or ``str`` for others). + session_id: Runtime session id. Auto-generated as a UUID4 when + not supplied, matching :meth:`generate_ws_connection_oauth`. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"`` on + the wire when not supplied. + headers: Optional extra headers merged over the auth headers. + timeout: HTTP read timeout in seconds. + content_type: Request ``Content-Type``. + accept: Request ``Accept``. + + Returns: + The complete response body as a string. For JSON responses + whose body is a JSON-encoded string, the unwrapped string is + returned. + + Raises: + AgentRuntimeError: On non-2xx responses or in-band SSE error + events. + ValueError: If the ARN format is invalid. + """ + if not session_id: + session_id = str(uuid.uuid4()) + url = self._build_http_url(runtime_arn, "invocations", endpoint_name) + request_headers = self._build_bearer_headers( + bearer_token, session_id, accept, content_type + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + url, + headers=request_headers, + body=self._serialize_body(body, content_type), + timeout=timeout, + preload_content=False, + ) + try: + self._check_response(response) + response_content_type = response.headers.get("content-type", "") + if "text/event-stream" not in response_content_type: + return self._read_non_streaming(response) + return "".join(self._iter_sse_decoded(response)) + finally: + response.release_conn() + + def invoke_streaming( + self, + runtime_arn: str, + bearer_token: str, + body: Any, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + content_type: str = "application/json", + accept: str = "application/json", + ) -> Generator[str, None, None]: + """Invoke the runtime and yield SSE chunks as they arrive. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + body: Request payload. + session_id: Runtime session id. Auto-generated if not supplied. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + timeout: HTTP read timeout in seconds. + content_type: Request ``Content-Type``. + accept: Request ``Accept``. + + Yields: + Decoded payload strings from the SSE stream. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the stream. + ValueError: If the ARN format is invalid. + """ + if not session_id: + session_id = str(uuid.uuid4()) + url = self._build_http_url(runtime_arn, "invocations", endpoint_name) + request_headers = self._build_bearer_headers( + bearer_token, session_id, accept, content_type + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + url, + headers=request_headers, + body=self._serialize_body(body, content_type), + timeout=timeout, + preload_content=False, + ) + try: + self._check_response(response) + yield from self._iter_sse_decoded(response) + finally: + response.release_conn() + + async def invoke_streaming_async( + self, + runtime_arn: str, + bearer_token: str, + body: Any, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + content_type: str = "application/json", + accept: str = "application/json", + ) -> AsyncGenerator[str, None]: + """Async generator version of :meth:`invoke_streaming`. + + The underlying HTTP call is blocking; this wrapper runs it on a + background thread and delivers chunks through an :class:`asyncio.Queue`, + so it is safe to ``async for`` over. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + body: Request payload. + session_id: Runtime session id. Auto-generated if not supplied. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + timeout: HTTP read timeout in seconds. + content_type: Request ``Content-Type``. + accept: Request ``Accept``. + + Yields: + Decoded payload strings from the SSE stream. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + or an error event in the stream. + ValueError: If the ARN format is invalid. + """ + chunk_queue: asyncio.Queue = asyncio.Queue() + sentinel = object() + loop = asyncio.get_running_loop() + + def stream_in_thread() -> None: + try: + for decoded in self.invoke_streaming( + runtime_arn, + bearer_token, + body, + session_id=session_id, + endpoint_name=endpoint_name, + headers=headers, + timeout=timeout, + content_type=content_type, + accept=accept, + ): + loop.call_soon_threadsafe(chunk_queue.put_nowait, decoded) + loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) + except Exception as exc: # noqa: BLE001 — propagated to caller + loop.call_soon_threadsafe(chunk_queue.put_nowait, exc) + loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) + + thread = threading.Thread(target=stream_in_thread, daemon=True) + thread.start() + + while True: + item = await chunk_queue.get() + if item is sentinel: + break + if isinstance(item, Exception): + raise item + yield item + await asyncio.sleep(0) + + # ------------------------------------------------------------------ # + # execute_command / execute_command_streaming + # ------------------------------------------------------------------ # + + def execute_command( + self, + runtime_arn: str, + bearer_token: str, + command: str, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + command_timeout: int = DEFAULT_COMMAND_TIMEOUT, + ) -> Dict[str, Any]: + """Run a shell command inside the runtime session and collect the full result. + + Backed by the ``InvokeAgentRuntimeCommand`` API. Blocking; accumulates + all ``stdout`` and ``stderr`` chunks from the EventStream and returns + the final exit status. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + command: Shell command to run (1 B – 64 KB per the + ``InvokeAgentRuntimeCommand`` API). + session_id: Runtime session id. Auto-generated if not supplied. + The filesystem inside the container persists across calls in + the same session, but a fresh shell is spawned each time, so + working directory and environment variables do not. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + command_timeout: Server-side command wall-clock timeout in + seconds (1–3600). The HTTP read timeout is derived internally + as ``command_timeout + 30``. + + Returns: + Dict with keys ``"stdout"`` (str), ``"stderr"`` (str), + ``"exitCode"`` (int, ``-1`` if no ``contentStop`` was received), + and ``"status"`` (``"COMPLETED"``, ``"TIMED_OUT"``, or + ``"UNKNOWN"``). + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status. + ValueError: If the ARN format is invalid. + """ + stdout_parts: list = [] + stderr_parts: list = [] + exit_code: int = -1 + status: str = "UNKNOWN" + + for event in self.execute_command_streaming( + runtime_arn, + bearer_token, + command, + session_id=session_id, + endpoint_name=endpoint_name, + headers=headers, + command_timeout=command_timeout, + ): + if "contentDelta" in event: + delta = event["contentDelta"] + if "stdout" in delta: + stdout_parts.append(delta["stdout"]) + if "stderr" in delta: + stderr_parts.append(delta["stderr"]) + elif "contentStop" in event: + exit_code = int(event["contentStop"].get("exitCode", -1)) + status = str(event["contentStop"].get("status", "UNKNOWN")) + + return { + "stdout": "".join(stdout_parts), + "stderr": "".join(stderr_parts), + "exitCode": exit_code, + "status": status, + } + + def execute_command_streaming( + self, + runtime_arn: str, + bearer_token: str, + command: str, + *, + session_id: Optional[str] = None, + endpoint_name: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + command_timeout: int = DEFAULT_COMMAND_TIMEOUT, + ) -> Generator[Dict[str, Any], None, None]: + """Stream AWS EventStream events from ``InvokeAgentRuntimeCommand``. + + Yields the decoded event payloads (the value inside the server's + ``"chunk"`` envelope). Each yielded dict has exactly one of the keys + ``"contentStart"``, ``"contentDelta"``, or ``"contentStop"``. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + command: Shell command to run. + session_id: Runtime session id. Auto-generated if not supplied. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + headers: Optional extra headers. + command_timeout: Server-side wall-clock timeout in seconds + (1–3600). + + Yields: + Parsed event payload dicts. + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status. + ValueError: If the ARN format is invalid. + """ + if not session_id: + session_id = str(uuid.uuid4()) + + request_headers = self._build_bearer_headers( + bearer_token, + session_id, + accept="application/vnd.amazon.eventstream", + content_type="application/json", + ) + if headers: + request_headers.update(headers) + + response = self._http.request( + "POST", + self._build_http_url(runtime_arn, "commands", endpoint_name), + headers=request_headers, + body=json.dumps({"command": command, "timeout": command_timeout}).encode( + "utf-8" + ), + timeout=command_timeout + 30, + preload_content=False, + ) + try: + self._check_response(response) + + buf = EventStreamBuffer() + for chunk in response.stream(4096): + if not chunk: + continue + buf.add_data(chunk) + for event in buf: + payload = event.payload + if not payload: + continue + try: + decoded = json.loads(payload) + except json.JSONDecodeError: + continue + inner = decoded.get("chunk") if isinstance(decoded, dict) else None + yield inner if isinstance(inner, dict) else decoded + finally: + response.release_conn() + + # ------------------------------------------------------------------ # + # stop_runtime_session + # ------------------------------------------------------------------ # + + def stop_runtime_session( + self, + runtime_arn: str, + bearer_token: str, + *, + session_id: str, + endpoint_name: Optional[str] = None, + client_token: Optional[str] = None, + timeout: int = DEFAULT_HTTP_TIMEOUT, + ) -> Dict[str, Any]: + """Terminate a runtime session via HTTP. + + Args: + runtime_arn: Full runtime ARN. + bearer_token: Bearer token for authentication. + session_id: The session id to stop. + endpoint_name: Endpoint qualifier. Defaults to ``"DEFAULT"``. + client_token: Idempotency token. Auto-generated as a UUID4 + when not supplied. + timeout: HTTP read timeout in seconds. + + Returns: + Parsed JSON body of the response (often an empty dict). + + Raises: + AgentRuntimeError: If the runtime returns a non-2xx status + (for example ``HTTP 404`` for an unknown session). + ValueError: If the ARN format is invalid. + """ + request_headers = { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": "application/json", + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, + } + response = self._http.request( + "POST", + self._build_http_url(runtime_arn, "stopruntimesession", endpoint_name), + headers=request_headers, + body=json.dumps({"clientToken": client_token or str(uuid.uuid4())}).encode( + "utf-8" + ), + timeout=timeout, + preload_content=True, + ) + self._check_response(response) + if not response.data: + return {} + try: + parsed = json.loads(response.data) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {"response": parsed} + + # ------------------------------------------------------------------ # + # Response parsing helpers + # ------------------------------------------------------------------ # + + @staticmethod + def _check_response(response: "_UrllibResponse") -> None: + """Raise :class:`AgentRuntimeError` for non-2xx responses. + + Attempts to parse the body as JSON and surface ``error``, + ``error_type``, and ``message`` fields. Falls back to the raw body + text. + + Args: + response: The urllib3 response to inspect. + + Raises: + AgentRuntimeError: If the status code is 400 or above. + """ + if response.status < 400: + return + + body_bytes = response.read() + try: + body: Any = json.loads(body_bytes) + except (json.JSONDecodeError, ValueError): + body = body_bytes.decode("utf-8", errors="replace") + + error_type = f"HTTP {response.status}" + reason = response.reason or error_type + + if isinstance(body, dict): + raise AgentRuntimeError( + error=str(body.get("error", reason)), + error_type=str(body.get("error_type", error_type)), + message=str( + body.get("message", body_bytes.decode("utf-8", errors="replace")) + ), + ) + raise AgentRuntimeError( + error=str(body) or reason, + error_type=error_type, + ) + + @staticmethod + def _read_non_streaming(response: "_UrllibResponse") -> str: + """Read a fully-buffered response body and unwrap JSON strings. + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Returns: + The response body as text. If the body is a JSON-encoded + string (e.g. ``"hello"``), the unwrapped value is returned. + """ + data = response.read() + text = data.decode("utf-8", errors="replace") + try: + parsed = json.loads(text) + except (json.JSONDecodeError, ValueError): + return text + return parsed if isinstance(parsed, str) else text + + @classmethod + def _iter_sse_decoded(cls, response: "_UrllibResponse") -> Iterator[str]: + """Iterate decoded SSE payloads from a streaming response. + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Yields: + Decoded payload strings (one per non-empty ``data:`` line + or JSON line). + """ + for raw_line in cls._iter_lines(response): + if not raw_line: + continue + decoded = cls._decode_sse_line(raw_line.decode("utf-8", errors="replace")) + if decoded: + yield decoded + + @staticmethod + def _iter_lines(response: "_UrllibResponse") -> Iterator[bytes]: + r"""Yield lines from a streaming urllib3 response. + + Splits on ``\n`` and strips a trailing ``\r`` so it handles both + LF and CRLF line endings. Preserves empty lines (the caller + filters them). + + Args: + response: The urllib3 response, opened with + ``preload_content=False``. + + Yields: + Each line as bytes (with no trailing newline). + """ + pending = b"" + for chunk in response.stream(1024): + if not chunk: + continue + pending += chunk + while True: + idx = pending.find(b"\n") + if idx < 0: + break + line, pending = pending[:idx], pending[idx + 1 :] + if line.endswith(b"\r"): + line = line[:-1] + yield line + if pending: + if pending.endswith(b"\r"): + pending = pending[:-1] + yield pending + + @staticmethod + def _decode_sse_line(line: str) -> Optional[str]: + """Decode a single SSE or JSON-Lines line. + + Handles SSE ``data:`` lines, JSON error envelopes, JSON-encoded + strings, and plain text. Error envelopes (with an ``error`` key) + are raised as :class:`AgentRuntimeError`. + + Args: + line: A single line of the streamed response. + + Returns: + The decoded payload, or ``None`` if the line is a comment, + empty, or carries no renderable text. + + Raises: + AgentRuntimeError: If the line carries a JSON error payload. + """ + line = line.strip() + if not line or line.startswith(":"): + return None + + if line.startswith("data:"): + content = line[5:].strip() + + if content.startswith("{"): + try: + data = json.loads(content) + if isinstance(data, dict) and "error" in data: + raise AgentRuntimeError( + error=str(data["error"]), + error_type=str(data.get("error_type", "")), + message=str(data.get("message", data["error"])), + ) + except json.JSONDecodeError: + pass + + if content.startswith('"'): + try: + unwrapped = json.loads(content) + except json.JSONDecodeError: + if content.endswith('"'): + return content[1:-1] + return content + return unwrapped if isinstance(unwrapped, str) else content + return content + + try: + data = json.loads(line) + except json.JSONDecodeError: + return line + + if isinstance(data, dict): + if "error" in data: + raise AgentRuntimeError( + error=str(data["error"]), + error_type=str(data.get("error_type", "")), + message=str(data.get("message", data["error"])), + ) + if "text" in data: + value = data["text"] + return value if isinstance(value, str) else str(value) + if "content" in data: + value = data["content"] + return value if isinstance(value, str) else str(value) + if "data" in data: + return str(data["data"]) + return None diff --git a/src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py b/src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py deleted file mode 100644 index 45fd7d49..00000000 --- a/src/bedrock_agentcore/runtime/agent_core_runtime_http_client.py +++ /dev/null @@ -1,733 +0,0 @@ -"""HTTP client for invoking a deployed Bedrock AgentCore runtime. - -Complements :class:`AgentCoreRuntimeClient` (which only builds WebSocket URLs -and headers) by providing bearer-token HTTP invocation, SSE streaming, and the -``InvokeAgentRuntimeCommand`` API for running shell commands inside a runtime -session. - -Wire formats targeted: - -- ``POST /runtimes/{arn}/invocations`` — agent invocation. Response is either - a JSON document or Server-Sent Events (``text/event-stream``). -- ``POST /runtimes/{arn}/commands`` — ``InvokeAgentRuntimeCommand``. Response - is the AWS EventStream binary framing - (``application/vnd.amazon.eventstream``). Each event's payload is JSON - wrapped under a ``chunk`` key containing one of ``contentStart``, - ``contentDelta {stdout, stderr}``, or - ``contentStop {exitCode, status}``. -- ``POST /runtimes/{arn}/stopruntimesession`` — session termination. -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import threading -import urllib.parse -import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Iterator, Optional - -import urllib3 -from botocore.eventstream import EventStreamBuffer - -if TYPE_CHECKING: - # urllib3 v2 returns ``BaseHTTPResponse`` from ``PoolManager.request``. - # On urllib3 v1.26 the returned object is ``HTTPResponse`` which is - # structurally compatible with the same attributes (``status``, ``headers``, - # ``read``, ``stream``, ``release_conn``), so the annotation is purely for - # type-checking under v2. - from urllib3 import BaseHTTPResponse as _UrllibResponse - -logger = logging.getLogger(__name__) - - -class AgentRuntimeError(Exception): - """Raised when an AgentCore runtime returns an error response. - - Used by :class:`AgentCoreRuntimeHttpClient` for both non-2xx HTTP - responses and in-band error events embedded in SSE streams. - - Attributes: - error: Short machine-readable error token (for example the - runtime's ``error`` field, or the HTTP status text). - error_type: Category label. For HTTP failures this is - ``"HTTP "`` (e.g. ``"HTTP 404"``). For runtime error - payloads it is whatever the server set on the ``error_type`` - field. - """ - - def __init__(self, error: str, error_type: str = "", message: str = "") -> None: - """Initialize the exception. - - Args: - error: Short error token (see :attr:`error`). - error_type: Category label (see :attr:`error_type`). Defaults - to the empty string. - message: Human-readable message used as the exception's - string representation. Defaults to ``error`` when empty. - """ - self.error = error - self.error_type = error_type - super().__init__(message or error) - - -class AgentCoreRuntimeHttpClient: - """HTTP client for invoking a deployed Bedrock AgentCore runtime. - - Use this client when you need bearer-token authentication (JWT/OAuth) - instead of IAM/SigV4. It supports blocking invocation, synchronous and - asynchronous streaming, shell command execution via - ``InvokeAgentRuntimeCommand``, and session termination. - - Each method takes the ``bearer_token`` per-call so the same client can - be reused with rotating credentials. - - Attributes: - agent_arn: Full ARN of the target agent runtime. - region: AWS region extracted from ``agent_arn``. - endpoint_name: Endpoint qualifier (defaults to ``"DEFAULT"``). - timeout: Default HTTP read timeout in seconds for non-command - methods. ``execute_command`` derives its HTTP timeout from - ``command_timeout`` instead (see that method). - content_type: Request ``Content-Type`` for :meth:`invoke` and - :meth:`invoke_streaming`. - accept: Request ``Accept`` for :meth:`invoke` and - :meth:`invoke_streaming`. - """ - - def __init__( - self, - agent_arn: str, - endpoint_name: str = "DEFAULT", - timeout: int = 300, - content_type: str = "application/json", - accept: str = "application/json", - pool_manager: Optional[urllib3.PoolManager] = None, - ) -> None: - """Initialize the HTTP client. - - Args: - agent_arn: The ARN of the agent runtime to invoke. The AWS - region is extracted from the ARN automatically. - endpoint_name: Endpoint qualifier sent as the ``qualifier`` - query parameter. Defaults to ``"DEFAULT"``. - timeout: Default HTTP read timeout in seconds (used by - :meth:`invoke`, :meth:`invoke_streaming`, - :meth:`invoke_streaming_async`, and - :meth:`stop_runtime_session`). ``execute_command`` derives - its HTTP timeout internally from ``command_timeout``. - content_type: MIME type of the request payload for invocation - calls. - accept: Desired MIME type for invocation responses. - pool_manager: Optional pre-configured - :class:`urllib3.PoolManager`. Primarily useful for tests - and for callers who want to control connection pooling. - A fresh manager is created when not provided. - - Raises: - ValueError: If the ARN does not contain a parseable region - component. - """ - parts = agent_arn.split(":") - if len(parts) < 4 or not parts[3]: - raise ValueError(f"Invalid agent ARN (missing region): {agent_arn}") - - self.agent_arn = agent_arn - self.region = parts[3] - self.endpoint_name = endpoint_name - self.timeout = timeout - self.content_type = content_type - self.accept = accept - self._http: urllib3.PoolManager = pool_manager or urllib3.PoolManager() - - # ------------------------------------------------------------------ # - # URL, headers, body helpers - # ------------------------------------------------------------------ # - - def _build_url(self, path_suffix: str) -> str: - """Build a full runtime URL. - - Args: - path_suffix: Path under ``/runtimes/{arn}/``. Must not start - with ``/``. Examples: ``"invocations"``, ``"commands"``, - ``"stopruntimesession"``. - - Returns: - Absolute URL including the qualifier query string. - """ - escaped_arn = urllib.parse.quote(self.agent_arn, safe="") - base = f"https://bedrock-agentcore.{self.region}.amazonaws.com/runtimes/{escaped_arn}/{path_suffix}" - query = urllib.parse.urlencode({"qualifier": self.endpoint_name}) - return f"{base}?{query}" - - def _build_headers( - self, - bearer_token: str, - session_id: str, - accept: Optional[str] = None, - content_type: Optional[str] = None, - ) -> dict[str, str]: - """Build the base request headers. - - Args: - bearer_token: OAuth/JWT bearer token. - session_id: Runtime session id. - accept: Optional override for the ``Accept`` header. - content_type: Optional override for the ``Content-Type`` - header. - - Returns: - Header dict including ``Authorization``, ``Content-Type``, - ``Accept``, and ``X-Amzn-Bedrock-AgentCore-Runtime-Session-Id``. - """ - return { - "Authorization": f"Bearer {bearer_token}", - "Content-Type": content_type or self.content_type, - "Accept": accept or self.accept, - "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, - } - - def _serialize_body(self, body: Any) -> bytes: - """Serialize a request body to bytes. - - JSON content types serialize via :func:`json.dumps`. For other - content types, ``bytes`` pass through unchanged, ``str`` is - UTF-8-encoded, and anything else falls back to JSON serialization. - - Args: - body: The payload to send. - - Returns: - UTF-8 encoded request body. - """ - if "json" in self.content_type: - return json.dumps(body).encode("utf-8") - if isinstance(body, bytes): - return body - if isinstance(body, str): - return body.encode("utf-8") - return json.dumps(body).encode("utf-8") - - # ------------------------------------------------------------------ # - # invoke / invoke_streaming / invoke_streaming_async - # ------------------------------------------------------------------ # - - def invoke( - self, - body: Any, - session_id: str, - bearer_token: str, - headers: Optional[dict[str, str]] = None, - ) -> str: - """Invoke the agent and return the full response body as a string. - - Handles both JSON responses and Server-Sent Events transparently. - For SSE responses, the decoded event data is concatenated in the - order it arrived. - - Args: - body: The request body to send to the agent. - session_id: Session id for conversation continuity. - bearer_token: Bearer token for authentication. - headers: Optional extra headers to include. Overwrite the - defaults on key collision. - - Returns: - The complete response body as a string. For JSON responses - that happen to be a JSON-encoded string, the unwrapped string - value is returned. - - Raises: - AgentRuntimeError: If the runtime returns a non-2xx status - or an error event in the SSE stream. - """ - request_headers = self._build_headers(bearer_token, session_id) - if headers: - request_headers.update(headers) - - response = self._http.request( - "POST", - self._build_url("invocations"), - headers=request_headers, - body=self._serialize_body(body), - timeout=self.timeout, - preload_content=False, - ) - try: - self._check_response(response) - content_type = response.headers.get("content-type", "") - if "text/event-stream" not in content_type: - return self._read_non_streaming(response) - return "".join(self._iter_sse_decoded(response)) - finally: - response.release_conn() - - def invoke_streaming( - self, - body: Any, - session_id: str, - bearer_token: str, - headers: Optional[dict[str, str]] = None, - ) -> Generator[str, None, None]: - """Invoke the agent and yield SSE chunks as they arrive. - - Args: - body: The request body to send to the agent. - session_id: Session id for conversation continuity. - bearer_token: Bearer token for authentication. - headers: Optional extra headers to include. - - Yields: - Decoded payload strings from the SSE stream, one per - non-empty ``data:`` line. - - Raises: - AgentRuntimeError: If the runtime returns a non-2xx status - or an error event in the stream. - """ - request_headers = self._build_headers(bearer_token, session_id) - if headers: - request_headers.update(headers) - - response = self._http.request( - "POST", - self._build_url("invocations"), - headers=request_headers, - body=self._serialize_body(body), - timeout=self.timeout, - preload_content=False, - ) - try: - self._check_response(response) - yield from self._iter_sse_decoded(response) - finally: - response.release_conn() - - async def invoke_streaming_async( - self, - body: Any, - session_id: str, - bearer_token: str, - headers: Optional[dict[str, str]] = None, - ) -> AsyncGenerator[str, None]: - """Async generator version of :meth:`invoke_streaming`. - - The underlying HTTP call is blocking; this wrapper runs it on a - background thread and delivers chunks to the caller through an - :class:`asyncio.Queue`, so it is safe to ``async for`` over. - - Args: - body: The request body to send to the agent. - session_id: Session id for conversation continuity. - bearer_token: Bearer token for authentication. - headers: Optional extra headers to include. - - Yields: - Decoded payload strings from the SSE stream. - - Raises: - AgentRuntimeError: If the runtime returns a non-2xx status - or an error event in the stream. - """ - chunk_queue: asyncio.Queue[Any] = asyncio.Queue() - sentinel = object() - loop = asyncio.get_running_loop() - - def stream_in_thread() -> None: - try: - for decoded in self.invoke_streaming( - body=body, - session_id=session_id, - bearer_token=bearer_token, - headers=headers, - ): - loop.call_soon_threadsafe(chunk_queue.put_nowait, decoded) - loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) - except Exception as exc: # noqa: BLE001 — propagated to caller - loop.call_soon_threadsafe(chunk_queue.put_nowait, exc) - loop.call_soon_threadsafe(chunk_queue.put_nowait, sentinel) - - thread = threading.Thread(target=stream_in_thread, daemon=True) - thread.start() - - while True: - item = await chunk_queue.get() - if item is sentinel: - break - if isinstance(item, Exception): - raise item - yield item - await asyncio.sleep(0) - - # ------------------------------------------------------------------ # - # execute_command / execute_command_streaming - # ------------------------------------------------------------------ # - - def execute_command( - self, - command: str, - session_id: str, - bearer_token: str, - command_timeout: Optional[int] = None, - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - """Run a shell command inside the runtime session and collect the full result. - - Blocking. Accumulates all ``stdout`` and ``stderr`` chunks from - the EventStream and returns the final exit status. - - Args: - command: Shell command to run (1 B – 64 KB per the - ``InvokeAgentRuntimeCommand`` API). - session_id: Runtime session id to target. The filesystem - inside the container persists across calls, but a fresh - shell is spawned each time, so working directory and - environment variables do not. - bearer_token: Bearer token for authentication. - command_timeout: Server-side command wall-clock timeout in - seconds (1–3600). Defaults to :attr:`timeout`. The - HTTP read timeout is derived internally as - ``command_timeout + 30``. - headers: Optional extra headers to include. - - Returns: - Dict with keys ``"stdout"`` (str), ``"stderr"`` (str), - ``"exitCode"`` (int, ``-1`` if no ``contentStop`` was - received), and ``"status"`` (``"COMPLETED"`` or - ``"TIMED_OUT"``, or ``"UNKNOWN"`` if no ``contentStop`` was - received). - - Raises: - AgentRuntimeError: If the runtime returns a non-2xx status. - """ - stdout_parts: list[str] = [] - stderr_parts: list[str] = [] - exit_code: int = -1 - status: str = "UNKNOWN" - - for event in self.execute_command_streaming( - command=command, - session_id=session_id, - bearer_token=bearer_token, - command_timeout=command_timeout, - headers=headers, - ): - if "contentDelta" in event: - delta = event["contentDelta"] - if "stdout" in delta: - stdout_parts.append(delta["stdout"]) - if "stderr" in delta: - stderr_parts.append(delta["stderr"]) - elif "contentStop" in event: - exit_code = int(event["contentStop"].get("exitCode", -1)) - status = str(event["contentStop"].get("status", "UNKNOWN")) - - return { - "stdout": "".join(stdout_parts), - "stderr": "".join(stderr_parts), - "exitCode": exit_code, - "status": status, - } - - def execute_command_streaming( - self, - command: str, - session_id: str, - bearer_token: str, - command_timeout: Optional[int] = None, - headers: Optional[dict[str, str]] = None, - ) -> Generator[dict[str, Any], None, None]: - """Stream AWS EventStream events from ``InvokeAgentRuntimeCommand``. - - Yields the decoded event payloads (the value inside the - server's ``"chunk"`` envelope). Each yielded dict has exactly one - of the keys ``"contentStart"``, ``"contentDelta"``, or - ``"contentStop"``. - - Args: - command: Shell command to run. - session_id: Runtime session id. - bearer_token: Bearer token for authentication. - command_timeout: Server-side wall-clock timeout in seconds - (1–3600). Defaults to :attr:`timeout`. - headers: Optional extra headers to include. - - Yields: - Parsed event payload dicts. - - Raises: - AgentRuntimeError: If the runtime returns a non-2xx status. - """ - effective_timeout = ( - command_timeout if command_timeout is not None else self.timeout - ) - - request_headers = self._build_headers( - bearer_token, - session_id, - accept="application/vnd.amazon.eventstream", - content_type="application/json", - ) - if headers: - request_headers.update(headers) - - response = self._http.request( - "POST", - self._build_url("commands"), - headers=request_headers, - body=json.dumps({"command": command, "timeout": effective_timeout}).encode( - "utf-8" - ), - timeout=effective_timeout + 30, - preload_content=False, - ) - try: - self._check_response(response) - - buf = EventStreamBuffer() - for chunk in response.stream(4096): - if not chunk: - continue - buf.add_data(chunk) - for event in buf: - payload = event.payload - if not payload: - continue - try: - decoded = json.loads(payload) - except json.JSONDecodeError: - continue - inner = decoded.get("chunk") if isinstance(decoded, dict) else None - yield inner if isinstance(inner, dict) else decoded - finally: - response.release_conn() - - # ------------------------------------------------------------------ # - # stop_runtime_session - # ------------------------------------------------------------------ # - - def stop_runtime_session( - self, - session_id: str, - bearer_token: str, - client_token: Optional[str] = None, - ) -> dict[str, Any]: - """Terminate a runtime session. - - Args: - session_id: The session id to stop. - bearer_token: Bearer token for authentication. - client_token: Idempotency token. Auto-generated as a UUID4 - when not supplied. - - Returns: - Parsed JSON body of the response (often an empty dict). - - Raises: - AgentRuntimeError: If the runtime returns a non-2xx status - (for example ``HTTP 404`` for an unknown session). - """ - request_headers = { - "Authorization": f"Bearer {bearer_token}", - "Content-Type": "application/json", - "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id, - } - response = self._http.request( - "POST", - self._build_url("stopruntimesession"), - headers=request_headers, - body=json.dumps({"clientToken": client_token or str(uuid.uuid4())}).encode( - "utf-8" - ), - timeout=self.timeout, - preload_content=True, - ) - self._check_response(response) - if not response.data: - return {} - try: - parsed = json.loads(response.data) - except json.JSONDecodeError: - return {} - return parsed if isinstance(parsed, dict) else {"response": parsed} - - # ------------------------------------------------------------------ # - # Response parsing - # ------------------------------------------------------------------ # - - def _check_response(self, response: _UrllibResponse) -> None: - """Raise :class:`AgentRuntimeError` for non-2xx responses. - - Attempts to parse the body as JSON and surface ``error``, - ``error_type``, and ``message`` fields. Falls back to the raw - body text. - - Args: - response: The urllib3 response to inspect. - - Raises: - AgentRuntimeError: If the status code is 400 or above. - """ - if response.status < 400: - return - - body_bytes = response.read() - try: - body: Any = json.loads(body_bytes) - except (json.JSONDecodeError, ValueError): - body = body_bytes.decode("utf-8", errors="replace") - - error_type = f"HTTP {response.status}" - reason = response.reason or error_type - - if isinstance(body, dict): - raise AgentRuntimeError( - error=body.get("error", reason), - error_type=body.get("error_type", error_type), - message=body.get( - "message", body_bytes.decode("utf-8", errors="replace") - ), - ) - raise AgentRuntimeError( - error=str(body) or reason, - error_type=error_type, - ) - - def _read_non_streaming(self, response: _UrllibResponse) -> str: - """Read a fully-buffered response body and unwrap JSON strings. - - Args: - response: The urllib3 response, opened with - ``preload_content=False``. - - Returns: - The response body as text. If the body is a JSON-encoded - string (e.g. ``"hello"``), the unwrapped value is returned. - """ - data = response.read() - text = data.decode("utf-8", errors="replace") - try: - parsed = json.loads(text) - except (json.JSONDecodeError, ValueError): - return text - return parsed if isinstance(parsed, str) else text - - def _iter_sse_decoded(self, response: _UrllibResponse) -> Iterator[str]: - """Iterate decoded SSE payloads from a streaming response. - - Args: - response: The urllib3 response, opened with - ``preload_content=False``. - - Yields: - Decoded payload strings (one per non-empty ``data:`` line - or JSON line). - """ - for raw_line in self._iter_lines(response): - if not raw_line: - continue - decoded = self._decode_sse_line(raw_line.decode("utf-8", errors="replace")) - if decoded: - yield decoded - - @staticmethod - def _iter_lines(response: _UrllibResponse) -> Iterator[bytes]: - r"""Yield lines from a streaming urllib3 response. - - Splits on ``\n`` and strips a trailing ``\r`` so it handles - both LF and CRLF line endings. Preserves empty lines (the caller - filters them). - - Args: - response: The urllib3 response, opened with - ``preload_content=False``. - - Yields: - Each line as bytes (with no trailing newline). - """ - pending = b"" - for chunk in response.stream(1024): - if not chunk: - continue - pending += chunk - while True: - idx = pending.find(b"\n") - if idx < 0: - break - line, pending = pending[:idx], pending[idx + 1 :] - if line.endswith(b"\r"): - line = line[:-1] - yield line - if pending: - if pending.endswith(b"\r"): - pending = pending[:-1] - yield pending - - def _decode_sse_line(self, line: str) -> Optional[str]: - """Decode a single SSE or JSON-Lines line. - - Handles SSE ``data:`` lines, JSON error envelopes, JSON-encoded - strings, and plain text. Error envelopes (with an ``error`` key) - are raised as :class:`AgentRuntimeError`. - - Args: - line: A single line of the streamed response. - - Returns: - The decoded payload, or ``None`` if the line is a comment, - empty, or carries no renderable text. - - Raises: - AgentRuntimeError: If the line carries a JSON error payload. - """ - line = line.strip() - if not line or line.startswith(":"): - return None - - if line.startswith("data:"): - content = line[5:].strip() - - if content.startswith("{"): - try: - data = json.loads(content) - if isinstance(data, dict) and "error" in data: - raise AgentRuntimeError( - error=str(data["error"]), - error_type=str(data.get("error_type", "")), - message=str(data.get("message", data["error"])), - ) - except json.JSONDecodeError: - pass - - if content.startswith('"'): - try: - unwrapped = json.loads(content) - except json.JSONDecodeError: - if content.endswith('"'): - return content[1:-1] - return content - return unwrapped if isinstance(unwrapped, str) else content - return content - - try: - data = json.loads(line) - except json.JSONDecodeError: - return line - - if isinstance(data, dict): - if "error" in data: - raise AgentRuntimeError( - error=str(data["error"]), - error_type=str(data.get("error_type", "")), - message=str(data.get("message", data["error"])), - ) - if "text" in data: - value = data["text"] - return value if isinstance(value, str) else str(value) - if "content" in data: - value = data["content"] - return value if isinstance(value, str) else str(value) - if "data" in data: - return str(data["data"]) - return None diff --git a/tests/unit/runtime/test_agent_core_runtime_client.py b/tests/unit/runtime/test_agent_core_runtime_client.py index 4cb33085..62d86c96 100644 --- a/tests/unit/runtime/test_agent_core_runtime_client.py +++ b/tests/unit/runtime/test_agent_core_runtime_client.py @@ -1,11 +1,16 @@ """Tests for AgentCoreRuntimeClient.""" -from unittest.mock import Mock, patch +import json +from typing import Iterator, Optional +from unittest.mock import MagicMock, Mock, patch from urllib.parse import quote import pytest -from bedrock_agentcore.runtime.agent_core_runtime_client import AgentCoreRuntimeClient +from bedrock_agentcore.runtime.agent_core_runtime_client import ( + AgentCoreRuntimeClient, + AgentRuntimeError, +) class TestAgentCoreRuntimeClientInit: @@ -28,7 +33,9 @@ class TestParseRuntimeArn: def test_parse_valid_arn(self): """Test parsing a valid runtime ARN.""" client = AgentCoreRuntimeClient(region="us-west-2") - arn = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc123" + arn = ( + "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc123" + ) result = client._parse_runtime_arn(arn) @@ -99,7 +106,9 @@ def test_parse_empty_runtime_id_raises_error(self): class TestBuildWebsocketUrl: """Tests for _build_websocket_url helper.""" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_basic_url(self, mock_endpoint): """Test building basic WebSocket URL without query params.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -112,7 +121,9 @@ def test_build_basic_url(self, mock_endpoint): encoded_arn = quote(runtime_arn, safe="") assert result == f"wss://example.aws.dev/runtimes/{encoded_arn}/ws" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_url_with_endpoint_name(self, mock_endpoint): """Test building URL with endpoint name (qualifier param).""" mock_endpoint.return_value = "https://example.aws.dev" @@ -122,30 +133,41 @@ def test_build_url_with_endpoint_name(self, mock_endpoint): result = client._build_websocket_url(runtime_arn, endpoint_name="DEFAULT") encoded_arn = quote(runtime_arn, safe="") - assert result == f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?qualifier=DEFAULT" - - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + assert ( + result + == f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?qualifier=DEFAULT" + ) + + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_url_with_custom_headers(self, mock_endpoint): """Test building URL with custom headers as query params.""" mock_endpoint.return_value = "https://example.aws.dev" client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - result = client._build_websocket_url(runtime_arn, custom_headers={"abc": "pqr", "foo": "bar"}) + result = client._build_websocket_url( + runtime_arn, custom_headers={"abc": "pqr", "foo": "bar"} + ) encoded_arn = quote(runtime_arn, safe="") assert f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?" in result assert "abc=pqr" in result assert "foo=bar" in result - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_build_url_with_all_params(self, mock_endpoint): """Test building URL with endpoint name and custom headers.""" mock_endpoint.return_value = "https://example.aws.dev" client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - result = client._build_websocket_url(runtime_arn, endpoint_name="DEFAULT", custom_headers={"abc": "pqr"}) + result = client._build_websocket_url( + runtime_arn, endpoint_name="DEFAULT", custom_headers={"abc": "pqr"} + ) encoded_arn = quote(runtime_arn, safe="") assert f"wss://example.aws.dev/runtimes/{encoded_arn}/ws?" in result @@ -157,13 +179,17 @@ class TestGenerateWsConnection: """Tests for generate_ws_connection method.""" @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_basic_connection(self, mock_endpoint, mock_session): """Test generating basic WebSocket connection.""" # Setup mocks mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -185,32 +211,44 @@ def test_generate_basic_connection(self, mock_endpoint, mock_session): assert "Sec-WebSocket-Key" in headers @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_connection_with_session_id(self, mock_endpoint, mock_session): """Test generating connection with explicit session ID.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - ws_url, headers = client.generate_ws_connection(runtime_arn, session_id="test-session-123") + ws_url, headers = client.generate_ws_connection( + runtime_arn, session_id="test-session-123" + ) assert ws_url is not None assert headers is not None # Verify session ID is in headers assert "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id" in headers - assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == "test-session-123" + assert ( + headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == "test-session-123" + ) @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_connection_user_agent(self, mock_endpoint, mock_session): """Test that User-Agent header is set correctly.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -222,18 +260,24 @@ def test_generate_connection_user_agent(self, mock_endpoint, mock_session): assert headers["User-Agent"] == "AgentCoreRuntimeClient/1.0" @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_connection_with_endpoint_name(self, mock_endpoint, mock_session): """Test generating connection with endpoint name.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - ws_url, headers = client.generate_ws_connection(runtime_arn, endpoint_name="DEFAULT") + ws_url, headers = client.generate_ws_connection( + runtime_arn, endpoint_name="DEFAULT" + ) assert "qualifier=DEFAULT" in ws_url @@ -253,12 +297,16 @@ class TestGeneratePresignedUrl: """Tests for generate_presigned_url method.""" @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_basic_presigned_url(self, mock_endpoint, mock_session): """Test generating basic presigned URL.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -279,61 +327,92 @@ def test_generate_basic_presigned_url(self, mock_endpoint, mock_session): assert "X-Amz-Signature" in presigned_url @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") - def test_generate_presigned_url_with_endpoint_name(self, mock_endpoint, mock_session): + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) + def test_generate_presigned_url_with_endpoint_name( + self, mock_endpoint, mock_session + ): """Test generating presigned URL with endpoint name.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - presigned_url = client.generate_presigned_url(runtime_arn, endpoint_name="DEFAULT") + presigned_url = client.generate_presigned_url( + runtime_arn, endpoint_name="DEFAULT" + ) assert "qualifier=DEFAULT" in presigned_url @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") - def test_generate_presigned_url_with_custom_headers(self, mock_endpoint, mock_session): + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) + def test_generate_presigned_url_with_custom_headers( + self, mock_endpoint, mock_session + ): """Test generating presigned URL with custom headers.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - presigned_url = client.generate_presigned_url(runtime_arn, custom_headers={"abc": "pqr"}) + presigned_url = client.generate_presigned_url( + runtime_arn, custom_headers={"abc": "pqr"} + ) assert "abc=pqr" in presigned_url @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_presigned_url_with_session_id(self, mock_endpoint, mock_session): """Test generating presigned URL with explicit session ID.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" - presigned_url = client.generate_presigned_url(runtime_arn, session_id="test-session-456") + presigned_url = client.generate_presigned_url( + runtime_arn, session_id="test-session-456" + ) # Verify session ID is in query params - assert "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id=test-session-456" in presigned_url + assert ( + "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id=test-session-456" + in presigned_url + ) @patch("bedrock_agentcore.runtime.agent_core_runtime_client.boto3.Session") - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") - def test_generate_presigned_url_with_custom_expires(self, mock_endpoint, mock_session): + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) + def test_generate_presigned_url_with_custom_expires( + self, mock_endpoint, mock_session + ): """Test generating presigned URL with custom expiration.""" mock_endpoint.return_value = "https://example.aws.dev" mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) mock_session.return_value.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2") @@ -387,7 +466,9 @@ def test_init_without_session_creates_default(self, mock_session_class): assert client.session == mock_session mock_session_class.assert_called_once() - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_ws_connection_uses_custom_session(self, mock_endpoint): """Test that generate_ws_connection uses the custom session.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -395,7 +476,9 @@ def test_generate_ws_connection_uses_custom_session(self, mock_endpoint): # Create custom session with credentials custom_session = Mock() mock_creds = Mock() - mock_creds.get_frozen_credentials.return_value = Mock(access_key="AKIATEST", secret_key="secret", token=None) + mock_creds.get_frozen_credentials.return_value = Mock( + access_key="AKIATEST", secret_key="secret", token=None + ) custom_session.get_credentials.return_value = mock_creds client = AgentCoreRuntimeClient(region="us-west-2", session=custom_session) @@ -412,7 +495,9 @@ def test_generate_ws_connection_uses_custom_session(self, mock_endpoint): class TestGenerateWsConnectionOAuth: """Tests for generate_ws_connection_oauth method.""" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_oauth_connection_basic(self, mock_endpoint): """Test generating basic OAuth WebSocket connection.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -434,7 +519,9 @@ def test_generate_oauth_connection_basic(self, mock_endpoint): assert "Sec-WebSocket-Version" in headers assert headers["Sec-WebSocket-Version"] == "13" - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_oauth_connection_with_session_id(self, mock_endpoint): """Test generating OAuth connection with explicit session ID.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -443,11 +530,17 @@ def test_generate_oauth_connection_with_session_id(self, mock_endpoint): bearer_token = "test-token" custom_session_id = "custom-oauth-session-123" - ws_url, headers = client.generate_ws_connection_oauth(runtime_arn, bearer_token, session_id=custom_session_id) + ws_url, headers = client.generate_ws_connection_oauth( + runtime_arn, bearer_token, session_id=custom_session_id + ) - assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == custom_session_id + assert ( + headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == custom_session_id + ) - @patch("bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint") + @patch( + "bedrock_agentcore.runtime.agent_core_runtime_client.get_data_plane_endpoint" + ) def test_generate_oauth_connection_with_endpoint_name(self, mock_endpoint): """Test generating OAuth connection with endpoint name.""" mock_endpoint.return_value = "https://example.aws.dev" @@ -455,7 +548,9 @@ def test_generate_oauth_connection_with_endpoint_name(self, mock_endpoint): runtime_arn = "arn:aws:bedrock-agentcore:us-west-2:123:runtime/my-runtime" bearer_token = "test-token" - ws_url, headers = client.generate_ws_connection_oauth(runtime_arn, bearer_token, endpoint_name="DEFAULT") + ws_url, headers = client.generate_ws_connection_oauth( + runtime_arn, bearer_token, endpoint_name="DEFAULT" + ) assert "qualifier=DEFAULT" in ws_url @@ -475,3 +570,842 @@ def test_generate_oauth_connection_invalid_arn_raises_error(self): with pytest.raises(ValueError, match="Invalid runtime ARN format"): client.generate_ws_connection_oauth(invalid_arn, bearer_token) + + +# ===================================================================== # +# HTTP invocation tests (bearer-token auth) +# ===================================================================== # + +ARN = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc" +BEARER = "test-token" +SESSION = "a" * 36 + + +class _FakeResponse: + """Minimal stand-in for urllib3's response used by the HTTP methods. + + Exposes the exact surface the client touches: ``status``, ``reason``, + ``headers``, ``data``, ``read``, ``stream``, and ``release_conn``. + """ + + def __init__( + self, + status: int = 200, + headers: Optional[dict] = None, + body: bytes = b"", + chunks: Optional[list] = None, + reason: str = "", + ) -> None: + self.status = status + self.reason = reason + self.headers = headers or {} + self._body = body + self._chunks = chunks + self._consumed = False + self.release_calls = 0 + + @property + def data(self) -> bytes: + return self._body + + def read(self) -> bytes: + if self._consumed: + return b"" + self._consumed = True + return self._body + + def stream(self, amt: int = 1024) -> Iterator[bytes]: + if self._chunks is not None: + for chunk in self._chunks: + yield chunk + return + if self._body: + yield self._body + + def release_conn(self) -> None: + self.release_calls += 1 + + +def _make_client(response: _FakeResponse) -> "tuple[AgentCoreRuntimeClient, MagicMock]": + """Build a client whose lazy PoolManager is bypassed with a mock.""" + pool = MagicMock() + pool.request.return_value = response + client = AgentCoreRuntimeClient(region="us-west-2") + client._pool_manager = pool # bypass the lazy property + return client, pool + + +# --------------------------------------------------------------------- # +# Lazy PoolManager +# --------------------------------------------------------------------- # + + +class TestLazyPoolManager: + """Tests for the @property _http lazy PoolManager.""" + + def test_pool_not_created_at_init(self) -> None: + """Constructor does not create a PoolManager.""" + client = AgentCoreRuntimeClient(region="us-west-2") + assert client._pool_manager is None + + def test_pool_created_on_first_access(self) -> None: + """First access to _http creates the PoolManager.""" + import urllib3 + + client = AgentCoreRuntimeClient(region="us-west-2") + pool = client._http + assert isinstance(pool, urllib3.PoolManager) + assert client._pool_manager is pool + + def test_pool_reused_across_accesses(self) -> None: + """Repeated _http access returns the same instance.""" + client = AgentCoreRuntimeClient(region="us-west-2") + first = client._http + second = client._http + assert first is second + + +# --------------------------------------------------------------------- # +# _build_http_url / _build_bearer_headers / _serialize_body +# --------------------------------------------------------------------- # + + +class TestBuildHttpUrl: + """Tests for _build_http_url.""" + + def test_invocations_url(self) -> None: + """URL embeds region (via endpoint helper), URL-encoded ARN, path, and qualifier.""" + client = AgentCoreRuntimeClient(region="us-west-2") + url = client._build_http_url(ARN, "invocations") + assert url.startswith( + "https://bedrock-agentcore.us-west-2.amazonaws.com/runtimes/" + ) + assert "/invocations?qualifier=DEFAULT" in url + assert "arn%3Aaws%3Abedrock-agentcore" in url + + def test_commands_url(self) -> None: + """Different path suffix yields the commands endpoint.""" + client = AgentCoreRuntimeClient(region="us-west-2") + assert "/commands?qualifier=DEFAULT" in client._build_http_url(ARN, "commands") + + def test_non_default_qualifier(self) -> None: + """Qualifier reflects the endpoint_name argument.""" + client = AgentCoreRuntimeClient(region="us-west-2") + assert "qualifier=DEV" in client._build_http_url( + ARN, "invocations", endpoint_name="DEV" + ) + + def test_region_derived_from_arn(self) -> None: + """URL uses the ARN region, not the client's init region.""" + client = AgentCoreRuntimeClient(region="us-east-1") + arn = "arn:aws:bedrock-agentcore:eu-west-2:123456789012:runtime/other" + assert ( + "https://bedrock-agentcore.eu-west-2.amazonaws.com/" + in client._build_http_url(arn, "invocations") + ) + + def test_invalid_arn_raises(self) -> None: + """Invalid ARN propagates from _parse_runtime_arn.""" + client = AgentCoreRuntimeClient(region="us-west-2") + with pytest.raises(ValueError, match="Invalid runtime ARN format"): + client._build_http_url("not-an-arn", "invocations") + + +class TestBuildBearerHeaders: + """Tests for _build_bearer_headers.""" + + def test_all_headers_populated(self) -> None: + """All four bearer-auth headers are set.""" + client = AgentCoreRuntimeClient(region="us-west-2") + headers = client._build_bearer_headers( + BEARER, SESSION, accept="application/json", content_type="application/json" + ) + assert headers["Authorization"] == f"Bearer {BEARER}" + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == SESSION + + def test_eventstream_accept(self) -> None: + """Accept value is passed through verbatim.""" + client = AgentCoreRuntimeClient(region="us-west-2") + headers = client._build_bearer_headers( + BEARER, + SESSION, + accept="application/vnd.amazon.eventstream", + content_type="application/json", + ) + assert headers["Accept"] == "application/vnd.amazon.eventstream" + + +class TestSerializeBody: + """Tests for _serialize_body.""" + + def test_json_dict(self) -> None: + """Dict body is JSON-serialized when content type is JSON.""" + assert ( + AgentCoreRuntimeClient._serialize_body({"a": 1}, "application/json") + == b'{"a": 1}' + ) + + def test_bytes_passthrough_non_json(self) -> None: + """Bytes body is sent verbatim for non-JSON content types.""" + assert ( + AgentCoreRuntimeClient._serialize_body(b"raw", "application/octet-stream") + == b"raw" + ) + + def test_str_utf8_encoded_non_json(self) -> None: + """String body is UTF-8 encoded for non-JSON content types.""" + assert AgentCoreRuntimeClient._serialize_body( + "héllo", "text/plain" + ) == "héllo".encode("utf-8") + + def test_fallback_to_json(self) -> None: + """Non-str, non-bytes body falls back to JSON for non-JSON content types.""" + assert ( + AgentCoreRuntimeClient._serialize_body({"x": 1}, "application/cbor") + == b'{"x": 1}' + ) + + +# --------------------------------------------------------------------- # +# invoke (non-streaming JSON, non-streaming plain, SSE) +# --------------------------------------------------------------------- # + + +class TestInvoke: + """Tests for invoke.""" + + def test_non_streaming_json_string(self) -> None: + """JSON string response is unwrapped.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b'"hello"', + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) == "hello" + ) + assert resp.release_calls == 1 + + def test_non_streaming_json_object_returns_text(self) -> None: + """Non-string JSON response returns the raw body text.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b'{"answer": 42}', + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) + == '{"answer": 42}' + ) + + def test_non_streaming_invalid_json_returns_text(self) -> None: + """If the body isn't valid JSON, the raw text is returned.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/json"}, + body=b"plain-text-not-json", + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) + == "plain-text-not-json" + ) + + def test_sse_streaming_concatenates(self) -> None: + """Multiple SSE data lines are concatenated in order.""" + body = b"data: hello\ndata: world\n" + resp = _FakeResponse( + status=200, + headers={"content-type": "text/event-stream"}, + chunks=[body], + ) + client, _ = _make_client(resp) + assert ( + client.invoke(ARN, BEARER, {"prompt": "hi"}, session_id=SESSION) + == "helloworld" + ) + + def test_custom_headers_merged(self) -> None: + """Caller-supplied headers are merged into the request.""" + resp = _FakeResponse( + status=200, headers={"content-type": "application/json"}, body=b'"ok"' + ) + client, pool = _make_client(resp) + client.invoke(ARN, BEARER, {}, session_id=SESSION, headers={"X-Test": "1"}) + sent_headers = pool.request.call_args.kwargs["headers"] + assert sent_headers["X-Test"] == "1" + assert sent_headers["Authorization"] == f"Bearer {BEARER}" + + def test_non_ok_raises(self) -> None: + """4xx and 5xx responses become AgentRuntimeError.""" + resp = _FakeResponse( + status=500, + body=b'{"error": "oops", "message": "boom"}', + reason="Server Error", + ) + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError) as excinfo: + client.invoke(ARN, BEARER, {}, session_id=SESSION) + assert excinfo.value.error == "oops" + assert "boom" in str(excinfo.value) + + def test_session_id_auto_generated(self) -> None: + """Missing session_id is auto-generated as a UUID.""" + resp = _FakeResponse( + status=200, headers={"content-type": "application/json"}, body=b'"ok"' + ) + client, pool = _make_client(resp) + client.invoke(ARN, BEARER, {}) + sent_headers = pool.request.call_args.kwargs["headers"] + # UUID4 string is 36 chars. + assert len(sent_headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"]) == 36 + + +# --------------------------------------------------------------------- # +# invoke_streaming +# --------------------------------------------------------------------- # + + +class TestInvokeStreaming: + """Tests for invoke_streaming.""" + + def test_yields_chunks(self) -> None: + """Yields each decoded SSE payload in order.""" + body = b"data: first\ndata: second\n" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[body] + ) + client, _ = _make_client(resp) + assert list(client.invoke_streaming(ARN, BEARER, {}, session_id=SESSION)) == [ + "first", + "second", + ] + + def test_non_ok_raises(self) -> None: + """Errors are surfaced before the generator yields.""" + resp = _FakeResponse(status=404, body=b"not found", reason="Not Found") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + list(client.invoke_streaming(ARN, BEARER, {}, session_id=SESSION)) + + def test_custom_headers_merged(self) -> None: + """Custom headers are merged into the request.""" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[b""] + ) + client, pool = _make_client(resp) + list( + client.invoke_streaming( + ARN, BEARER, {}, session_id=SESSION, headers={"X-Custom": "v"} + ) + ) + assert pool.request.call_args.kwargs["headers"]["X-Custom"] == "v" + + +# --------------------------------------------------------------------- # +# invoke_streaming_async +# --------------------------------------------------------------------- # + + +class TestInvokeStreamingAsync: + """Tests for invoke_streaming_async.""" + + @pytest.mark.asyncio + async def test_yields_async_chunks(self) -> None: + """Async generator yields the same chunks as the sync version.""" + body = b"data: a\ndata: b\n" + resp = _FakeResponse( + status=200, headers={"content-type": "text/event-stream"}, chunks=[body] + ) + client, _ = _make_client(resp) + out = [] + async for chunk in client.invoke_streaming_async( + ARN, BEARER, {}, session_id=SESSION + ): + out.append(chunk) + assert out == ["a", "b"] + + @pytest.mark.asyncio + async def test_propagates_exception(self) -> None: + """Exceptions from the background thread surface to the caller.""" + resp = _FakeResponse(status=500, body=b"err", reason="Server Error") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + async for _ in client.invoke_streaming_async( + ARN, BEARER, {}, session_id=SESSION + ): + pass + + +# --------------------------------------------------------------------- # +# execute_command / execute_command_streaming +# --------------------------------------------------------------------- # + + +def _encode_eventstream_frame(payload: dict) -> bytes: + """Encode a JSON payload using the EventStream binary framing. + + Builds a valid frame the client's EventStreamBuffer can parse. Format: + total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4). + """ + import binascii + import struct + + body = json.dumps(payload).encode("utf-8") + headers = b"" + headers_length = len(headers) + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + return message_bytes + message_crc + + +class TestExecuteCommand: + """Tests for execute_command (blocking).""" + + def test_aggregates_output(self) -> None: + """stdout/stderr/exitCode/status are aggregated across events.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentStart": {}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "hi "}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "there"}}}), + _encode_eventstream_frame({"chunk": {"contentDelta": {"stderr": "warn"}}}), + _encode_eventstream_frame( + {"chunk": {"contentStop": {"exitCode": 0, "status": "COMPLETED"}}} + ), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + result = client.execute_command( + ARN, BEARER, "echo hi there", session_id=SESSION + ) + assert result == { + "stdout": "hi there", + "stderr": "warn", + "exitCode": 0, + "status": "COMPLETED", + } + + def test_missing_stop_has_unknown_status(self) -> None: + """Without a contentStop event, defaults are UNKNOWN / -1.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "x"}}}), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + result = client.execute_command(ARN, BEARER, "echo x", session_id=SESSION) + assert result == { + "stdout": "x", + "stderr": "", + "exitCode": -1, + "status": "UNKNOWN", + } + + +class TestExecuteCommandStreaming: + """Tests for execute_command_streaming.""" + + def test_sends_command_body(self) -> None: + """Request body contains command and timeout.""" + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[], + ) + client, pool = _make_client(resp) + list( + client.execute_command_streaming( + ARN, BEARER, "ls", session_id=SESSION, command_timeout=42 + ) + ) + sent = pool.request.call_args + assert sent.kwargs["body"] == b'{"command": "ls", "timeout": 42}' + assert sent.kwargs["timeout"] == 72 # command_timeout + 30 + + def test_defaults_to_constant(self) -> None: + """Omitting command_timeout uses DEFAULT_COMMAND_TIMEOUT.""" + from bedrock_agentcore.runtime.agent_core_runtime_client import ( + DEFAULT_COMMAND_TIMEOUT, + ) + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[], + ) + client, pool = _make_client(resp) + list(client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION)) + sent = pool.request.call_args + body = json.loads(sent.kwargs["body"]) + assert body["timeout"] == DEFAULT_COMMAND_TIMEOUT + assert sent.kwargs["timeout"] == DEFAULT_COMMAND_TIMEOUT + 30 + + def test_yields_parsed_events(self) -> None: + """Each EventStream frame is yielded as a dict.""" + events = [ + _encode_eventstream_frame({"chunk": {"contentStart": {}}}), + _encode_eventstream_frame( + {"chunk": {"contentStop": {"exitCode": 2, "status": "COMPLETED"}}} + ), + ] + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=events, + ) + client, _ = _make_client(resp) + parsed = list( + client.execute_command_streaming(ARN, BEARER, "echo hi", session_id=SESSION) + ) + assert parsed[0] == {"contentStart": {}} + assert parsed[1] == {"contentStop": {"exitCode": 2, "status": "COMPLETED"}} + + def test_skips_event_without_payload(self) -> None: + """Events with empty payload are skipped silently.""" + import binascii + import struct + + headers = b"" + body = b"" + headers_length = 0 + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + empty_frame = message_bytes + message_crc + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[empty_frame], + ) + client, _ = _make_client(resp) + assert ( + list( + client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION) + ) + == [] + ) + + def test_skips_bad_json_payload(self) -> None: + """Events whose payload is not valid JSON are skipped.""" + import binascii + import struct + + headers = b"" + body = b"not-json" + headers_length = 0 + total_length = 4 + 4 + 4 + headers_length + len(body) + 4 + prelude = struct.pack(">II", total_length, headers_length) + prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) + message_bytes = prelude + prelude_crc + headers + body + message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) + bad_frame = message_bytes + message_crc + + resp = _FakeResponse( + status=200, + headers={"content-type": "application/vnd.amazon.eventstream"}, + chunks=[bad_frame], + ) + client, _ = _make_client(resp) + assert ( + list( + client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION) + ) + == [] + ) + + def test_non_ok_raises(self) -> None: + """Non-2xx responses raise AgentRuntimeError.""" + resp = _FakeResponse(status=403, body=b"forbidden", reason="Forbidden") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError): + list( + client.execute_command_streaming(ARN, BEARER, "ls", session_id=SESSION) + ) + + +# --------------------------------------------------------------------- # +# stop_runtime_session +# --------------------------------------------------------------------- # + + +class TestStopRuntimeSession: + """Tests for stop_runtime_session.""" + + def test_sends_client_token(self) -> None: + """Request body contains the provided client token.""" + resp = _FakeResponse(status=200, body=b"{}") + client, pool = _make_client(resp) + client.stop_runtime_session(ARN, BEARER, session_id=SESSION, client_token="abc") + sent = json.loads(pool.request.call_args.kwargs["body"]) + assert sent == {"clientToken": "abc"} + + def test_autogenerates_client_token(self) -> None: + """Missing client_token is replaced with a UUID4.""" + resp = _FakeResponse(status=200, body=b"{}") + client, pool = _make_client(resp) + client.stop_runtime_session(ARN, BEARER, session_id=SESSION) + sent = json.loads(pool.request.call_args.kwargs["body"]) + assert len(sent["clientToken"]) == 36 + + def test_returns_empty_dict_on_blank_body(self) -> None: + """Empty response body yields an empty dict.""" + resp = _FakeResponse(status=200, body=b"") + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == {} + + def test_returns_parsed_dict(self) -> None: + """Dict body is returned as-is.""" + resp = _FakeResponse(status=200, body=b'{"sessionId": "x"}') + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == { + "sessionId": "x" + } + + def test_non_dict_body_wrapped(self) -> None: + """A JSON body that isn't a dict is wrapped under 'response'.""" + resp = _FakeResponse(status=200, body=b'["a", "b"]') + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == { + "response": ["a", "b"] + } + + def test_invalid_json_body_returns_empty(self) -> None: + """Invalid JSON body returns empty dict rather than raising.""" + resp = _FakeResponse(status=200, body=b"not json") + client, _ = _make_client(resp) + assert client.stop_runtime_session(ARN, BEARER, session_id=SESSION) == {} + + def test_404_raises(self) -> None: + """Unknown-session HTTP 404 becomes AgentRuntimeError.""" + resp = _FakeResponse(status=404, body=b"{}", reason="Not Found") + client, _ = _make_client(resp) + with pytest.raises(AgentRuntimeError) as excinfo: + client.stop_runtime_session(ARN, BEARER, session_id=SESSION) + assert excinfo.value.error_type == "HTTP 404" + + +# --------------------------------------------------------------------- # +# _check_response +# --------------------------------------------------------------------- # + + +class TestCheckResponse: + """Tests for _check_response.""" + + def test_noop_on_2xx(self) -> None: + """2xx responses do not raise.""" + AgentCoreRuntimeClient._check_response(_FakeResponse(status=200)) + + def test_parses_json_error_body(self) -> None: + """JSON error bodies populate error/error_type/message.""" + resp = _FakeResponse( + status=400, + body=b'{"error": "bad", "error_type": "Validation", "message": "details"}', + reason="Bad Request", + ) + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert excinfo.value.error == "bad" + assert excinfo.value.error_type == "Validation" + + def test_falls_back_to_text_body(self) -> None: + """Non-JSON bodies become the error string.""" + resp = _FakeResponse(status=500, body=b"internal boom", reason="Server Error") + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert "internal boom" in excinfo.value.error + assert excinfo.value.error_type == "HTTP 500" + + def test_defaults_error_type_for_dict_body(self) -> None: + """Missing error_type in JSON body defaults to 'HTTP '.""" + resp = _FakeResponse(status=502, body=b'{"message": "bad gateway"}') + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert excinfo.value.error_type == "HTTP 502" + + def test_empty_body_uses_reason(self) -> None: + """Empty body falls back to the response reason string.""" + resp = _FakeResponse(status=503, body=b"", reason="Service Unavailable") + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._check_response(resp) + assert excinfo.value.error == "Service Unavailable" + + +# --------------------------------------------------------------------- # +# _decode_sse_line — all branches +# --------------------------------------------------------------------- # + + +class TestDecodeSseLine: + """Tests for _decode_sse_line.""" + + def test_empty_line(self) -> None: + """Empty / whitespace-only lines return None.""" + assert AgentCoreRuntimeClient._decode_sse_line("") is None + assert AgentCoreRuntimeClient._decode_sse_line(" ") is None + + def test_comment_line(self) -> None: + """SSE comment lines (starting with ':') return None.""" + assert AgentCoreRuntimeClient._decode_sse_line(": keepalive") is None + + def test_data_with_json_encoded_string(self) -> None: + """data: "hello" -> hello (unwrapped).""" + assert AgentCoreRuntimeClient._decode_sse_line('data: "hello"') == "hello" + + def test_data_with_plain_text(self) -> None: + """data: plain -> plain.""" + assert AgentCoreRuntimeClient._decode_sse_line("data: plain") == "plain" + + def test_data_with_malformed_json_string(self) -> None: + """data: "bad" missing closing quote: strip what we can.""" + assert AgentCoreRuntimeClient._decode_sse_line('data: "oops') == '"oops' + + def test_data_with_malformed_quoted_string(self) -> None: + """data: \"bad\\escape\" with invalid escape yields trimmed string.""" + assert ( + AgentCoreRuntimeClient._decode_sse_line('data: "bad\\escape"') + == "bad\\escape" + ) + + def test_data_with_json_error_raises(self) -> None: + """data: {"error": ...} raises AgentRuntimeError.""" + with pytest.raises(AgentRuntimeError) as excinfo: + AgentCoreRuntimeClient._decode_sse_line( + 'data: {"error": "boom", "error_type": "T"}' + ) + assert excinfo.value.error == "boom" + + def test_data_with_json_object_no_error(self) -> None: + """data: {...} without an error key passes through as raw content.""" + assert ( + AgentCoreRuntimeClient._decode_sse_line('data: {"foo": "bar"}') + == '{"foo": "bar"}' + ) + + def test_data_with_broken_json_passes_through(self) -> None: + """data: { broken passes through as the content after 'data:'.""" + assert AgentCoreRuntimeClient._decode_sse_line("data: {not json") == "{not json" + + def test_plain_json_text_key(self) -> None: + """Plain JSON line with 'text' key returns that value.""" + assert AgentCoreRuntimeClient._decode_sse_line('{"text": "hi"}') == "hi" + + def test_plain_json_content_key(self) -> None: + """Plain JSON line with 'content' key returns that value.""" + assert AgentCoreRuntimeClient._decode_sse_line('{"content": "c"}') == "c" + + def test_plain_json_data_key(self) -> None: + """Plain JSON line with 'data' key returns str(that value).""" + assert AgentCoreRuntimeClient._decode_sse_line('{"data": 42}') == "42" + + def test_plain_json_error_raises(self) -> None: + """Plain JSON line with 'error' raises.""" + with pytest.raises(AgentRuntimeError): + AgentCoreRuntimeClient._decode_sse_line('{"error": "bad"}') + + def test_plain_json_unknown_shape_returns_none(self) -> None: + """Plain JSON dict without known keys returns None.""" + assert AgentCoreRuntimeClient._decode_sse_line('{"something": "else"}') is None + + def test_plain_text_passes_through(self) -> None: + """Non-JSON plain text is returned verbatim.""" + assert AgentCoreRuntimeClient._decode_sse_line("hello world") == "hello world" + + +# --------------------------------------------------------------------- # +# _iter_lines +# --------------------------------------------------------------------- # + + +class TestIterLines: + """Tests for _iter_lines chunk splitting.""" + + def test_splits_on_lf(self) -> None: + """Splits on newline and yields each line.""" + resp = _FakeResponse(chunks=[b"one\ntwo\nthree\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"one", b"two", b"three"] + + def test_strips_crlf(self) -> None: + r"""Trailing \r is stripped.""" + resp = _FakeResponse(chunks=[b"a\r\nb\r\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"a", b"b"] + + def test_trailing_without_newline(self) -> None: + """A trailing fragment without newline is still yielded.""" + resp = _FakeResponse(chunks=[b"done"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"done"] + + def test_trailing_cr_stripped(self) -> None: + """A trailing CR without LF is stripped.""" + resp = _FakeResponse(chunks=[b"done\r"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"done"] + + def test_split_across_chunks(self) -> None: + """Lines spanning chunk boundaries are reassembled.""" + resp = _FakeResponse(chunks=[b"hel", b"lo\nwo", b"rld\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"hello", b"world"] + + def test_empty_chunk_skipped(self) -> None: + """Empty chunks in the stream are ignored.""" + resp = _FakeResponse(chunks=[b"a\n", b"", b"b\n"]) + lines = list(AgentCoreRuntimeClient._iter_lines(resp)) + assert lines == [b"a", b"b"] + + +# --------------------------------------------------------------------- # +# AgentRuntimeError +# --------------------------------------------------------------------- # + + +class TestAgentRuntimeError: + """Tests for AgentRuntimeError construction and rendering.""" + + def test_stores_error_and_type(self) -> None: + """error and error_type are accessible as attributes.""" + exc = AgentRuntimeError(error="boom", error_type="T") + assert exc.error == "boom" + assert exc.error_type == "T" + + def test_default_error_type_empty(self) -> None: + """error_type defaults to the empty string.""" + exc = AgentRuntimeError(error="boom") + assert exc.error_type == "" + + def test_str_uses_message(self) -> None: + """str() prefers the message when provided.""" + exc = AgentRuntimeError(error="e", message="human-readable") + assert str(exc) == "human-readable" + + def test_str_falls_back_to_error(self) -> None: + """str() falls back to the error token when no message.""" + exc = AgentRuntimeError(error="only-error") + assert str(exc) == "only-error" + + def test_is_exception(self) -> None: + """Subclasses Exception so it can be raised.""" + try: + raise AgentRuntimeError(error="x") + except Exception as exc: + assert isinstance(exc, AgentRuntimeError) diff --git a/tests/unit/runtime/test_agent_core_runtime_http_client.py b/tests/unit/runtime/test_agent_core_runtime_http_client.py deleted file mode 100644 index 766cc93d..00000000 --- a/tests/unit/runtime/test_agent_core_runtime_http_client.py +++ /dev/null @@ -1,823 +0,0 @@ -"""Tests for AgentCoreRuntimeHttpClient.""" - -from __future__ import annotations - -import json -from typing import Any, Iterator, Optional -from unittest.mock import MagicMock - -import pytest - -from bedrock_agentcore.runtime.agent_core_runtime_http_client import ( - AgentCoreRuntimeHttpClient, - AgentRuntimeError, -) - -ARN = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-runtime-abc" -BEARER = "test-token" -SESSION = "a" * 36 - - -class _FakeResponse: - """Minimal stand-in for urllib3's response used by the HTTP client. - - Exposes the exact surface the client touches: ``status``, ``reason``, - ``headers``, ``data``, ``read``, ``stream``, and ``release_conn``. - """ - - def __init__( - self, - status: int = 200, - headers: Optional[dict[str, str]] = None, - body: bytes = b"", - chunks: Optional[list[bytes]] = None, - reason: str = "", - ) -> None: - self.status = status - self.reason = reason - self.headers = headers or {} - self._body = body - self._chunks = chunks - self._consumed = False - self.release_calls = 0 - - @property - def data(self) -> bytes: - return self._body - - def read(self) -> bytes: - if self._consumed: - return b"" - self._consumed = True - return self._body - - def stream(self, amt: int = 1024) -> Iterator[bytes]: - if self._chunks is not None: - for chunk in self._chunks: - yield chunk - return - # Fall back to emitting the full body as one chunk. - if self._body: - yield self._body - - def release_conn(self) -> None: - self.release_calls += 1 - - -def _make_client( - response: _FakeResponse, -) -> tuple[AgentCoreRuntimeHttpClient, MagicMock]: - """Build a client whose PoolManager returns ``response``.""" - pool = MagicMock() - pool.request.return_value = response - client = AgentCoreRuntimeHttpClient(agent_arn=ARN, pool_manager=pool) - return client, pool - - -# --------------------------------------------------------------------- # -# Initialization -# --------------------------------------------------------------------- # - - -class TestInit: - """Tests for AgentCoreRuntimeHttpClient.__init__.""" - - def test_parses_region_from_arn(self) -> None: - """Region is extracted from the ARN.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - assert client.region == "us-west-2" - - def test_defaults(self) -> None: - """Default field values.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - assert client.endpoint_name == "DEFAULT" - assert client.timeout == 300 - assert client.content_type == "application/json" - assert client.accept == "application/json" - - def test_custom_fields(self) -> None: - """Non-default field values are stored.""" - client = AgentCoreRuntimeHttpClient( - agent_arn=ARN, - endpoint_name="DEV", - timeout=42, - content_type="application/cbor", - accept="text/plain", - ) - assert client.endpoint_name == "DEV" - assert client.timeout == 42 - assert client.content_type == "application/cbor" - assert client.accept == "text/plain" - - def test_invalid_arn_missing_region(self) -> None: - """An ARN with no region raises ValueError.""" - with pytest.raises(ValueError, match="Invalid agent ARN"): - AgentCoreRuntimeHttpClient(agent_arn="arn:aws:bedrock-agentcore") - - def test_invalid_arn_empty_region(self) -> None: - """An ARN with empty region raises ValueError.""" - with pytest.raises(ValueError, match="Invalid agent ARN"): - AgentCoreRuntimeHttpClient( - agent_arn="arn:aws:bedrock-agentcore::123:runtime/x" - ) - - def test_pool_manager_injection(self) -> None: - """A caller-provided PoolManager is used verbatim.""" - pool = MagicMock() - client = AgentCoreRuntimeHttpClient(agent_arn=ARN, pool_manager=pool) - assert client._http is pool - - -# --------------------------------------------------------------------- # -# _build_url / _build_headers / _serialize_body -# --------------------------------------------------------------------- # - - -class TestBuildUrl: - """Tests for _build_url.""" - - def test_invocations_url(self) -> None: - """URL embeds region, URL-encoded ARN, path, and qualifier.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN, endpoint_name="DEFAULT") - url = client._build_url("invocations") - assert url.startswith( - "https://bedrock-agentcore.us-west-2.amazonaws.com/runtimes/" - ) - assert "/invocations?qualifier=DEFAULT" in url - assert "arn%3Aaws%3Abedrock-agentcore" in url # colons URL-encoded - - def test_commands_url(self) -> None: - """Different path suffix yields the commands endpoint.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - assert "/commands?qualifier=DEFAULT" in client._build_url("commands") - - def test_non_default_qualifier(self) -> None: - """Qualifier reflects endpoint_name.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN, endpoint_name="DEV") - assert "qualifier=DEV" in client._build_url("invocations") - - -class TestBuildHeaders: - """Tests for _build_headers.""" - - def test_default_headers(self) -> None: - """Includes auth, content-type, accept, session-id.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - headers = client._build_headers(BEARER, SESSION) - assert headers["Authorization"] == f"Bearer {BEARER}" - assert headers["Content-Type"] == "application/json" - assert headers["Accept"] == "application/json" - assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == SESSION - - def test_override_accept_and_content_type(self) -> None: - """Per-call overrides replace the defaults.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - headers = client._build_headers( - BEARER, - SESSION, - accept="application/vnd.amazon.eventstream", - content_type="application/xml", - ) - assert headers["Accept"] == "application/vnd.amazon.eventstream" - assert headers["Content-Type"] == "application/xml" - - -class TestSerializeBody: - """Tests for _serialize_body.""" - - def test_json_dict(self) -> None: - """Dict body is JSON-serialized when content type is JSON.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - assert client._serialize_body({"a": 1}) == b'{"a": 1}' - - def test_bytes_passthrough_non_json(self) -> None: - """Bytes body is sent verbatim for non-JSON content types.""" - client = AgentCoreRuntimeHttpClient( - agent_arn=ARN, content_type="application/octet-stream" - ) - assert client._serialize_body(b"raw") == b"raw" - - def test_str_utf8_encoded_non_json(self) -> None: - """String body is UTF-8 encoded for non-JSON content types.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN, content_type="text/plain") - assert client._serialize_body("héllo") == "héllo".encode("utf-8") - - def test_fallback_to_json(self) -> None: - """Non-str, non-bytes body falls back to JSON even with non-JSON content type.""" - client = AgentCoreRuntimeHttpClient( - agent_arn=ARN, content_type="application/cbor" - ) - assert client._serialize_body({"x": 1}) == b'{"x": 1}' - - -# --------------------------------------------------------------------- # -# invoke (non-streaming JSON, non-streaming plain, SSE) -# --------------------------------------------------------------------- # - - -class TestInvoke: - """Tests for invoke.""" - - def test_non_streaming_json_string(self) -> None: - """JSON string response is unwrapped.""" - resp = _FakeResponse( - status=200, - headers={"content-type": "application/json"}, - body=b'"hello"', - ) - client, _ = _make_client(resp) - assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == "hello" - assert resp.release_calls == 1 - - def test_non_streaming_json_object_returns_text(self) -> None: - """Non-string JSON response returns the raw body text.""" - resp = _FakeResponse( - status=200, - headers={"content-type": "application/json"}, - body=b'{"answer": 42}', - ) - client, _ = _make_client(resp) - assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == '{"answer": 42}' - - def test_non_streaming_invalid_json_returns_text(self) -> None: - """If the body isn't valid JSON, the raw text is returned.""" - resp = _FakeResponse( - status=200, - headers={"content-type": "application/json"}, - body=b"plain-text-not-json", - ) - client, _ = _make_client(resp) - assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == "plain-text-not-json" - - def test_sse_streaming_concatenates(self) -> None: - """Multiple SSE data lines are concatenated in order.""" - body = b"data: hello\ndata: world\n" - resp = _FakeResponse( - status=200, - headers={"content-type": "text/event-stream"}, - chunks=[body], - ) - client, _ = _make_client(resp) - assert client.invoke({"prompt": "hi"}, SESSION, BEARER) == "helloworld" - - def test_custom_headers_merged(self) -> None: - """Caller-supplied headers are merged into the request.""" - resp = _FakeResponse( - status=200, headers={"content-type": "application/json"}, body=b'"ok"' - ) - client, pool = _make_client(resp) - client.invoke({}, SESSION, BEARER, headers={"X-Test": "1"}) - sent_headers = pool.request.call_args.kwargs["headers"] - assert sent_headers["X-Test"] == "1" - assert sent_headers["Authorization"] == f"Bearer {BEARER}" - - def test_non_ok_raises(self) -> None: - """4xx and 5xx responses become AgentRuntimeError.""" - resp = _FakeResponse( - status=500, - body=b'{"error": "oops", "message": "boom"}', - reason="Server Error", - ) - client, _ = _make_client(resp) - with pytest.raises(AgentRuntimeError) as excinfo: - client.invoke({}, SESSION, BEARER) - assert excinfo.value.error == "oops" - assert "boom" in str(excinfo.value) - - -# --------------------------------------------------------------------- # -# invoke_streaming -# --------------------------------------------------------------------- # - - -class TestInvokeStreaming: - """Tests for invoke_streaming.""" - - def test_yields_chunks(self) -> None: - """Yields each decoded SSE payload in order.""" - body = b"data: first\ndata: second\n" - resp = _FakeResponse( - status=200, headers={"content-type": "text/event-stream"}, chunks=[body] - ) - client, _ = _make_client(resp) - assert list(client.invoke_streaming({}, SESSION, BEARER)) == ["first", "second"] - - def test_non_ok_raises(self) -> None: - """Errors are surfaced before the generator yields.""" - resp = _FakeResponse(status=404, body=b"not found", reason="Not Found") - client, _ = _make_client(resp) - with pytest.raises(AgentRuntimeError): - list(client.invoke_streaming({}, SESSION, BEARER)) - - def test_custom_headers_merged(self) -> None: - """Custom headers are merged into the request.""" - resp = _FakeResponse( - status=200, headers={"content-type": "text/event-stream"}, chunks=[b""] - ) - client, pool = _make_client(resp) - list(client.invoke_streaming({}, SESSION, BEARER, headers={"X-Custom": "v"})) - assert pool.request.call_args.kwargs["headers"]["X-Custom"] == "v" - - -# --------------------------------------------------------------------- # -# invoke_streaming_async -# --------------------------------------------------------------------- # - - -class TestInvokeStreamingAsync: - """Tests for invoke_streaming_async.""" - - @pytest.mark.asyncio - async def test_yields_async_chunks(self) -> None: - """Async generator yields the same chunks as the sync version.""" - body = b"data: a\ndata: b\n" - resp = _FakeResponse( - status=200, headers={"content-type": "text/event-stream"}, chunks=[body] - ) - client, _ = _make_client(resp) - out: list[str] = [] - async for chunk in client.invoke_streaming_async({}, SESSION, BEARER): - out.append(chunk) - assert out == ["a", "b"] - - @pytest.mark.asyncio - async def test_propagates_exception(self) -> None: - """Exceptions from the background thread surface to the caller.""" - resp = _FakeResponse(status=500, body=b"err", reason="Server Error") - client, _ = _make_client(resp) - with pytest.raises(AgentRuntimeError): - async for _ in client.invoke_streaming_async({}, SESSION, BEARER): - pass - - -# --------------------------------------------------------------------- # -# execute_command / execute_command_streaming -# --------------------------------------------------------------------- # - - -def _encode_eventstream_frame(payload: dict[str, Any]) -> bytes: - """Encode a JSON payload using botocore's EventStream framing. - - Uses the same encoder that the server would use to build a valid frame - the client can parse. - """ - - body = json.dumps(payload).encode("utf-8") - # Build prelude manually. Easier: construct bytes that match the wire format. - # Format: total_length (4), headers_length (4), prelude_crc (4), headers, payload, message_crc (4). - # With no headers: headers_length = 0. - import binascii - import struct - - headers = b"" - headers_length = len(headers) - total_length = 4 + 4 + 4 + headers_length + len(body) + 4 # + message CRC - prelude = struct.pack(">II", total_length, headers_length) - prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) - message_bytes = prelude + prelude_crc + headers + body - message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) - return message_bytes + message_crc - - -class TestExecuteCommand: - """Tests for execute_command (blocking).""" - - def test_aggregates_output(self) -> None: - """stdout/stderr/exitCode/status are aggregated across events.""" - events = [ - _encode_eventstream_frame({"chunk": {"contentStart": {}}}), - _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "hi "}}}), - _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "there"}}}), - _encode_eventstream_frame({"chunk": {"contentDelta": {"stderr": "warn"}}}), - _encode_eventstream_frame( - {"chunk": {"contentStop": {"exitCode": 0, "status": "COMPLETED"}}} - ), - ] - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=events, - ) - client, _ = _make_client(resp) - result = client.execute_command("echo hi there", SESSION, BEARER) - assert result == { - "stdout": "hi there", - "stderr": "warn", - "exitCode": 0, - "status": "COMPLETED", - } - - def test_missing_stop_has_unknown_status(self) -> None: - """Without a contentStop event, defaults are UNKNOWN / -1.""" - events = [ - _encode_eventstream_frame({"chunk": {"contentDelta": {"stdout": "x"}}}), - ] - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=events, - ) - client, _ = _make_client(resp) - result = client.execute_command("echo x", SESSION, BEARER) - assert result == { - "stdout": "x", - "stderr": "", - "exitCode": -1, - "status": "UNKNOWN", - } - - -class TestExecuteCommandStreaming: - """Tests for execute_command_streaming.""" - - def test_sends_command_body(self) -> None: - """Request body contains command and timeout.""" - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=[], - ) - client, pool = _make_client(resp) - list( - client.execute_command_streaming("ls", SESSION, BEARER, command_timeout=42) - ) - sent = pool.request.call_args - assert sent.kwargs["body"] == b'{"command": "ls", "timeout": 42}' - # HTTP timeout must be command_timeout + 30. - assert sent.kwargs["timeout"] == 72 - - def test_defaults_to_self_timeout(self) -> None: - """command_timeout=None falls back to self.timeout.""" - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=[], - ) - client = AgentCoreRuntimeHttpClient(agent_arn=ARN, timeout=120) - client._http = MagicMock() - client._http.request.return_value = resp - list(client.execute_command_streaming("ls", SESSION, BEARER)) - sent = client._http.request.call_args - body = json.loads(sent.kwargs["body"]) - assert body["timeout"] == 120 - assert sent.kwargs["timeout"] == 150 - - def test_yields_parsed_events(self) -> None: - """Each EventStream frame is yielded as a dict.""" - events = [ - _encode_eventstream_frame({"chunk": {"contentStart": {}}}), - _encode_eventstream_frame( - {"chunk": {"contentStop": {"exitCode": 2, "status": "COMPLETED"}}} - ), - ] - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=events, - ) - client, _ = _make_client(resp) - parsed = list(client.execute_command_streaming("echo hi", SESSION, BEARER)) - assert parsed[0] == {"contentStart": {}} - assert parsed[1] == {"contentStop": {"exitCode": 2, "status": "COMPLETED"}} - - def test_skips_event_without_payload(self) -> None: - """Events with empty payload are skipped silently.""" - # An eventstream frame with empty payload body. - import binascii - import struct - - headers = b"" - body = b"" - headers_length = 0 - total_length = 4 + 4 + 4 + headers_length + len(body) + 4 - prelude = struct.pack(">II", total_length, headers_length) - prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) - message_bytes = prelude + prelude_crc + headers + body - message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) - empty_frame = message_bytes + message_crc - - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=[empty_frame], - ) - client, _ = _make_client(resp) - assert list(client.execute_command_streaming("ls", SESSION, BEARER)) == [] - - def test_skips_bad_json_payload(self) -> None: - """Events whose payload is not valid JSON are skipped.""" - # Build a frame whose payload is not valid JSON. - import binascii - import struct - - headers = b"" - body = b"not-json" - headers_length = 0 - total_length = 4 + 4 + 4 + headers_length + len(body) + 4 - prelude = struct.pack(">II", total_length, headers_length) - prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) - message_bytes = prelude + prelude_crc + headers + body - message_crc = struct.pack(">I", binascii.crc32(message_bytes) & 0xFFFFFFFF) - bad_frame = message_bytes + message_crc - - resp = _FakeResponse( - status=200, - headers={"content-type": "application/vnd.amazon.eventstream"}, - chunks=[bad_frame], - ) - client, _ = _make_client(resp) - assert list(client.execute_command_streaming("ls", SESSION, BEARER)) == [] - - def test_non_ok_raises(self) -> None: - """Non-2xx responses raise AgentRuntimeError.""" - resp = _FakeResponse(status=403, body=b"forbidden", reason="Forbidden") - client, _ = _make_client(resp) - with pytest.raises(AgentRuntimeError): - list(client.execute_command_streaming("ls", SESSION, BEARER)) - - -# --------------------------------------------------------------------- # -# stop_runtime_session -# --------------------------------------------------------------------- # - - -class TestStopRuntimeSession: - """Tests for stop_runtime_session.""" - - def test_sends_client_token(self) -> None: - """Request body contains the provided client token.""" - resp = _FakeResponse(status=200, body=b"{}") - client, pool = _make_client(resp) - client.stop_runtime_session(SESSION, BEARER, client_token="abc") - sent = json.loads(pool.request.call_args.kwargs["body"]) - assert sent == {"clientToken": "abc"} - - def test_autogenerates_client_token(self) -> None: - """Missing client_token is replaced with a UUID4.""" - resp = _FakeResponse(status=200, body=b"{}") - client, pool = _make_client(resp) - client.stop_runtime_session(SESSION, BEARER) - sent = json.loads(pool.request.call_args.kwargs["body"]) - # UUID4 string is 36 chars (8-4-4-4-12). - assert len(sent["clientToken"]) == 36 - - def test_returns_empty_dict_on_blank_body(self) -> None: - """Empty response body yields an empty dict.""" - resp = _FakeResponse(status=200, body=b"") - client, _ = _make_client(resp) - assert client.stop_runtime_session(SESSION, BEARER) == {} - - def test_returns_parsed_dict(self) -> None: - """Dict body is returned as-is.""" - resp = _FakeResponse(status=200, body=b'{"sessionId": "x"}') - client, _ = _make_client(resp) - assert client.stop_runtime_session(SESSION, BEARER) == {"sessionId": "x"} - - def test_non_dict_body_wrapped(self) -> None: - """A JSON body that isn't a dict is wrapped under 'response'.""" - resp = _FakeResponse(status=200, body=b'["a", "b"]') - client, _ = _make_client(resp) - assert client.stop_runtime_session(SESSION, BEARER) == {"response": ["a", "b"]} - - def test_invalid_json_body_returns_empty(self) -> None: - """Invalid JSON body returns empty dict rather than raising.""" - resp = _FakeResponse(status=200, body=b"not json") - client, _ = _make_client(resp) - assert client.stop_runtime_session(SESSION, BEARER) == {} - - def test_404_raises(self) -> None: - """Unknown-session HTTP 404 becomes AgentRuntimeError.""" - resp = _FakeResponse(status=404, body=b"{}", reason="Not Found") - client, _ = _make_client(resp) - with pytest.raises(AgentRuntimeError) as excinfo: - client.stop_runtime_session(SESSION, BEARER) - assert excinfo.value.error_type == "HTTP 404" - - -# --------------------------------------------------------------------- # -# _check_response -# --------------------------------------------------------------------- # - - -class TestCheckResponse: - """Tests for _check_response.""" - - def test_noop_on_2xx(self) -> None: - """2xx responses do not raise.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - client._check_response(_FakeResponse(status=200)) - - def test_parses_json_error_body(self) -> None: - """JSON error bodies populate error/error_type/message.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - resp = _FakeResponse( - status=400, - body=b'{"error": "bad", "error_type": "Validation", "message": "details"}', - reason="Bad Request", - ) - with pytest.raises(AgentRuntimeError) as excinfo: - client._check_response(resp) - assert excinfo.value.error == "bad" - assert excinfo.value.error_type == "Validation" - - def test_falls_back_to_text_body(self) -> None: - """Non-JSON bodies become the error string.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - resp = _FakeResponse(status=500, body=b"internal boom", reason="Server Error") - with pytest.raises(AgentRuntimeError) as excinfo: - client._check_response(resp) - assert "internal boom" in excinfo.value.error - assert excinfo.value.error_type == "HTTP 500" - - def test_defaults_error_type_for_dict_body(self) -> None: - """Missing error_type in JSON body defaults to 'HTTP '.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - resp = _FakeResponse(status=502, body=b'{"message": "bad gateway"}') - with pytest.raises(AgentRuntimeError) as excinfo: - client._check_response(resp) - assert excinfo.value.error_type == "HTTP 502" - - def test_empty_body_uses_reason(self) -> None: - """Empty body falls back to the response reason string.""" - client = AgentCoreRuntimeHttpClient(agent_arn=ARN) - resp = _FakeResponse(status=503, body=b"", reason="Service Unavailable") - with pytest.raises(AgentRuntimeError) as excinfo: - client._check_response(resp) - assert excinfo.value.error == "Service Unavailable" - - -# --------------------------------------------------------------------- # -# _decode_sse_line — all branches -# --------------------------------------------------------------------- # - - -class TestDecodeSseLine: - """Tests for _decode_sse_line.""" - - @pytest.fixture - def client(self) -> AgentCoreRuntimeHttpClient: - return AgentCoreRuntimeHttpClient(agent_arn=ARN) - - def test_empty_line(self, client: AgentCoreRuntimeHttpClient) -> None: - """Empty / whitespace-only lines return None.""" - assert client._decode_sse_line("") is None - assert client._decode_sse_line(" ") is None - - def test_comment_line(self, client: AgentCoreRuntimeHttpClient) -> None: - """SSE comment lines (starting with ':') return None.""" - assert client._decode_sse_line(": keepalive") is None - - def test_data_with_json_encoded_string( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """data: "hello" → hello (unwrapped).""" - assert client._decode_sse_line('data: "hello"') == "hello" - - def test_data_with_plain_text(self, client: AgentCoreRuntimeHttpClient) -> None: - """data: plain → plain.""" - assert client._decode_sse_line("data: plain") == "plain" - - def test_data_with_malformed_json_string( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """data: "bad" but missing closing quote: strip what we can.""" - assert client._decode_sse_line('data: "oops') == '"oops' - - def test_data_with_malformed_quoted_string( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """data: \"truncated\\\\\" with trailing quote but bad escape yields trimmed string.""" - # A string that starts and ends with quotes but contains an invalid escape mid-string. - # json.loads fails; the fallback strips the outer quotes. - assert client._decode_sse_line('data: "bad\\escape"') == "bad\\escape" - - def test_data_with_json_error_raises( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """data: { "error": ... } raises AgentRuntimeError.""" - with pytest.raises(AgentRuntimeError) as excinfo: - client._decode_sse_line('data: {"error": "boom", "error_type": "T"}') - assert excinfo.value.error == "boom" - - def test_data_with_json_object_no_error( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """data: {...} without an error key passes through as the raw JSON content.""" - assert client._decode_sse_line('data: {"foo": "bar"}') == '{"foo": "bar"}' - - def test_data_with_broken_json_passes_through( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """data: { broken passes through as the content after 'data:'.""" - assert client._decode_sse_line("data: {not json") == "{not json" - - def test_plain_json_text_key(self, client: AgentCoreRuntimeHttpClient) -> None: - """Plain JSON line with 'text' key returns that value.""" - assert client._decode_sse_line('{"text": "hi"}') == "hi" - - def test_plain_json_content_key(self, client: AgentCoreRuntimeHttpClient) -> None: - """Plain JSON line with 'content' key returns that value.""" - assert client._decode_sse_line('{"content": "c"}') == "c" - - def test_plain_json_data_key(self, client: AgentCoreRuntimeHttpClient) -> None: - """Plain JSON line with 'data' key returns str(that value).""" - assert client._decode_sse_line('{"data": 42}') == "42" - - def test_plain_json_error_raises(self, client: AgentCoreRuntimeHttpClient) -> None: - """Plain JSON line with 'error' raises.""" - with pytest.raises(AgentRuntimeError): - client._decode_sse_line('{"error": "bad"}') - - def test_plain_json_unknown_shape_returns_none( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """Plain JSON dict without known keys returns None.""" - assert client._decode_sse_line('{"something": "else"}') is None - - def test_plain_text_passes_through( - self, client: AgentCoreRuntimeHttpClient - ) -> None: - """Non-JSON plain text is returned verbatim.""" - assert client._decode_sse_line("hello world") == "hello world" - - -# --------------------------------------------------------------------- # -# _iter_lines -# --------------------------------------------------------------------- # - - -class TestIterLines: - """Tests for _iter_lines chunk splitting.""" - - def test_splits_on_lf(self) -> None: - """Splits on newline and yields each line.""" - resp = _FakeResponse(chunks=[b"one\ntwo\nthree\n"]) - lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) - assert lines == [b"one", b"two", b"three"] - - def test_strips_crlf(self) -> None: - """Trailing \\r is stripped.""" - resp = _FakeResponse(chunks=[b"a\r\nb\r\n"]) - lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) - assert lines == [b"a", b"b"] - - def test_trailing_without_newline(self) -> None: - """A trailing fragment without newline is still yielded.""" - resp = _FakeResponse(chunks=[b"done"]) - lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) - assert lines == [b"done"] - - def test_trailing_cr_stripped(self) -> None: - """A trailing CR without LF is stripped.""" - resp = _FakeResponse(chunks=[b"done\r"]) - lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) - assert lines == [b"done"] - - def test_split_across_chunks(self) -> None: - """Lines spanning chunk boundaries are reassembled.""" - resp = _FakeResponse(chunks=[b"hel", b"lo\nwo", b"rld\n"]) - lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) - assert lines == [b"hello", b"world"] - - def test_empty_chunk_skipped(self) -> None: - """Empty chunks in the stream are ignored.""" - resp = _FakeResponse(chunks=[b"a\n", b"", b"b\n"]) - lines = list(AgentCoreRuntimeHttpClient._iter_lines(resp)) - assert lines == [b"a", b"b"] - - -# --------------------------------------------------------------------- # -# AgentRuntimeError -# --------------------------------------------------------------------- # - - -class TestAgentRuntimeError: - """Tests for AgentRuntimeError construction and rendering.""" - - def test_stores_error_and_type(self) -> None: - """error and error_type are accessible as attributes.""" - exc = AgentRuntimeError(error="boom", error_type="T") - assert exc.error == "boom" - assert exc.error_type == "T" - - def test_default_error_type_empty(self) -> None: - """error_type defaults to the empty string.""" - exc = AgentRuntimeError(error="boom") - assert exc.error_type == "" - - def test_str_uses_message(self) -> None: - """str() prefers the message when provided.""" - exc = AgentRuntimeError(error="e", message="human-readable") - assert str(exc) == "human-readable" - - def test_str_falls_back_to_error(self) -> None: - """str() falls back to the error token when no message.""" - exc = AgentRuntimeError(error="only-error") - assert str(exc) == "only-error" - - def test_is_exception(self) -> None: - """Subclasses Exception so it can be raised.""" - try: - raise AgentRuntimeError(error="x") - except Exception as exc: - assert isinstance(exc, AgentRuntimeError)