Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 321 additions & 14 deletions src/bedrock_agentcore/runtime/agent_core_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,28 @@
import logging
import secrets
import uuid
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
from urllib.parse import quote, urlencode, urlparse

import boto3
from botocore.auth import SigV4Auth, SigV4QueryAuth
from botocore.awsrequest import AWSRequest

from .._utils.endpoints import get_data_plane_endpoint
from botocore.config import Config
from botocore.exceptions import ClientError

from .._utils.config import WaitConfig
from .._utils.endpoints import get_data_plane_endpoint, validate_region
from .._utils.polling import wait_until, wait_until_deleted
from .._utils.snake_case import accept_snake_case_kwargs, convert_kwargs
from .._utils.user_agent import build_user_agent_suffix
from .utils import is_valid_partition

DEFAULT_PRESIGNED_URL_TIMEOUT = 300
MAX_PRESIGNED_URL_TIMEOUT = 300

_RUNTIME_FAILED_STATUSES = {"CREATE_FAILED", "UPDATE_FAILED"}
_ENDPOINT_FAILED_STATUSES = {"CREATE_FAILED", "UPDATE_FAILED", "DELETE_FAILED"}


class AgentCoreRuntimeClient:
"""Client for generating WebSocket authentication for AgentCore Runtime.
Expand All @@ -35,24 +44,86 @@ class AgentCoreRuntimeClient:
session (boto3.Session): The boto3 session for AWS credentials.
"""

def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None:
_ALLOWED_DP_METHODS = {
"invoke_agent_runtime",
"stop_runtime_session",
}

_ALLOWED_CP_METHODS = {
"create_agent_runtime",
"update_agent_runtime",
"get_agent_runtime",
"delete_agent_runtime",
"list_agent_runtimes",
"create_agent_runtime_endpoint",
"get_agent_runtime_endpoint",
"update_agent_runtime_endpoint",
"delete_agent_runtime_endpoint",
"list_agent_runtime_endpoints",
"list_agent_runtime_versions",
"delete_agent_runtime_version",
}

def __init__(
self,
region: Optional[str] = None,
session: Optional[boto3.Session] = None,
integration_source: Optional[str] = None,
) -> None:
"""Initialize an AgentCoreRuntime client for the specified AWS region.

Args:
region (str): The AWS region to use for the AgentCore Runtime service.
session (Optional[boto3.Session]): Optional boto3 session. If not provided,
a new session will be created using default credentials.
region: AWS region name. If not provided, uses the session's
region or "us-west-2".
session: Optional boto3 Session to use. If not provided, a
default session is created.
integration_source: Optional integration source for user-agent
telemetry.
"""
from .._utils.endpoints import validate_region

validate_region(region)
self.region = region
session = session if session else boto3.Session()
self.region = validate_region(region or session.region_name or "us-west-2")
self.session = session
self.integration_source = integration_source
self.logger = logging.getLogger(__name__)

if session is None:
session = boto3.Session()
user_agent_extra = build_user_agent_suffix(integration_source)
client_config = Config(user_agent_extra=user_agent_extra)

self.session = session
self.cp_client = session.client(
"bedrock-agentcore-control",
region_name=self.region,
config=client_config,
)
self.dp_client = session.client(
"bedrock-agentcore",
region_name=self.region,
config=client_config,
)
self.logger.info(
"Initialized AgentCoreRuntimeClient for region: %s",
self.region,
)

# Pass-through
# -------------------------------------------------------------------------
def __getattr__(self, name: str):
"""Dynamically forward allowlisted method calls to the appropriate boto3 client."""
if name in self._ALLOWED_DP_METHODS and hasattr(self.dp_client, name):
method = getattr(self.dp_client, name)
self.logger.debug("Forwarding method '%s' to dp_client", name)
return accept_snake_case_kwargs(method)

if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name):
method = getattr(self.cp_client, name)
self.logger.debug("Forwarding method '%s' to cp_client", name)
return accept_snake_case_kwargs(method)

raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'. "
f"Method not found on dp_client or cp_client. "
f"Available methods can be found in the boto3 documentation for "
f"'bedrock-agentcore' and 'bedrock-agentcore-control' services."
)

def _parse_runtime_arn(self, runtime_arn: str) -> Dict[str, str]:
"""Parse runtime ARN and extract components.
Expand Down Expand Up @@ -401,3 +472,239 @@ def generate_ws_connection_oauth(
self.logger.debug("Bearer token length: %d characters", len(bearer_token))

return ws_url, headers

# *_and_wait methods
# -------------------------------------------------------------------------
def create_agent_runtime_and_wait(
self,
wait_config: Optional[WaitConfig] = None,
**kwargs,
) -> Dict[str, Any]:
"""Create an agent runtime and wait for it to reach READY status.

Args:
wait_config: Optional WaitConfig for polling behavior.
**kwargs: Arguments forwarded to the create_agent_runtime API.

Returns:
Runtime details when READY.

Raises:
RuntimeError: If the runtime reaches a failed state.
TimeoutError: If the runtime doesn't become READY within max_wait.
"""
response = self.cp_client.create_agent_runtime(**convert_kwargs(kwargs))
rid = response["agentRuntimeId"]
return wait_until(
lambda: self.cp_client.get_agent_runtime(agentRuntimeId=rid),
"READY",
_RUNTIME_FAILED_STATUSES,
wait_config,
error_field="failureReason",
)

def update_agent_runtime_and_wait(
self,
wait_config: Optional[WaitConfig] = None,
**kwargs,
) -> Dict[str, Any]:
"""Update an agent runtime and wait for it to reach READY status.

Args:
wait_config: Optional WaitConfig for polling behavior.
**kwargs: Arguments forwarded to the update_agent_runtime API.

