Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions examples/agents-temporal/activities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Temporal activities — all real I/O lives here.

Both examples (direct composition and provider) share these activities.
Each activity is a plain async function that does real I/O.
"""

from __future__ import annotations

import dataclasses
from typing import Any

import temporalio.activity

import vercel_ai_sdk as ai

# ── Tool activities ──────────────────────────────────────────────


@temporalio.activity.defn(name="get_weather")
async def get_weather_activity(city: str) -> str:
return f"Sunny, 72F in {city}"


@temporalio.activity.defn(name="get_population")
async def get_population_activity(city: str) -> int:
return {"new york": 8_336_817, "los angeles": 3_979_576}.get(
city.lower(), 1_000_000
)


# ── Generic tool dispatch activity ──────────────────────────────
#
# Used by the provider example: the provider routes tool calls here
# instead of executing them inside the workflow.

_TOOL_REGISTRY: dict[str, Any] = {
"get_weather": get_weather_activity,
"get_population": get_population_activity,
}


@dataclasses.dataclass
class ToolDispatchParams:
tool_name: str
tool_args: str # JSON string


@dataclasses.dataclass
class ToolDispatchResult:
result: Any
is_error: bool = False


@temporalio.activity.defn(name="tool_dispatch")
async def tool_dispatch_activity(params: ToolDispatchParams) -> ToolDispatchResult:
"""Dispatch a tool call by name. Runs the real tool function."""
import json

fn = _TOOL_REGISTRY.get(params.tool_name)
if fn is None:
return ToolDispatchResult(
result=f"Unknown tool: {params.tool_name}", is_error=True
)

try:
kwargs = json.loads(params.tool_args) if params.tool_args else {}
result = await fn(**kwargs)
return ToolDispatchResult(result=result)
except Exception as exc:
return ToolDispatchResult(result=str(exc), is_error=True)


# ── LLM activity ────────────────────────────────────────────────


@dataclasses.dataclass
class LLMCallParams:
messages: list[dict[str, Any]]
tool_schemas: list[dict[str, Any]]


@dataclasses.dataclass
class LLMCallResult:
message: dict[str, Any] # serialized ai.Message


@temporalio.activity.defn(name="llm_call")
async def llm_call_activity(params: LLMCallParams) -> LLMCallResult:
"""Call the LLM, drain the stream, return the final message."""
model = ai.Model(
id="anthropic/claude-sonnet-4-20250514",
adapter="ai-gateway-v3",
provider="ai-gateway",
)

messages = [ai.Message.model_validate(m) for m in params.messages]
tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas]

s = await ai.models.stream(model, messages, tools=tools)
result = await ai.models.buffer(s)
return LLMCallResult(message=result.model_dump())
129 changes: 129 additions & 0 deletions examples/agents-temporal/direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Direct composition: a custom loop where every I/O call is a Temporal activity.

No DurabilityProvider is involved. The user writes a plain ``@agent.loop``
that replaces ``models.stream()`` and tool execution with Temporal
``execute_activity()`` calls. Temporal's event history provides durability.

This is the "lego bricks" approach: the framework gives you ``Agent``,
``Context``, ``@tool`` (for schema extraction), and the message types.
You compose them with Temporal yourself.
"""

from __future__ import annotations

import asyncio
import datetime
from collections.abc import AsyncGenerator
from typing import Any

import temporalio.common
import temporalio.workflow

with temporalio.workflow.unsafe.imports_passed_through():
import activities

import vercel_ai_sdk as ai
from vercel_ai_sdk.agents import Context, agent, tool


# ── Tools ────────────────────────────────────────────────────────
#
# Defined with @tool so the agent can extract JSON schemas for the
# LLM. The bodies are never called inside the workflow — execution
# goes through Temporal activities instead.


@tool
async def get_weather(city: str) -> str:
"""Get current weather for a city."""
raise RuntimeError("should not be called inside workflow")


@tool
async def get_population(city: str) -> int:
"""Get population of a city."""
raise RuntimeError("should not be called inside workflow")


# ── Agent with custom loop ───────────────────────────────────────

weather_agent = agent(tools=[get_weather, get_population])


