Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions axme_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,225 @@ def mcp_call_tool(
should_retry = retryable if retryable is not None else bool(idempotency_key)
return self._mcp_request(payload=payload, trace_id=trace_id, retryable=should_retry)

# ------------------------------------------------------------------
# Session API
# ------------------------------------------------------------------

def create_session(
self,
*,
type: str = "task",
project_id: str | None = None,
parent_session_id: str | None = None,
depends_on: list[str] | None = None,
metadata: dict[str, Any] | None = None,
trace_id: str | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"type": type}
if project_id is not None:
body["project_id"] = project_id
if parent_session_id is not None:
body["parent_session_id"] = parent_session_id
if depends_on is not None:
body["depends_on"] = depends_on
if metadata is not None:
body["metadata"] = metadata
return self._request_json("POST", "/v1/sessions", json_body=body, trace_id=trace_id, retryable=False)

def get_session(
self,
session_id: str,
*,
trace_id: str | None = None,
) -> dict[str, Any]:
return self._request_json("GET", f"/v1/sessions/{session_id}", trace_id=trace_id, retryable=True)

def list_sessions(
self,
*,
status: str | None = None,
parent_session_id: str | None = None,
limit: int | None = None,
trace_id: str | None = None,
) -> dict[str, Any]:
params: dict[str, str] = {}
if status is not None:
params["status"] = status
if parent_session_id is not None:
params["parent_session_id"] = parent_session_id
if limit is not None:
params["limit"] = str(limit)
return self._request_json("GET", "/v1/sessions", params=params or None, trace_id=trace_id, retryable=True)

def post_session_message(
self,
session_id: str,
*,
role: str,
content: Any,
content_type: str = "text",
trace_id: str | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"role": role, "content": content, "content_type": content_type}
return self._request_json("POST", f"/v1/sessions/{session_id}/messages", json_body=body, trace_id=trace_id, retryable=False)

def list_session_messages(
self,
session_id: str,
*,
since: int = 0,
limit: int | None = None,
trace_id: str | None = None,
) -> dict[str, Any]:
params: dict[str, str] = {}
if since > 0:
params["since"] = str(since)
if limit is not None:
params["limit"] = str(limit)
return self._request_json(
"GET", f"/v1/sessions/{session_id}/messages", params=params or None, trace_id=trace_id, retryable=True,
)

def get_session_feed(
self,
session_id: str,
*,
limit: int | None = None,
trace_id: str | None = None,
) -> dict[str, Any]:
params: dict[str, str] = {}
if limit is not None:
params["limit"] = str(limit)
return self._request_json(
"GET", f"/v1/sessions/{session_id}/feed", params=params or None, trace_id=trace_id, retryable=True,
)

def listen_session(
self,
session_id: str,
*,
since: int = 0,
wait_seconds: int = 30,
poll_interval_seconds: float = 1.0,
timeout_seconds: float | None = None,
trace_id: str | None = None,
) -> Iterator[dict[str, Any]]:
"""Stream session feed events via SSE. Yields dicts for each message/intent event.

Reconnects automatically on stream timeout. Stops on session.completed event
or when timeout_seconds is exceeded.
"""
if since < 0:
raise ValueError("since must be >= 0")
if wait_seconds < 1:
raise ValueError("wait_seconds must be >= 1")
if timeout_seconds is not None and timeout_seconds <= 0:
raise ValueError("timeout_seconds must be > 0 when provided")

deadline = (time.monotonic() + timeout_seconds) if timeout_seconds is not None else None
next_since = since

while True:
if deadline is not None and time.monotonic() >= deadline:
return

stream_wait = wait_seconds
if deadline is not None:
seconds_left = max(0.0, deadline - time.monotonic())
if seconds_left <= 0:
return
stream_wait = max(1, min(wait_seconds, int(seconds_left)))

try:
for event in self._iter_session_feed_stream(
session_id=session_id,
since=next_since,
wait_seconds=stream_wait,
trace_id=trace_id,
):
seq = event.get("seq")
if isinstance(seq, int) and seq > next_since:
next_since = seq
yield event
if event.get("type") == "session.completed":
return
except AxmeHttpError as exc:
if exc.status_code not in {404, 405, 501}:
raise
return

# Stream ended (timeout), reconnect
if deadline is not None and time.monotonic() >= deadline:
return
time.sleep(poll_interval_seconds)

def complete_session(
self,
session_id: str,
*,
result: dict[str, Any] | None = None,
trace_id: str | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {}
if result is not None:
body["result"] = result
return self._request_json("POST", f"/v1/sessions/{session_id}/complete", json_body=body, trace_id=trace_id, retryable=False)

def _iter_session_feed_stream(
self,
*,
session_id: str,
since: int,
wait_seconds: int,
trace_id: str | None,
) -> Iterator[dict[str, Any]]:
headers: dict[str, str] | None = None
normalized_trace_id = self._normalize_trace_id(trace_id)
if normalized_trace_id is not None:
headers = {"X-Trace-Id": normalized_trace_id}

stream_timeout = httpx.Timeout(
connect=10.0,
read=float(wait_seconds) + 15.0,
write=10.0,
pool=10.0,
)
with self._http.stream(
"GET",
f"/v1/sessions/{session_id}/feed/stream",
params={"since": str(since), "wait_seconds": str(wait_seconds)},
headers=headers,
timeout=stream_timeout,
) as response:
if response.status_code >= 400:
self._raise_http_error(response)

current_event: str | None = None
data_lines: list[str] = []
for line in response.iter_lines():
if line == "":
if current_event == "stream.timeout":
return
if current_event and data_lines:
try:
payload = json.loads("\n".join(data_lines))
except ValueError:
payload = None
if isinstance(payload, dict):
payload["type"] = current_event
yield payload
current_event = None
data_lines = []
continue
if line.startswith(":"):
continue
if line.startswith("event:"):
current_event = line.partition(":")[2].strip()
continue
if line.startswith("data:"):
data_lines.append(line.partition(":")[2].lstrip())
continue

def _request_json(
self,
method: str,
Expand Down
Loading
Loading