Returns:
Runtime details when READY.

Raises:
RuntimeError: If the runtime reaches a failed state.
TimeoutError: If the runtime doesn't become READY within max_wait.
"""
response = self.cp_client.update_agent_runtime(**convert_kwargs(kwargs))
rid = response["agentRuntimeId"]
return wait_until(
lambda: self.cp_client.get_agent_runtime(agentRuntimeId=rid),
"READY",
_RUNTIME_FAILED_STATUSES,
wait_config,
error_field="failureReason",
)

def delete_agent_runtime_and_wait(
self,
wait_config: Optional[WaitConfig] = None,
**kwargs,
) -> None:
"""Delete an agent runtime and wait for deletion to complete.

Args:
wait_config: Optional WaitConfig for polling behavior.
**kwargs: Arguments forwarded to the delete_agent_runtime API.

Raises:
TimeoutError: If the runtime isn't deleted within max_wait.
"""
response = self.cp_client.delete_agent_runtime(**convert_kwargs(kwargs))
rid = response["agentRuntimeId"]
wait_until_deleted(
lambda: self.cp_client.get_agent_runtime(agentRuntimeId=rid),
wait_config=wait_config,
)

def create_agent_runtime_endpoint_and_wait(
self,
wait_config: Optional[WaitConfig] = None,
**kwargs,
) -> Dict[str, Any]:
"""Create an agent runtime endpoint and wait for it to reach READY.

Args:
wait_config: Optional WaitConfig for polling behavior.
**kwargs: Arguments forwarded to the
create_agent_runtime_endpoint API.

Returns:
Endpoint details when READY.

Raises:
RuntimeError: If the endpoint reaches a failed state.
TimeoutError: If the endpoint doesn't become READY within
max_wait.
"""
converted = convert_kwargs(kwargs)
response = self.cp_client.create_agent_runtime_endpoint(
**converted,
)
rid = converted.get("agentRuntimeId")
ename = response.get("name", kwargs.get("name", "DEFAULT"))
return wait_until(
lambda: self.cp_client.get_agent_runtime_endpoint(
agentRuntimeId=rid,
endpointName=ename,
),
"READY",
_ENDPOINT_FAILED_STATUSES,
wait_config,
error_field="failureReason",
)

def update_agent_runtime_endpoint_and_wait(
self,
wait_config: Optional[WaitConfig] = None,
**kwargs,
) -> Dict[str, Any]:
"""Update an agent runtime endpoint and wait for READY status.

Args:
wait_config: Optional WaitConfig for polling behavior.
**kwargs: Arguments forwarded to the
update_agent_runtime_endpoint API.

Returns:
Endpoint details when READY.

Raises:
RuntimeError: If the endpoint reaches a failed state.
TimeoutError: If the endpoint doesn't become READY within
max_wait.
"""
converted = convert_kwargs(kwargs)
response = self.cp_client.update_agent_runtime_endpoint(
**converted,
)
rid = converted.get("agentRuntimeId")
ename = response.get("name", kwargs.get("endpointName", "DEFAULT"))
return wait_until(
lambda: self.cp_client.get_agent_runtime_endpoint(
agentRuntimeId=rid,
endpointName=ename,
),
"READY",
_ENDPOINT_FAILED_STATUSES,
wait_config,
error_field="failureReason",
)

# Higher-level orchestration methods
# -------------------------------------------------------------------------
def get_aggregated_status(
self,
agent_runtime_id: str,
endpoint_name: str = "DEFAULT",
) -> Dict[str, Any]:
"""Get aggregated status of runtime and endpoint.

Args:
agent_runtime_id: The agent runtime ID.
endpoint_name: Endpoint name (default: "DEFAULT").

Returns:
Dict with 'runtime' and 'endpoint' status details.
"""
result: Dict[str, Any] = {"runtime": None, "endpoint": None}

try:
result["runtime"] = self.cp_client.get_agent_runtime(
agentRuntimeId=agent_runtime_id,
)
except ClientError as e:
if e.response["Error"]["Code"] != "ResourceNotFoundException":
raise
result["runtime"] = {"error": str(e)}

try:
result["endpoint"] = self.cp_client.get_agent_runtime_endpoint(
agentRuntimeId=agent_runtime_id,
endpointName=endpoint_name,
)
except ClientError as e:
if e.response["Error"]["Code"] != "ResourceNotFoundException":
raise
result["endpoint"] = {"error": str(e)}

return result

def teardown_endpoint_and_runtime(
self,
agent_runtime_id: str,
endpoint_name: str = "DEFAULT",
) -> None:
"""Delete endpoint then runtime in correct order.

Silently ignores ResourceNotFoundException for either resource
(already deleted).

Args:
agent_runtime_id: The agent runtime ID.
endpoint_name: Endpoint name (default: "DEFAULT").
"""
try:
self.cp_client.delete_agent_runtime_endpoint(
agentRuntimeId=agent_runtime_id,
endpointName=endpoint_name,
)
self.logger.info(
"Deleted endpoint '%s' for runtime %s",
endpoint_name,
agent_runtime_id,
)
wait_until_deleted(
lambda: self.cp_client.get_agent_runtime_endpoint(
agentRuntimeId=agent_runtime_id,
endpointName=endpoint_name,
),
)
except ClientError as e:
if e.response["Error"]["Code"] != "ResourceNotFoundException":
raise
self.logger.info("Endpoint '%s' not found, skipping", endpoint_name)

try:
self.delete_agent_runtime_and_wait(
agentRuntimeId=agent_runtime_id,
)
except ClientError as e:
if e.response["Error"]["Code"] != "ResourceNotFoundException":
raise
self.logger.info("Runtime %s not found, skipping", agent_runtime_id)
Loading
Loading