@weather_agent.loop
async def temporal_loop(context: Context) -> AsyncGenerator[ai.Message]:
"""Agent loop where every I/O call is a durable Temporal activity."""
tool_schemas = [
{
"name": t.name,
"description": t.description,
"param_schema": t.param_schema,
}
for t in context.tools
]

while True:
# LLM call via activity.
result = await temporalio.workflow.execute_activity(
activities.llm_call_activity,
activities.LLMCallParams(
messages=[m.model_dump() for m in context.messages],
tool_schemas=tool_schemas,
),
start_to_close_timeout=datetime.timedelta(minutes=5),
retry_policy=temporalio.common.RetryPolicy(maximum_attempts=3),
)
msg = ai.Message.model_validate(result.message)
yield msg

if not msg.tool_calls:
break

# Tool calls via activities (parallel).
tool_call_parts = msg.tool_calls

async def _run_tool(tc: Any) -> ai.ToolResultPart:
dispatch_result = await temporalio.workflow.execute_activity(
activities.tool_dispatch_activity,
activities.ToolDispatchParams(
tool_name=tc.tool_name,
tool_args=tc.tool_args,
),
start_to_close_timeout=datetime.timedelta(minutes=2),
)
return ai.ToolResultPart(
tool_call_id=tc.tool_call_id,
tool_name=tc.tool_name,
result=dispatch_result.result,
is_error=dispatch_result.is_error,
)

tasks = [asyncio.ensure_future(_run_tool(tc)) for tc in tool_call_parts]
results = await asyncio.gather(*tasks)
yield ai.tool_message(*results)


# ── Workflow ─────────────────────────────────────────────────────


@temporalio.workflow.defn
class DirectWorkflow:
@temporalio.workflow.run
async def run(self, user_query: str) -> str:
model = ai.Model(
id="anthropic/claude-sonnet-4-20250514",
adapter="ai-gateway-v3",
provider="ai-gateway",
)
messages: list[ai.Message] = [
ai.system_message(
"Answer questions using the weather and population tools."
),
ai.user_message(user_query),
]

final_text = ""
async for msg in weather_agent.run(model, messages):
if msg.text:
final_text = msg.text
return final_text
78 changes: 78 additions & 0 deletions examples/agents-temporal/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Entry point — starts a Temporal worker and executes the agent workflow.

Two examples in one project:
- ``direct`` — custom loop, each I/O call is a Temporal activity
- ``provider`` — default loop, DurabilityProvider routes I/O to activities

Prerequisites:
1. Temporal dev server: temporal server start-dev
2. AI_GATEWAY_API_KEY environment variable set

Usage:
uv run python main.py direct
uv run python main.py provider
uv run python main.py direct "What is the weather in Tokyo?"
uv run python main.py provider "Compare weather in NYC and LA"
"""

from __future__ import annotations

import asyncio
import sys
import uuid

import activities
import direct
import provider
import temporalio.client
import temporalio.worker

TASK_QUEUE = "agents-durable"


async def main(mode: str, user_query: str) -> None:
temporal = await temporalio.client.Client.connect("localhost:7233")

workflows = {
"direct": direct.DirectWorkflow,
"provider": provider.ProviderWorkflow,
}
workflow_cls = workflows[mode]

async with temporalio.worker.Worker(
temporal,
task_queue=TASK_QUEUE,
workflows=[workflow_cls],
activities=[
activities.llm_call_activity,
activities.get_weather_activity,
activities.get_population_activity,
activities.tool_dispatch_activity,
],
):
workflow_id = f"agents-{mode}-{uuid.uuid4().hex[:8]}"
print(f"Mode: {mode}")
print(f"Workflow: {workflow_id}")
print(f"Query: {user_query}\n")

result = await temporal.execute_workflow(
workflow_cls.run,
user_query,
id=workflow_id,
task_queue=TASK_QUEUE,
)
print(result)


if __name__ == "__main__":
if len(sys.argv) < 2 or sys.argv[1] not in ("direct", "provider"):
print("Usage: python main.py <direct|provider> [query]")
sys.exit(1)

mode = sys.argv[1]
query = (
sys.argv[2]
if len(sys.argv) > 2
else "What's the weather and population of New York and Los Angeles?"
)
asyncio.run(main(mode, query))
Loading
Loading