diff --git a/README.md b/README.md index f8edc5f..931e33c 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,12 @@ from axme_sdk import AxmeClient, AxmeClientConfig config = AxmeClientConfig( base_url="https://gateway.example.com", api_key="YOUR_API_KEY", + max_retries=2, + retry_backoff_seconds=0.2, ) with AxmeClient(config) as client: - print(client.health()) + print(client.health(trace_id="trace-quickstart-001")) result = client.create_intent( { "intent_type": "notify.message.v1", @@ -29,7 +31,7 @@ with AxmeClient(config) as client: idempotency_key="create-intent-001", ) print(result) - inbox = client.list_inbox(owner_agent="agent://example/receiver") + inbox = client.list_inbox(owner_agent="agent://example/receiver", trace_id="trace-inbox-001") print(inbox) changes = client.list_inbox_changes(owner_agent="agent://example/receiver", limit=50) print(changes["next_cursor"], changes["has_more"]) diff --git a/axme_sdk/client.py b/axme_sdk/client.py index 2c21ff4..e665975 100644 --- a/axme_sdk/client.py +++ b/axme_sdk/client.py @@ -1,7 +1,9 @@ from __future__ import annotations from dataclasses import dataclass +import time from typing import Any +from uuid import uuid4 import httpx @@ -19,6 +21,9 @@ class AxmeClientConfig: base_url: str api_key: str timeout_seconds: float = 15.0 + max_retries: int = 2 + retry_backoff_seconds: float = 0.2 + auto_trace_id: bool = True class AxmeClient: @@ -44,9 +49,8 @@ def __enter__(self) -> "AxmeClient": def __exit__(self, exc_type: Any, exc: Any, traceback: Any) -> None: self.close() - def health(self) -> dict[str, Any]: - response = self._http.get("/health") - return self._parse_json_response(response) + def health(self, *, trace_id: str | None = None) -> dict[str, Any]: + return self._request_json("GET", "/health", trace_id=trace_id, retryable=True) def create_intent( self, @@ -54,6 +58,7 @@ def create_intent( *, correlation_id: str, idempotency_key: str | None = None, + trace_id: str | None = None, ) -> dict[str, Any]: request_payload = dict(payload) existing_correlation_id = request_payload.get("correlation_id") @@ -61,26 +66,32 @@ def create_intent( raise ValueError("payload correlation_id must match correlation_id argument") request_payload["correlation_id"] = correlation_id - headers: dict[str, str] | None = None - if idempotency_key is not None: - headers = {"Idempotency-Key": idempotency_key} - - response = self._http.post("/v1/intents", json=request_payload, headers=headers) - return self._parse_json_response(response) + return self._request_json( + "POST", + "/v1/intents", + json_body=request_payload, + idempotency_key=idempotency_key, + trace_id=trace_id, + retryable=idempotency_key is not None, + ) - def list_inbox(self, *, owner_agent: str | None = None) -> dict[str, Any]: + def list_inbox(self, *, owner_agent: str | None = None, trace_id: str | None = None) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - response = self._http.get("/v1/inbox", params=params) - return self._parse_json_response(response) + return self._request_json("GET", "/v1/inbox", params=params, trace_id=trace_id, retryable=True) - def get_inbox_thread(self, thread_id: str, *, owner_agent: str | None = None) -> dict[str, Any]: + def get_inbox_thread(self, thread_id: str, *, owner_agent: str | None = None, trace_id: str | None = None) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - response = self._http.get(f"/v1/inbox/{thread_id}", params=params) - return self._parse_json_response(response) + return self._request_json( + "GET", + f"/v1/inbox/{thread_id}", + params=params, + trace_id=trace_id, + retryable=True, + ) def list_inbox_changes( self, @@ -88,6 +99,7 @@ def list_inbox_changes( owner_agent: str | None = None, cursor: str | None = None, limit: int | None = None, + trace_id: str | None = None, ) -> dict[str, Any]: params: dict[str, str] = {} if owner_agent is not None: @@ -96,8 +108,13 @@ def list_inbox_changes( params["cursor"] = cursor if limit is not None: params["limit"] = str(limit) - response = self._http.get("/v1/inbox/changes", params=params or None) - return self._parse_json_response(response) + return self._request_json( + "GET", + "/v1/inbox/changes", + params=params or None, + trace_id=trace_id, + retryable=True, + ) def reply_inbox_thread( self, @@ -106,66 +123,167 @@ def reply_inbox_thread( message: str, owner_agent: str | None = None, idempotency_key: str | None = None, + trace_id: str | None = None, ) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - headers: dict[str, str] | None = None - if idempotency_key is not None: - headers = {"Idempotency-Key": idempotency_key} - response = self._http.post( + return self._request_json( + "POST", f"/v1/inbox/{thread_id}/reply", params=params, - json={"message": message}, - headers=headers, + json_body={"message": message}, + idempotency_key=idempotency_key, + trace_id=trace_id, + retryable=idempotency_key is not None, ) - return self._parse_json_response(response) def upsert_webhook_subscription( self, payload: dict[str, Any], *, idempotency_key: str | None = None, + trace_id: str | None = None, ) -> dict[str, Any]: - headers: dict[str, str] | None = None - if idempotency_key is not None: - headers = {"Idempotency-Key": idempotency_key} - response = self._http.post("/v1/webhooks/subscriptions", json=payload, headers=headers) - return self._parse_json_response(response) + return self._request_json( + "POST", + "/v1/webhooks/subscriptions", + json_body=payload, + idempotency_key=idempotency_key, + trace_id=trace_id, + retryable=idempotency_key is not None, + ) - def list_webhook_subscriptions(self, *, owner_agent: str | None = None) -> dict[str, Any]: + def list_webhook_subscriptions(self, *, owner_agent: str | None = None, trace_id: str | None = None) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - response = self._http.get("/v1/webhooks/subscriptions", params=params) - return self._parse_json_response(response) + return self._request_json("GET", "/v1/webhooks/subscriptions", params=params, trace_id=trace_id, retryable=True) - def delete_webhook_subscription(self, subscription_id: str, *, owner_agent: str | None = None) -> dict[str, Any]: + def delete_webhook_subscription( + self, + subscription_id: str, + *, + owner_agent: str | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - response = self._http.delete(f"/v1/webhooks/subscriptions/{subscription_id}", params=params) - return self._parse_json_response(response) + return self._request_json( + "DELETE", + f"/v1/webhooks/subscriptions/{subscription_id}", + params=params, + trace_id=trace_id, + retryable=True, + ) - def publish_webhook_event(self, payload: dict[str, Any], *, owner_agent: str | None = None) -> dict[str, Any]: + def publish_webhook_event( + self, + payload: dict[str, Any], + *, + owner_agent: str | None = None, + idempotency_key: str | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - response = self._http.post("/v1/webhooks/events", params=params, json=payload) - return self._parse_json_response(response) + return self._request_json( + "POST", + "/v1/webhooks/events", + params=params, + json_body=payload, + idempotency_key=idempotency_key, + trace_id=trace_id, + retryable=idempotency_key is not None, + ) - def replay_webhook_event(self, event_id: str, *, owner_agent: str | None = None) -> dict[str, Any]: + def replay_webhook_event( + self, + event_id: str, + *, + owner_agent: str | None = None, + idempotency_key: str | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: params: dict[str, str] | None = None if owner_agent is not None: params = {"owner_agent": owner_agent} - response = self._http.post(f"/v1/webhooks/events/{event_id}/replay", params=params) - return self._parse_json_response(response) + response = self._request_json( + "POST", + f"/v1/webhooks/events/{event_id}/replay", + params=params, + idempotency_key=idempotency_key, + trace_id=trace_id, + retryable=idempotency_key is not None, + ) + return response + + def _request_json( + self, + method: str, + path: str, + *, + params: dict[str, str] | None = None, + json_body: dict[str, Any] | None = None, + idempotency_key: str | None = None, + trace_id: str | None = None, + retryable: bool, + ) -> dict[str, Any]: + headers: dict[str, str] | None = None + normalized_trace_id = self._normalize_trace_id(trace_id) + if idempotency_key is not None or normalized_trace_id is not None: + headers = {} + if idempotency_key is not None: + headers["Idempotency-Key"] = idempotency_key + if normalized_trace_id is not None: + headers["X-Trace-Id"] = normalized_trace_id + + attempts = 1 + (self._config.max_retries if retryable else 0) + for attempt_idx in range(attempts): + try: + response = self._http.request( + method=method, + url=path, + params=params, + json=json_body, + headers=headers, + ) + except (httpx.TimeoutException, httpx.TransportError): + if attempt_idx >= attempts - 1: + raise + self._sleep_before_retry(attempt_idx, retry_after=None) + continue + + if retryable and attempt_idx < attempts - 1 and _is_retryable_status(response.status_code): + retry_after = _parse_retry_after(response.headers.get("Retry-After")) + self._sleep_before_retry(attempt_idx, retry_after=retry_after) + continue + return self._parse_json_response(response) + + raise RuntimeError("unreachable retry loop state") + + def _sleep_before_retry(self, attempt_idx: int, *, retry_after: int | None) -> None: + if retry_after is not None: + time.sleep(max(0, retry_after)) + return + backoff = self._config.retry_backoff_seconds * (2**attempt_idx) + time.sleep(max(0.0, backoff)) + + def _normalize_trace_id(self, trace_id: str | None) -> str | None: + if trace_id is not None: + return trace_id + if self._config.auto_trace_id: + return str(uuid4()) + return None def _parse_json_response(self, response: httpx.Response) -> dict[str, Any]: if response.status_code >= 400: self._raise_http_error(response) return response.json() + def _raise_http_error(self, response: httpx.Response) -> None: body: Any | None body = None @@ -212,3 +330,7 @@ def _parse_retry_after(value: str | None) -> int | None: return int(value) except ValueError: return None + + +def _is_retryable_status(status_code: int) -> bool: + return status_code == 429 or status_code >= 500 diff --git a/tests/test_client.py b/tests/test_client.py index 72f9fb1..d4c7705 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,8 +13,21 @@ def _transport(handler): return httpx.MockTransport(handler) -def _client(handler, api_key: str = "token") -> AxmeClient: - cfg = AxmeClientConfig(base_url="https://api.axme.test", api_key=api_key) +def _client( + handler, + api_key: str = "token", + *, + max_retries: int = 2, + retry_backoff_seconds: float = 0.0, + auto_trace_id: bool = True, +) -> AxmeClient: + cfg = AxmeClientConfig( + base_url="https://api.axme.test", + api_key=api_key, + max_retries=max_retries, + retry_backoff_seconds=retry_backoff_seconds, + auto_trace_id=auto_trace_id, + ) http_client = httpx.Client( transport=_transport(handler), base_url=cfg.base_url, @@ -74,6 +87,15 @@ def handler(request: httpx.Request) -> httpx.Response: assert client.health() == {"ok": True} +def test_health_propagates_trace_id_header() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.headers.get("x-trace-id") == "trace-123" + return httpx.Response(200, json={"ok": True}) + + client = _client(handler, auto_trace_id=False) + assert client.health(trace_id="trace-123") == {"ok": True} + + def test_create_intent_success() -> None: payload = { "intent_type": "notify.message.v1", @@ -229,7 +251,7 @@ def test_client_maps_rate_limit_error_with_retry_after() -> None: def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(429, json={"message": "too many"}, headers={"Retry-After": "30"}) - client = _client(handler) + client = _client(handler, max_retries=0) with pytest.raises(AxmeRateLimitError) as exc_info: client.list_inbox() assert exc_info.value.retry_after == 30 @@ -346,3 +368,57 @@ def handler(request: httpx.Request) -> httpx.Response: client = _client(handler) assert client.replay_webhook_event(event_id, owner_agent="agent://owner")["event_id"] == event_id + + +def test_retries_retryable_get_on_transient_server_error() -> None: + attempts = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal attempts + attempts += 1 + if attempts == 1: + return httpx.Response(500, json={"error": "temporary"}) + return httpx.Response(200, json={"ok": True, "threads": []}) + + client = _client(handler) + assert client.list_inbox(owner_agent="agent://owner") == {"ok": True, "threads": []} + assert attempts == 2 + + +def test_retries_post_when_idempotency_key_is_present() -> None: + attempts = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal attempts + attempts += 1 + if attempts == 1: + return httpx.Response(500, json={"error": "temporary"}) + return httpx.Response(200, json={"intent_id": "it_123"}) + + client = _client(handler) + assert ( + client.create_intent( + {"intent_type": "notify.message.v1", "to_agent": "agent://x", "from_agent": "agent://y", "payload": {}}, + correlation_id="11111111-1111-1111-1111-111111111111", + idempotency_key="idem-retry", + ) + == {"intent_id": "it_123"} + ) + assert attempts == 2 + + +def test_does_not_retry_post_without_idempotency_key() -> None: + attempts = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal attempts + attempts += 1 + return httpx.Response(500, json={"error": "temporary"}) + + client = _client(handler) + with pytest.raises(AxmeHttpError): + client.create_intent( + {"intent_type": "notify.message.v1", "to_agent": "agent://x", "from_agent": "agent://y", "payload": {}}, + correlation_id="11111111-1111-1111-1111-111111111111", + ) + assert attempts == 1