From b327c8cd8add124165fb03ab79dcc87763b7c532 Mon Sep 17 00:00:00 2001 From: geobelsky Date: Thu, 26 Mar 2026 08:50:41 +0000 Subject: [PATCH] Add Session API methods - create, get, list, messages, feed, SSE listen, complete - 8 new public methods: create_session, get_session, list_sessions, post_session_message, list_session_messages, get_session_feed, listen_session, complete_session - SSE streaming with auto-reconnect on timeout - 13 new tests, 87/87 total green --- axme_sdk/client.py | 219 +++++++++++++++++++++++++++++++++++++++++++ tests/test_client.py | 215 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 434 insertions(+) diff --git a/axme_sdk/client.py b/axme_sdk/client.py index 52ed4b4..33a3b22 100644 --- a/axme_sdk/client.py +++ b/axme_sdk/client.py @@ -1619,6 +1619,225 @@ def mcp_call_tool( should_retry = retryable if retryable is not None else bool(idempotency_key) return self._mcp_request(payload=payload, trace_id=trace_id, retryable=should_retry) + # ------------------------------------------------------------------ + # Session API + # ------------------------------------------------------------------ + + def create_session( + self, + *, + type: str = "task", + project_id: str | None = None, + parent_session_id: str | None = None, + depends_on: list[str] | None = None, + metadata: dict[str, Any] | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: + body: dict[str, Any] = {"type": type} + if project_id is not None: + body["project_id"] = project_id + if parent_session_id is not None: + body["parent_session_id"] = parent_session_id + if depends_on is not None: + body["depends_on"] = depends_on + if metadata is not None: + body["metadata"] = metadata + return self._request_json("POST", "/v1/sessions", json_body=body, trace_id=trace_id, retryable=False) + + def get_session( + self, + session_id: str, + *, + trace_id: str | None = None, + ) -> dict[str, Any]: + return self._request_json("GET", f"/v1/sessions/{session_id}", trace_id=trace_id, retryable=True) + + def list_sessions( + self, + *, + status: str | None = None, + parent_session_id: str | None = None, + limit: int | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: + params: dict[str, str] = {} + if status is not None: + params["status"] = status + if parent_session_id is not None: + params["parent_session_id"] = parent_session_id + if limit is not None: + params["limit"] = str(limit) + return self._request_json("GET", "/v1/sessions", params=params or None, trace_id=trace_id, retryable=True) + + def post_session_message( + self, + session_id: str, + *, + role: str, + content: Any, + content_type: str = "text", + trace_id: str | None = None, + ) -> dict[str, Any]: + body: dict[str, Any] = {"role": role, "content": content, "content_type": content_type} + return self._request_json("POST", f"/v1/sessions/{session_id}/messages", json_body=body, trace_id=trace_id, retryable=False) + + def list_session_messages( + self, + session_id: str, + *, + since: int = 0, + limit: int | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: + params: dict[str, str] = {} + if since > 0: + params["since"] = str(since) + if limit is not None: + params["limit"] = str(limit) + return self._request_json( + "GET", f"/v1/sessions/{session_id}/messages", params=params or None, trace_id=trace_id, retryable=True, + ) + + def get_session_feed( + self, + session_id: str, + *, + limit: int | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: + params: dict[str, str] = {} + if limit is not None: + params["limit"] = str(limit) + return self._request_json( + "GET", f"/v1/sessions/{session_id}/feed", params=params or None, trace_id=trace_id, retryable=True, + ) + + def listen_session( + self, + session_id: str, + *, + since: int = 0, + wait_seconds: int = 30, + poll_interval_seconds: float = 1.0, + timeout_seconds: float | None = None, + trace_id: str | None = None, + ) -> Iterator[dict[str, Any]]: + """Stream session feed events via SSE. Yields dicts for each message/intent event. + + Reconnects automatically on stream timeout. Stops on session.completed event + or when timeout_seconds is exceeded. + """ + if since < 0: + raise ValueError("since must be >= 0") + if wait_seconds < 1: + raise ValueError("wait_seconds must be >= 1") + if timeout_seconds is not None and timeout_seconds <= 0: + raise ValueError("timeout_seconds must be > 0 when provided") + + deadline = (time.monotonic() + timeout_seconds) if timeout_seconds is not None else None + next_since = since + + while True: + if deadline is not None and time.monotonic() >= deadline: + return + + stream_wait = wait_seconds + if deadline is not None: + seconds_left = max(0.0, deadline - time.monotonic()) + if seconds_left <= 0: + return + stream_wait = max(1, min(wait_seconds, int(seconds_left))) + + try: + for event in self._iter_session_feed_stream( + session_id=session_id, + since=next_since, + wait_seconds=stream_wait, + trace_id=trace_id, + ): + seq = event.get("seq") + if isinstance(seq, int) and seq > next_since: + next_since = seq + yield event + if event.get("type") == "session.completed": + return + except AxmeHttpError as exc: + if exc.status_code not in {404, 405, 501}: + raise + return + + # Stream ended (timeout), reconnect + if deadline is not None and time.monotonic() >= deadline: + return + time.sleep(poll_interval_seconds) + + def complete_session( + self, + session_id: str, + *, + result: dict[str, Any] | None = None, + trace_id: str | None = None, + ) -> dict[str, Any]: + body: dict[str, Any] = {} + if result is not None: + body["result"] = result + return self._request_json("POST", f"/v1/sessions/{session_id}/complete", json_body=body, trace_id=trace_id, retryable=False) + + def _iter_session_feed_stream( + self, + *, + session_id: str, + since: int, + wait_seconds: int, + trace_id: str | None, + ) -> Iterator[dict[str, Any]]: + headers: dict[str, str] | None = None + normalized_trace_id = self._normalize_trace_id(trace_id) + if normalized_trace_id is not None: + headers = {"X-Trace-Id": normalized_trace_id} + + stream_timeout = httpx.Timeout( + connect=10.0, + read=float(wait_seconds) + 15.0, + write=10.0, + pool=10.0, + ) + with self._http.stream( + "GET", + f"/v1/sessions/{session_id}/feed/stream", + params={"since": str(since), "wait_seconds": str(wait_seconds)}, + headers=headers, + timeout=stream_timeout, + ) as response: + if response.status_code >= 400: + self._raise_http_error(response) + + current_event: str | None = None + data_lines: list[str] = [] + for line in response.iter_lines(): + if line == "": + if current_event == "stream.timeout": + return + if current_event and data_lines: + try: + payload = json.loads("\n".join(data_lines)) + except ValueError: + payload = None + if isinstance(payload, dict): + payload["type"] = current_event + yield payload + current_event = None + data_lines = [] + continue + if line.startswith(":"): + continue + if line.startswith("event:"): + current_event = line.partition(":")[2].strip() + continue + if line.startswith("data:"): + data_lines.append(line.partition(":")[2].lstrip()) + continue + def _request_json( self, method: str, diff --git a/tests/test_client.py b/tests/test_client.py index b0f3ed4..0f18381 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1896,3 +1896,218 @@ def handler(request: httpx.Request) -> httpx.Response: client = _client(handler) with pytest.raises(AxmeAuthError): list(client.listen("org/ws/bot")) + + +# ------------------------------------------------------------------ +# Session API tests +# ------------------------------------------------------------------ + + +def test_create_session() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.method == "POST" + assert request.url.path == "/v1/sessions" + body = json.loads(request.content) + assert body["type"] == "task" + assert body["metadata"]["agent"] == "claude-code" + return httpx.Response(200, json={ + "ok": True, + "session_id": "s-123", + "status": "ACTIVE", + "type": "task", + "created_at": "2026-03-26T12:00:00Z", + }) + + client = _client(handler) + result = client.create_session(type="task", metadata={"agent": "claude-code"}) + assert result["ok"] is True + assert result["session_id"] == "s-123" + assert result["status"] == "ACTIVE" + + +def test_create_session_with_depends_on() -> None: + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + assert body["depends_on"] == ["s-dep-1", "s-dep-2"] + return httpx.Response(200, json={"ok": True, "session_id": "s-456", "status": "PAUSED", "type": "task", "created_at": "2026-03-26T12:00:00Z"}) + + client = _client(handler) + result = client.create_session(depends_on=["s-dep-1", "s-dep-2"]) + assert result["status"] == "PAUSED" + + +def test_get_session() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.method == "GET" + assert request.url.path == "/v1/sessions/s-123" + return httpx.Response(200, json={ + "ok": True, + "session": { + "session_id": "s-123", + "type": "task", + "status": "ACTIVE", + "metadata": {"agent": "claude-code"}, + }, + }) + + client = _client(handler) + result = client.get_session("s-123") + assert result["session"]["session_id"] == "s-123" + + +def test_list_sessions() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.method == "GET" + assert request.url.path == "/v1/sessions" + assert "status=ACTIVE" in str(request.url) + return httpx.Response(200, json={"ok": True, "sessions": [{"session_id": "s-1"}, {"session_id": "s-2"}]}) + + client = _client(handler) + result = client.list_sessions(status="ACTIVE") + assert len(result["sessions"]) == 2 + + +def test_post_session_message() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.method == "POST" + assert "/messages" in request.url.path + body = json.loads(request.content) + assert body["role"] == "agent" + assert body["content"] == "Reading file..." + return httpx.Response(200, json={"ok": True, "message_id": "m-1", "seq": 1, "created_at": "2026-03-26T12:00:00Z"}) + + client = _client(handler) + result = client.post_session_message("s-123", role="agent", content="Reading file...") + assert result["ok"] is True + assert result["seq"] == 1 + + +def test_post_session_message_structured_content() -> None: + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + assert body["content_type"] == "tool_use" + assert body["content"]["tool"] == "Read" + return httpx.Response(200, json={"ok": True, "message_id": "m-2", "seq": 2, "created_at": "2026-03-26T12:00:00Z"}) + + client = _client(handler) + result = client.post_session_message( + "s-123", role="agent", content_type="tool_use", + content={"tool": "Read", "input": {"path": "/tmp/foo.py"}}, + ) + assert result["ok"] is True + + +def test_list_session_messages() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.method == "GET" + assert "/messages" in request.url.path + return httpx.Response(200, json={"ok": True, "messages": [ + {"message_id": "m-1", "seq": 1, "role": "agent", "content": "Hello"}, + ]}) + + client = _client(handler) + result = client.list_session_messages("s-123") + assert len(result["messages"]) == 1 + + +def test_list_session_messages_with_since() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert "since=5" in str(request.url) + return httpx.Response(200, json={"ok": True, "messages": []}) + + client = _client(handler) + result = client.list_session_messages("s-123", since=5) + assert result["messages"] == [] + + +def test_get_session_feed() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/v1/sessions/s-123/feed" + return httpx.Response(200, json={"ok": True, "feed": [ + {"type": "message", "role": "agent", "content": "Working..."}, + {"type": "intent", "intent_id": "i-1", "status": "WAITING"}, + ]}) + + client = _client(handler) + result = client.get_session_feed("s-123") + assert len(result["feed"]) == 2 + assert result["feed"][0]["type"] == "message" + assert result["feed"][1]["type"] == "intent" + + +def test_complete_session() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.method == "POST" + assert request.url.path == "/v1/sessions/s-123/complete" + body = json.loads(request.content) + assert body["result"]["pr_url"] == "https://github.com/org/repo/pull/42" + return httpx.Response(200, json={"ok": True, "session_id": "s-123", "status": "COMPLETED"}) + + client = _client(handler) + result = client.complete_session("s-123", result={"pr_url": "https://github.com/org/repo/pull/42"}) + assert result["status"] == "COMPLETED" + + +def test_complete_session_no_result() -> None: + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + assert "result" not in body + return httpx.Response(200, json={"ok": True, "session_id": "s-123", "status": "COMPLETED"}) + + client = _client(handler) + result = client.complete_session("s-123") + assert result["ok"] is True + + +def test_listen_session_stream() -> None: + sse_body = ( + "event: session.message\n" + 'data: {"message_id": "m-1", "seq": 1, "role": "agent", "content": "Working..."}\n' + "\n" + "event: session.completed\n" + 'data: {"session_id": "s-123", "status": "COMPLETED"}\n' + "\n" + ) + + def handler(request: httpx.Request) -> httpx.Response: + assert "/feed/stream" in request.url.path + return httpx.Response(200, text=sse_body, headers={"content-type": "text/event-stream"}) + + client = _client(handler) + events = list(client.listen_session("s-123", wait_seconds=2, timeout_seconds=5)) + assert len(events) == 2 + assert events[0]["type"] == "session.message" + assert events[0]["content"] == "Working..." + assert events[1]["type"] == "session.completed" + + +def test_listen_session_reconnects_on_timeout() -> None: + call_count = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count == 1: + sse_body = ( + "event: session.message\n" + 'data: {"message_id": "m-1", "seq": 1, "role": "agent", "content": "First"}\n' + "\n" + "event: stream.timeout\n" + 'data: {"ok": true, "last_seq": 1}\n' + "\n" + ) + return httpx.Response(200, text=sse_body, headers={"content-type": "text/event-stream"}) + else: + sse_body = ( + "event: session.completed\n" + 'data: {"session_id": "s-123", "status": "COMPLETED"}\n' + "\n" + ) + return httpx.Response(200, text=sse_body, headers={"content-type": "text/event-stream"}) + + client = _client(handler) + events = list(client.listen_session("s-123", wait_seconds=1, poll_interval_seconds=0.01, timeout_seconds=10)) + assert len(events) == 2 + assert events[0]["content"] == "First" + assert events[1]["type"] == "session.completed" + assert call_count == 2