diff --git a/src/bedrock_agentcore/_utils/config.py b/src/bedrock_agentcore/_utils/config.py new file mode 100644 index 00000000..a7a6e025 --- /dev/null +++ b/src/bedrock_agentcore/_utils/config.py @@ -0,0 +1,22 @@ +"""Shared configuration dataclasses for SDK clients.""" + +from dataclasses import dataclass + + +@dataclass +class WaitConfig: + """Configuration for *_and_wait polling methods. + + Args: + max_wait: Maximum seconds to wait. Default: 300. Must be >= 1. + poll_interval: Seconds between status checks. Default: 10. Must be >= 1. + """ + + max_wait: int = 300 + poll_interval: int = 10 + + def __post_init__(self): + if self.max_wait < 1: + raise ValueError("max_wait must be at least 1") + if self.poll_interval < 1: + raise ValueError("poll_interval must be at least 1") diff --git a/src/bedrock_agentcore/_utils/polling.py b/src/bedrock_agentcore/_utils/polling.py new file mode 100644 index 00000000..7ff10435 --- /dev/null +++ b/src/bedrock_agentcore/_utils/polling.py @@ -0,0 +1,92 @@ +"""Shared polling helpers for SDK clients.""" + +import logging +import time +from typing import Any, Callable, Dict, Optional, Set + +from .config import WaitConfig + +logger = logging.getLogger(__name__) + + +def wait_until( + poll_fn: Callable[[], Dict[str, Any]], + target: str, + failed: Set[str], + wait_config: Optional[WaitConfig] = None, + error_field: str = "statusReasons", +) -> Dict[str, Any]: + """Poll until a resource reaches the target status. + + Args: + poll_fn: Zero-arg callable that returns the resource's current state. + target: The status to wait for (e.g. "ACTIVE", "READY"). + failed: Statuses that indicate terminal failure. + wait_config: Optional WaitConfig for polling behavior. + error_field: Response field containing error details. + + Returns: + Full response when target status is reached. + + Raises: + RuntimeError: If the resource reaches a failed status. + TimeoutError: If target status is not reached within max_wait. + """ + wait = wait_config or WaitConfig() + start_time = time.time() + while True: + resp = poll_fn() + status = resp.get("status") + if status is None: + logger.warning("Response missing 'status' field: %s", resp) + if status == target: + return resp + if status in failed: + reason = resp.get(error_field, "Unknown") + raise RuntimeError("Reached %s: %s" % (status, reason)) + if time.time() - start_time >= wait.max_wait: + break + time.sleep(wait.poll_interval) + raise TimeoutError("Did not reach %s within %d seconds" % (target, wait.max_wait)) + + +def wait_until_deleted( + poll_fn: Callable[[], Dict[str, Any]], + not_found_code: str = "ResourceNotFoundException", + failed: Optional[Set[str]] = None, + wait_config: Optional[WaitConfig] = None, + error_field: str = "statusReasons", +) -> None: + """Poll until a resource is deleted (raises not-found exception). + + Args: + poll_fn: Zero-arg callable that calls the get API. + not_found_code: The error code indicating the resource is gone. + failed: Optional set of statuses that indicate deletion failed. + wait_config: Optional WaitConfig for polling behavior. + error_field: Response field containing error details. + + Raises: + RuntimeError: If the resource reaches a failed status. + TimeoutError: If the resource is not deleted within max_wait. + """ + from botocore.exceptions import ClientError + + wait = wait_config or WaitConfig() + start_time = time.time() + while True: + try: + resp = poll_fn() + if failed: + status = resp.get("status") + if status in failed: + reason = resp.get(error_field, "Unknown") + raise RuntimeError("Reached %s: %s" % (status, reason)) + except ClientError as e: + if e.response["Error"]["Code"] == not_found_code: + return + raise + if time.time() - start_time >= wait.max_wait: + break + time.sleep(wait.poll_interval) + raise TimeoutError("Resource was not deleted within %d seconds" % wait.max_wait) diff --git a/src/bedrock_agentcore/_utils/snake_case.py b/src/bedrock_agentcore/_utils/snake_case.py index d6be4b77..79598eb8 100644 --- a/src/bedrock_agentcore/_utils/snake_case.py +++ b/src/bedrock_agentcore/_utils/snake_case.py @@ -44,3 +44,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return method(*args, **converted) return wrapper + + +def convert_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Convert snake_case kwargs to camelCase for direct boto3 calls.""" + return {snake_to_camel(k): v for k, v in kwargs.items()} diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 00000000..2b40f8b1 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,114 @@ +"""Tests for shared _utils: pagination and polling.""" + +from unittest.mock import Mock, patch + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore._utils.polling import wait_until, wait_until_deleted + + +class TestWaitUntil: + def test_immediate_success(self): + poll_fn = Mock(return_value={"status": "ACTIVE"}) + result = wait_until(poll_fn, "ACTIVE", {"FAILED"}) + assert result["status"] == "ACTIVE" + poll_fn.assert_called_once() + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch( + "bedrock_agentcore._utils.polling.time.time", + side_effect=[0, 0, 0, 1, 1], + ) + def test_polls_until_target(self, _mock_time, _mock_sleep): + poll_fn = Mock( + side_effect=[{"status": "CREATING"}, {"status": "ACTIVE"}], + ) + result = wait_until(poll_fn, "ACTIVE", {"FAILED"}) + assert result["status"] == "ACTIVE" + assert poll_fn.call_count == 2 + + def test_raises_on_failed_status(self): + poll_fn = Mock( + return_value={"status": "FAILED", "statusReasons": ["broke"]}, + ) + with pytest.raises(RuntimeError, match="FAILED"): + wait_until(poll_fn, "ACTIVE", {"FAILED"}) + + def test_custom_error_field(self): + poll_fn = Mock( + return_value={ + "status": "CREATE_FAILED", + "failureReason": "bad config", + }, + ) + with pytest.raises(RuntimeError, match="bad config"): + wait_until( + poll_fn, + "ACTIVE", + {"CREATE_FAILED"}, + error_field="failureReason", + ) + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch( + "bedrock_agentcore._utils.polling.time.time", + side_effect=[0, 0, 0, 301], + ) + def test_timeout(self, _mock_time, _mock_sleep): + poll_fn = Mock(return_value={"status": "CREATING"}) + with pytest.raises(TimeoutError): + wait_until(poll_fn, "ACTIVE", {"FAILED"}) + + +class TestWaitUntilDeleted: + def test_immediate_not_found(self): + poll_fn = Mock( + side_effect=ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": ""}}, + "Get", + ), + ) + wait_until_deleted(poll_fn) + poll_fn.assert_called_once() + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch( + "bedrock_agentcore._utils.polling.time.time", + side_effect=[0, 0, 0, 1, 1], + ) + def test_polls_then_deleted(self, _mock_time, _mock_sleep): + poll_fn = Mock( + side_effect=[ + {"status": "DELETING"}, + ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": ""}}, + "Get", + ), + ], + ) + wait_until_deleted(poll_fn) + assert poll_fn.call_count == 2 + + def test_raises_on_failed_status(self): + poll_fn = Mock( + return_value={ + "status": "DELETE_FAILED", + "statusReasons": ["stuck"], + }, + ) + with pytest.raises(RuntimeError, match="DELETE_FAILED"): + wait_until_deleted( + poll_fn, + failed={"DELETE_FAILED"}, + ) + + @patch("bedrock_agentcore._utils.polling.time.sleep") + @patch( + "bedrock_agentcore._utils.polling.time.time", + side_effect=[0, 0, 0, 301], + ) + def test_timeout(self, _mock_time, _mock_sleep): + poll_fn = Mock(return_value={"status": "DELETING"}) + with pytest.raises(TimeoutError): + wait_until_deleted(poll_fn)