diff --git a/docs/replay.md b/docs/replay.md new file mode 100644 index 0000000..c7ed378 --- /dev/null +++ b/docs/replay.md @@ -0,0 +1,271 @@ +# Record/replay — capture and inspect agent runs + +`dedalus_labs.lib.replay` lets you save a complete record of any agent run to +a local JSON file. You can open that file, read what the agent did, and — in a +future release — replay the run deterministically for debugging. + +--- + +## Privacy model + +Recording is **opt-in** and **local-first**: + +- Nothing is uploaded anywhere. The trace lands in a file on your machine. +- Recording only happens when you pass `on_tool_event=rec.on_tool` and + `on_model_event=rec.on_model` to `runner.run()`. A run without those + arguments produces no trace. +- What is captured: every model request payload (messages, tools, model name) + and every model response (including raw tool calls), plus each tool + result (name, arguments, return value). +- What is **not** captured: anything the runner never sees — secrets already + in environment variables, TLS-layer bytes, MCP-server internals. + +If trace files will leave your machine (shared with a customer, attached to an +issue), use the built-in redactors before saving: + +```python +from dedalus_labs.lib.replay import Recorder, redact_emails, redact_bearer_tokens + +def redact(event): + event = redact_emails(event) + event = redact_bearer_tokens(event) + return event + +with Recorder("trace.json", redact=redact) as rec: + runner.run(..., on_tool_event=rec.on_tool, on_model_event=rec.on_model) +``` + +--- + +## Quick start + +```python +from dedalus_labs import Dedalus +from dedalus_labs.lib.runner import DedalusRunner +from dedalus_labs.lib.replay import Recorder + +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + +client = Dedalus() +runner = DedalusRunner(client) + +with Recorder("trace.json") as rec: + result = runner.run( + model="openai/gpt-5-nano", + input="What is 3 + 4? Use the add tool.", + tools=[add], + on_tool_event=rec.on_tool, + on_model_event=rec.on_model, + ) + +print(result.final_output) +# trace.json now contains the full run +``` + +--- + +## Recorder API + +```python +Recorder(path, *, redact=None, metadata=None) +``` + +| Parameter | Type | Description | +|-----------|------|-------------| +| `path` | `str \| Path` | Where to write the trace file. | +| `redact` | `Callable[[dict], dict] \| None` | Called on each event before it is stored. Return a modified copy; raise to fall back to marking the event `_redaction_failed`. | +| `metadata` | `dict \| None` | Arbitrary key-value pairs written into the trace envelope (ticket IDs, customer names, environment tags, etc.). | + +**Methods:** + +- `rec.on_tool(event)` — pass as `on_tool_event=` in `runner.run()`. +- `rec.on_model(event)` — pass as `on_model_event=` in `runner.run()`. +- `rec.save()` — write the file immediately. Called automatically on `__exit__`. + +Use as a context manager (`with Recorder(...) as rec`) — `save()` is called +even if the run raises. + +--- + +## Trace format (v1.0) + +The file is UTF-8 JSON, pretty-printed with two-space indentation. + +```json +{ + "format_version": "1.0", + "sdk_version": "0.3.0", + "recorded_at": "2026-05-12T15:02:11Z", + "metadata": {}, + "events": [ + { + "kind": "model_request", + "step": 1, + "request": { "model": "openai/gpt-5-nano", "messages": [...], "tools": [...] }, + "ts": 1715526131.04 + }, + { + "kind": "model_response", + "step": 1, + "response": { "id": "chatcmpl-...", "choices": [...] }, + "ts": 1715526132.41 + }, + { + "kind": "tool_end", + "step": 1, + "name": "add", + "tool_call_id": "call_abc123", + "arguments": "{\"a\": 3, \"b\": 4}", + "result": 7, + "ts": 1715526132.63 + }, + { + "kind": "model_request", + "step": 2, + "request": { ... }, + "ts": 1715526132.71 + }, + { + "kind": "model_response", + "step": 2, + "response": { ... }, + "ts": 1715526133.92 + } + ] +} +``` + +### Envelope fields + +| Field | Description | +|-------|-------------| +| `format_version` | Schema version. Bump on breaking changes. | +| `sdk_version` | `dedalus_labs.__version__` at record time. | +| `recorded_at` | UTC ISO-8601 timestamp when `save()` was called. | +| `metadata` | User-supplied dict (pass via `Recorder(..., metadata={...})`). | +| `events` | Ordered list of event objects. | + +### Event fields (all events) + +| Field | Description | +|-------|-------------| +| `kind` | `"model_request"`, `"model_response"`, or `"tool_end"`. | +| `step` | Turn counter from the runner (starts at 1). | +| `ts` | Unix timestamp (float) when the event was recorded. | + +### `model_request` extra fields + +| Field | Description | +|-------|-------------| +| `request` | The kwargs passed to `client.chat.completions.create`, serialized. | + +### `model_response` extra fields + +| Field | Description | +|-------|-------------| +| `response` | The `ChatCompletion` object serialized via `model_dump(mode="json")`. | + +### `tool_end` extra fields + +| Field | Description | +|-------|-------------| +| `name` | Tool function name. | +| `tool_call_id` | ID from the model's tool call request. | +| `arguments` | Raw JSON string of arguments the model passed. | +| `result` | Return value of the tool function. | +| `error` | Present only if the tool raised; contains the error message string. | + +--- + +## Built-in redactors + +```python +from dedalus_labs.lib.replay import redact_emails, redact_bearer_tokens, redact_api_keys +``` + +Each redactor walks the event dict recursively and replaces matching strings: + +| Redactor | Pattern replaced | Replacement | +|----------|-----------------|-------------| +| `redact_emails` | `user@example.com` style | `[REDACTED_EMAIL]` | +| `redact_bearer_tokens` | `Bearer ` in any string value | `Bearer [REDACTED]` | +| `redact_api_keys` | `sk-...`, `dsk-...`, `key-...` patterns | `[REDACTED_KEY]` | + +Compose them: + +```python +def redact(event): + event = redact_emails(event) + event = redact_bearer_tokens(event) + event = redact_api_keys(event) + return event +``` + +--- + +## Replaying a trace + +`Replayer` reads a `trace.json` and re-runs the recorded conversation +through the production `DedalusRunner` - no API calls, no MCP traffic. + +```python +from dedalus_labs.lib.replay import Replayer + +result = Replayer.from_file("trace.json").run() +print(result.final_output) +``` + +Internally, `Replayer` injects a fake client whose `chat.completions.create()` +serves the recorded `ChatCompletion` objects in order, and substitutes each +local tool with a stub that returns the recorded result. The runner walks +its normal step loop; nothing is mocked except the two outward seams. + +### `Replayer.run(...)` parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `swap_tool` | `dict[str, Callable] \| None` | Map of tool name to callable. Named tools run your function instead of the recorded stub. Useful for A/B-testing a fix. | +| `swap_client` | `Dedalus \| None` | A live client. Routes model calls to the real API using the recorded messages and tools as context. | + +```python +# A/B-test a tool fix against the same recorded conversation +def better_add(a: int, b: int) -> int: + return a + b + +Replayer.from_file("trace.json").run(swap_tool={"add": better_add}) + +# Run the recorded conversation against a real model +from dedalus_labs import Dedalus +Replayer.from_file("trace.json").run(swap_client=Dedalus()) +``` + +### Drift detection + +Replay fails loudly when the recorded behavior and current code paths +diverge: + +- **More model calls than recorded** - the fake client raises a `RuntimeError` + pointing at `swap_client=` as the bridge. +- **More tool calls for a name than recorded** - the synthetic tool raises + pointing at `swap_tool={name: ...}`. +- **Unknown `format_version`** - `from_file` / `from_dict` raises `ValueError` + during construction. + +A drift error usually means the customer's recorded run hits a code path +that no longer exists locally. That is exactly the bug an FDE wants to +surface, not silently swallow. + +--- + +## Out of scope (follow-up issues) + +The following are intentional non-goals for v1. File a new issue if you need one: + +- **Streaming recording / replay** — `_execute_streaming_*` paths are not instrumented. +- **Per-tool start events** — `tool_start` events with timing inside parallel batches. +- **Cloud upload / hosted viewer** — traces are local-only. +- **OpenTelemetry export** — the event format is not OTel-compatible today. +- **Trace diffing** — comparing two trace files for regression testing. +- **Schema migration** — tooling to upgrade `format_version` 1.0 traces to future versions. diff --git a/examples/replay/01_record.py b/examples/replay/01_record.py new file mode 100644 index 0000000..2c9322f --- /dev/null +++ b/examples/replay/01_record.py @@ -0,0 +1,48 @@ +"""Record an agent run to trace.json. + +Usage: + DEDALUS_API_KEY= python examples/replay/01_record.py + +The script asks the model to add two numbers using a local Python tool. +After the run, open trace.json to see the full record: model requests, +model responses, and the tool call result. +""" + +import json +import sys +from pathlib import Path + +from dedalus_labs import Dedalus +from dedalus_labs.lib.runner import DedalusRunner +from dedalus_labs.lib.replay import Recorder + + +def add(a: int, b: int) -> int: + """Add two integers and return the sum.""" + return a + b + + +def main() -> None: + client = Dedalus() + runner = DedalusRunner(client) + trace_path = Path("trace.json") + + with Recorder(trace_path) as rec: + result = runner.run( + model="openai/gpt-5-nano", + input="What is 3 + 4? You must call the add tool.", + tools=[add], + on_tool_event=rec.on_tool, + on_model_event=rec.on_model, + ) + + print(f"Answer : {result.final_output}") + print(f"Trace : {trace_path} ({trace_path.stat().st_size} bytes)") + + trace = json.loads(trace_path.read_text()) + kinds = [e["kind"] for e in trace["events"]] + print(f"Events : {kinds}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/replay/02_replay.py b/examples/replay/02_replay.py new file mode 100644 index 0000000..893c9ac --- /dev/null +++ b/examples/replay/02_replay.py @@ -0,0 +1,33 @@ +"""Replay an agent run from a recorded trace. + +Usage: + python examples/replay/02_replay.py [trace.json] + +Reads the trace file, re-runs the conversation through DedalusRunner with +a fake client and synthetic tools, and prints the final answer. No network +calls are made. + +First run examples/replay/01_record.py to produce trace.json. +""" + +import sys +from pathlib import Path + +from dedalus_labs.lib.replay import Replayer + + +def main() -> None: + path = Path(sys.argv[1] if len(sys.argv) > 1 else "trace.json") + if not path.exists(): + print(f"Trace not found: {path}") + print("Run examples/replay/01_record.py first to record a trace.") + sys.exit(1) + + result = Replayer.from_file(path).run() + + print(f"Replayed from : {path}") + print(f"Final output : {result.final_output}") + + +if __name__ == "__main__": + main() diff --git a/examples/replay/03_multi_tool.py b/examples/replay/03_multi_tool.py new file mode 100644 index 0000000..3255be6 --- /dev/null +++ b/examples/replay/03_multi_tool.py @@ -0,0 +1,80 @@ +"""Record a multi-tool, multi-step, multi-model run. + +Exercises: + - Two local tools (`add`, `multiply`) + - Multi-step reasoning where step 2 depends on step 1's result + - Tool-use that requires recalling a training-time fact + (the year GPT-5 came out) + - A handoff prompt across two models (gpt-5-nano + claude-sonnet-4-5) + +Usage: + DEDALUS_API_KEY= python examples/replay/03_multi_tool.py + +After this writes trace_multi.json, replay it with: + python examples/replay/02_replay.py trace_multi.json + +Note on handoffs: passing model=[a, b] makes the server advertise a +`transfer_to_*` tool to the primary model, but client-side execution of +that handoff tool is not enabled by default in this SDK build. The model +may emit a `transfer_to_*` call that the runner does not resolve. The +recorder captures this faithfully — see how replay reproduces the exact +final state, even when the live run ended on an unresolved handoff. +""" + +import json +import sys +from pathlib import Path + +from dedalus_labs import Dedalus +from dedalus_labs.lib.replay import Recorder +from dedalus_labs.lib.runner import DedalusRunner + + +def add(a: int, b: int) -> int: + """Add two integers and return the sum.""" + return a + b + + +def multiply(a: int, b: int) -> int: + """Multiply two integers and return the product.""" + return a * b + + +PROMPT = ( + "Log all model handoffs you conduct. Use your tools to do this:\n" + "1) Add 3 + 5.\n" + "2) Multiply the result from step 1 by the year GPT-5 came out.\n" + "3) Handoff to sonnet and write a poem about the number from step 2. " + "Output the actual poem." +) + + +def main() -> None: + client = Dedalus() + runner = DedalusRunner(client) + trace_path = Path("trace_multi.json") + + with Recorder(trace_path) as rec: + result = runner.run( + model=["openai/gpt-5-nano", "anthropic/claude-sonnet-4-5"], + input=PROMPT, + tools=[add, multiply], + on_tool_event=rec.on_tool, + on_model_event=rec.on_model, + ) + + print("=" * 60) + print("FINAL OUTPUT") + print("=" * 60) + print(result.final_output) + print() + print(f"Trace : {trace_path} ({trace_path.stat().st_size} bytes)") + + trace = json.loads(trace_path.read_text()) + kinds = [e["kind"] for e in trace["events"]] + print(f"Events: {kinds}") + print(f"Tools called: {[e['name'] for e in trace['events'] if e['kind'] == 'tool_end']}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/replay/README.md b/examples/replay/README.md new file mode 100644 index 0000000..4110e0e --- /dev/null +++ b/examples/replay/README.md @@ -0,0 +1,16 @@ +# Replay examples + +| Script | What it shows | +|--------|---------------| +| `01_record.py` | Record a tool-calling agent run to `trace.json` | +| `02_replay.py` | Re-run the recorded conversation locally with no API calls | + +Run them in order: + +```bash +DEDALUS_API_KEY= python examples/replay/01_record.py +python examples/replay/02_replay.py trace.json +``` + +See [`docs/replay.md`](../../docs/replay.md) for the full API reference, +trace format, and privacy model. diff --git a/src/dedalus_labs/lib/replay/__init__.py b/src/dedalus_labs/lib/replay/__init__.py new file mode 100644 index 0000000..31820ce --- /dev/null +++ b/src/dedalus_labs/lib/replay/__init__.py @@ -0,0 +1,48 @@ +"""Record and replay agent runs via a local JSON trace file. + +Record:: + + from dedalus_labs.lib.replay import Recorder + + with Recorder("trace.json") as rec: + runner.run( + model="openai/gpt-5-nano", + input="...", + tools=[add], + on_tool_event=rec.on_tool, + on_model_event=rec.on_model, + ) + +Replay:: + + from dedalus_labs.lib.replay import Replayer + + result = Replayer.from_file("trace.json").run() + +The trace file is local-only. See ``docs/replay.md`` for the privacy model, +the trace format, and how to compose redactors. +""" + +from ._events import ( + TOOL_END, + MODEL_REQUEST, + FORMAT_VERSION, + MODEL_RESPONSE, + build_envelope, +) +from ._redact import redact_emails, redact_api_keys, redact_bearer_tokens +from ._recorder import Recorder +from ._replayer import Replayer + +__all__ = [ + "FORMAT_VERSION", + "MODEL_REQUEST", + "MODEL_RESPONSE", + "TOOL_END", + "Recorder", + "Replayer", + "build_envelope", + "redact_api_keys", + "redact_bearer_tokens", + "redact_emails", +] diff --git a/src/dedalus_labs/lib/replay/_events.py b/src/dedalus_labs/lib/replay/_events.py new file mode 100644 index 0000000..752d438 --- /dev/null +++ b/src/dedalus_labs/lib/replay/_events.py @@ -0,0 +1,43 @@ +"""Trace format constants and envelope construction. + +The trace format is intentionally simple JSON. See ``docs/replay.md`` for the +full schema and the reasoning behind the choices. +""" + +from __future__ import annotations + +from typing import Any, Dict, List +from datetime import datetime, timezone + +from ..._version import __version__ + +# Bump on backwards-incompatible trace format changes. +# Readers should reject unknown major versions; future minor/patch versions +# may add fields and remain forward-compatible. +FORMAT_VERSION = "1.0" + +# Event kinds the runner emits. Constants exposed so callers can filter or +# assert on event types without stringly-typed comparisons. +MODEL_REQUEST = "model_request" +MODEL_RESPONSE = "model_response" +TOOL_END = "tool_end" + + +def build_envelope(events: List[Dict[str, Any]], metadata: Dict[str, Any]) -> Dict[str, Any]: + """Wrap a list of events in the v1 trace envelope. + + The envelope is the structure that lands in the trace JSON file. It + carries enough metadata for a reader to know what SDK version produced + the trace and when, plus any user-supplied context (ticket IDs, etc.). + """ + return { + "format_version": FORMAT_VERSION, + "sdk_version": __version__, + "recorded_at": _now_iso(), + "metadata": dict(metadata), + "events": list(events), + } + + +def _now_iso() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") diff --git a/src/dedalus_labs/lib/replay/_fake_client.py b/src/dedalus_labs/lib/replay/_fake_client.py new file mode 100644 index 0000000..df391d7 --- /dev/null +++ b/src/dedalus_labs/lib/replay/_fake_client.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Any + +from ...types.chat import ChatCompletion + + +class _FakeCompletions: + def __init__(self, responses: list[ChatCompletion]) -> None: + self._queue = list(responses) + + def create(self, **kwargs: Any) -> ChatCompletion: # noqa: ARG002 + if not self._queue: + raise RuntimeError( + "replay drift: runner requested more model responses than the trace " + "recorded. Pass swap_client=Dedalus() to continue with a real model." + ) + return self._queue.pop(0) + + +class _FakeChat: + def __init__(self, responses: list[ChatCompletion]) -> None: + self.completions = _FakeCompletions(responses) + + +class _FakeClient: + """Minimal client shim that serves recorded ChatCompletion objects. + + Not an instance of AsyncDedalus, so DedalusRunner routes to the sync + execution path (_execute_turns_sync). + """ + + def __init__(self, responses: list[ChatCompletion]) -> None: + self.chat = _FakeChat(responses) diff --git a/src/dedalus_labs/lib/replay/_recorder.py b/src/dedalus_labs/lib/replay/_recorder.py new file mode 100644 index 0000000..8015798 --- /dev/null +++ b/src/dedalus_labs/lib/replay/_recorder.py @@ -0,0 +1,106 @@ +"""Recorder: capture runner events to a JSON file.""" + +from __future__ import annotations + +import json +import time +from typing import Any, Dict, List, Union, Callable, Optional +from pathlib import Path + +from ._events import build_envelope +from ..._utils._json import _CustomEncoder + +EventDict = Dict[str, Any] +RedactFn = Callable[[EventDict], EventDict] + + +class Recorder: + """Capture runner events to a JSON file for later inspection or replay. + + Pass the bound methods ``on_tool`` and ``on_model`` as runner callbacks. + Use the Recorder as a context manager so the file is written on exit:: + + with Recorder("trace.json") as rec: + runner.run( + model="openai/gpt-5-nano", + input="What is 3 + 4?", + tools=[add], + on_tool_event=rec.on_tool, + on_model_event=rec.on_model, + ) + + The file lives only on local disk. Nothing is sent over the network. + + If ``redact`` is provided, it runs on every event before the event is + stored in memory — raw values never live in the Recorder. The function + should accept and return an event dict. + + ``metadata`` is embedded in the trace envelope verbatim. Use it for + context like ticket IDs or customer IDs that aren't part of the agent + interaction itself. + """ + + def __init__( + self, + path: Union[str, Path], + *, + redact: Optional[RedactFn] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + self._path = Path(path) + self._redact = redact + self._metadata = dict(metadata or {}) + self._events: List[EventDict] = [] + self._closed = False + + # -- runner callback targets -------------------------------------------- + + def on_tool(self, event: EventDict) -> None: + """Callback for ``runner.run(on_tool_event=...)``.""" + self._record(event) + + def on_model(self, event: EventDict) -> None: + """Callback for ``runner.run(on_model_event=...)``.""" + self._record(event) + + # -- internals ---------------------------------------------------------- + + def _record(self, event: EventDict) -> None: + stamped: EventDict = {**event, "ts": time.time()} + if self._redact is None: + self._events.append(stamped) + return + try: + self._events.append(self._redact(stamped)) + except Exception: + # Redaction must never break the run. Mark the event so the + # operator can spot redaction failures by reading the trace. + self._events.append({**stamped, "_redaction_failed": True}) + + # -- output ------------------------------------------------------------- + + def save(self) -> None: + """Write the trace JSON to disk. Idempotent.""" + if self._closed: + return + envelope = build_envelope(self._events, self._metadata) + # Pretty-print so the trace file is readable by hand. Reuses the + # SDK's custom encoder for datetime and Pydantic support. + self._path.write_text( + json.dumps(envelope, cls=_CustomEncoder, indent=2, ensure_ascii=False, allow_nan=False), + encoding="utf-8", + ) + self._closed = True + + @property + def events(self) -> List[EventDict]: + """Captured events (read-only view for inspection in tests).""" + return list(self._events) + + # -- context manager ---------------------------------------------------- + + def __enter__(self) -> "Recorder": + return self + + def __exit__(self, *exc: Any) -> None: + self.save() diff --git a/src/dedalus_labs/lib/replay/_redact.py b/src/dedalus_labs/lib/replay/_redact.py new file mode 100644 index 0000000..d39ab84 --- /dev/null +++ b/src/dedalus_labs/lib/replay/_redact.py @@ -0,0 +1,55 @@ +"""Default redactors for use with ``Recorder(redact=...)``. + +Each redactor takes an event dict and returns a new event dict with matching +substrings replaced by ``[REDACTED]``. The original event is not mutated; +redactors walk dicts/lists/tuples recursively. + +Compose redactors by chaining them in a small lambda:: + + def redact(event): + event = redact_emails(event) + event = redact_bearer_tokens(event) + return event + + with Recorder("trace.json", redact=redact) as rec: + ... +""" + +from __future__ import annotations + +import re +from typing import Any, Callable + +_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.-]+") +_BEARER_RE = re.compile(r"\bBearer\s+[\w.\-+/=]+", re.IGNORECASE) +_API_KEY_RE = re.compile(r"\b(?:sk|dsk|pk|api)[-_][\w-]{10,}\b") + +_PLACEHOLDER = "[REDACTED]" + + +def redact_emails(event: dict[str, Any]) -> dict[str, Any]: + """Replace email-like substrings with ``[REDACTED]``.""" + return _walk(event, lambda s: _EMAIL_RE.sub(_PLACEHOLDER, s)) + + +def redact_bearer_tokens(event: dict[str, Any]) -> dict[str, Any]: + """Replace ``Bearer `` substrings with ``Bearer [REDACTED]``.""" + return _walk(event, lambda s: _BEARER_RE.sub(f"Bearer {_PLACEHOLDER}", s)) + + +def redact_api_keys(event: dict[str, Any]) -> dict[str, Any]: + """Replace common API key shapes (``sk-...``, ``dsk-...``, ``pk-...``, ``api_...``).""" + return _walk(event, lambda s: _API_KEY_RE.sub(_PLACEHOLDER, s)) + + +def _walk(obj: Any, fn: Callable[[str], str]) -> Any: + """Recursively apply ``fn`` to every string inside a JSON-ish structure.""" + if isinstance(obj, str): + return fn(obj) + if isinstance(obj, dict): + return {k: _walk(v, fn) for k, v in obj.items()} + if isinstance(obj, list): + return [_walk(v, fn) for v in obj] + if isinstance(obj, tuple): + return tuple(_walk(v, fn) for v in obj) + return obj diff --git a/src/dedalus_labs/lib/replay/_replayer.py b/src/dedalus_labs/lib/replay/_replayer.py new file mode 100644 index 0000000..ef6dbf4 --- /dev/null +++ b/src/dedalus_labs/lib/replay/_replayer.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import json +from typing import Any, Callable +from pathlib import Path + +from ..runner import DedalusRunner +from ._events import FORMAT_VERSION +from ...types.chat import ChatCompletion +from ._fake_client import _FakeClient + + +class Replayer: + """Re-run an agent conversation from a recorded trace. + + Reuses the production DedalusRunner end-to-end. Only the model client + and local tool functions are replaced with recorded-data stubs — policy, + message building, scheduling, and MCP composition all run as in + production. + + Usage:: + + result = Replayer.from_file("trace.json").run() + print(result.final_output) + + # Substitute a fixed tool implementation + result = Replayer.from_file("trace.json").run( + swap_tool={"add": lambda a, b: a + b}, + ) + + # Route through a real model instead + from dedalus_labs import Dedalus + result = Replayer.from_file("trace.json").run( + swap_client=Dedalus(), + ) + """ + + def __init__(self, trace: dict[str, Any]) -> None: + self._trace = trace + self._validate() + + @classmethod + def from_file(cls, path: str | Path) -> "Replayer": + return cls(json.loads(Path(path).read_text(encoding="utf-8"))) + + @classmethod + def from_dict(cls, trace: dict[str, Any]) -> "Replayer": + return cls(trace) + + def _validate(self) -> None: + v = self._trace.get("format_version") + if v != FORMAT_VERSION: + raise ValueError( + f"unsupported trace format_version={v!r} " + f"(this Replayer reads {FORMAT_VERSION!r})" + ) + if not isinstance(self._trace.get("events"), list): + raise ValueError("trace is missing an `events` list") + + def run( + self, + *, + swap_tool: dict[str, Callable[..., Any]] | None = None, + swap_client: Any = None, + ) -> Any: + """Replay the recorded conversation. + + Parameters + ---------- + swap_tool: + Map of tool name -> callable. Named tools run your callable + instead of the recorded stub. Useful for A/B-testing a fix. + swap_client: + A real Dedalus client. Routes model calls to the live API + using the recorded messages and tools as context. + """ + events = self._trace["events"] + + first_req = next( + (e for e in events if e["kind"] == "model_request"), None + ) + if first_req is None: + raise ValueError("trace contains no model_request events") + + req = first_req["request"] + + responses = [ + ChatCompletion.model_validate(e["response"]) + for e in events + if e["kind"] == "model_response" + ] + + recorded_tool_ends: dict[str, list[dict[str, Any]]] = {} + for e in events: + if e["kind"] == "tool_end": + recorded_tool_ends.setdefault(e["name"], []).append(e) + + # Union both sources so MCP-only tools (no tool_end events) are still surfaced + tool_names = set(recorded_tool_ends) | { + t["function"]["name"] + for t in (req.get("tools") or []) + } + + swap_tool = swap_tool or {} + tools: list[Callable[..., Any]] = [] + for name in tool_names: + if name in swap_tool: + fn = swap_tool[name] + if fn.__name__ != name: + # Runner dispatches by __name__; rename without mutating caller's fn + wrapped = lambda *a, _fn=fn, **kw: _fn(*a, **kw) + wrapped.__name__ = name + fn = wrapped + tools.append(fn) + else: + tools.append(_make_replay_tool(name, recorded_tool_ends.get(name, []))) + + client = swap_client or _FakeClient(responses) + runner = DedalusRunner(client) + + messages = req.get("messages") or [] + if not messages: + raise ValueError("trace first model_request has no messages") + + return runner.run( + model=req["model"], + messages=messages, + tools=tools or None, + mcp_servers=req.get("mcp_servers") or None, + ) + + +# --------------------------------------------------------------------------- +# Helpers + + +def _make_replay_tool(name: str, recorded_calls: list[dict[str, Any]]) -> Callable[..., Any]: + """Return a callable that pops the next recorded result for `name`.""" + queue = list(recorded_calls) + + def replay_fn(**kwargs: Any) -> Any: # noqa: ARG001 + if not queue: + raise RuntimeError( + f"replay drift: model called tool {name!r} more times than " + f"recorded. Pass swap_tool={{'{name}': }} to bridge." + ) + ev = queue.pop(0) + if ev.get("error"): + raise RuntimeError(ev["error"]) + return ev["result"] + + replay_fn.__name__ = name + return replay_fn diff --git a/src/dedalus_labs/lib/runner/_scheduler.py b/src/dedalus_labs/lib/runner/_scheduler.py index fe94b3a..a51b346 100644 --- a/src/dedalus_labs/lib/runner/_scheduler.py +++ b/src/dedalus_labs/lib/runner/_scheduler.py @@ -210,13 +210,13 @@ async def _run_one_async( try: result = await tool_handler.exec(fn_name, fn_args) - tool_results.append({"name": fn_name, "result": result, "step": step}) + tool_results.append({"name": fn_name, "result": result, "step": step, "tool_call_id": call_id}) tools_called.append(fn_name) messages.append({"role": "tool", "tool_call_id": call_id, "content": str(result)}) if verbose: print(f" Tool {fn_name}: {str(result)[:50]}...") # noqa: T201 except Exception as e: - tool_results.append({"error": str(e), "name": fn_name, "step": step}) + tool_results.append({"error": str(e), "name": fn_name, "step": step, "tool_call_id": call_id}) messages.append({"role": "tool", "tool_call_id": call_id, "content": f"Error: {e}"}) if verbose: print(f" Tool {fn_name} failed: {e}") # noqa: T201 @@ -237,11 +237,11 @@ def _run_one_sync( try: result = tool_handler.exec_sync(fn_name, fn_args) - tool_results.append({"name": fn_name, "result": result, "step": step}) + tool_results.append({"name": fn_name, "result": result, "step": step, "tool_call_id": call_id}) tools_called.append(fn_name) messages.append({"role": "tool", "tool_call_id": call_id, "content": str(result)}) except Exception as e: - tool_results.append({"error": str(e), "name": fn_name, "step": step}) + tool_results.append({"error": str(e), "name": fn_name, "step": step, "tool_call_id": call_id}) messages.append({"role": "tool", "tool_call_id": call_id, "content": f"Error: {e}"}) diff --git a/src/dedalus_labs/lib/runner/core.py b/src/dedalus_labs/lib/runner/core.py index a295f09..741cb0b 100644 --- a/src/dedalus_labs/lib/runner/core.py +++ b/src/dedalus_labs/lib/runner/core.py @@ -71,6 +71,85 @@ def _extract_mcp_results(response: Any) -> list[MCPToolResult]: return [item if isinstance(item, MCPToolResult) else MCPToolResult.model_validate(item) for item in mcp_results] +def _emit(callback: Callable[[Dict[str, JsonValue]], None] | None, event: Dict[str, Any]) -> None: + """Fire a runner observation callback, swallowing any callback-side error. + + Callbacks are observation hooks; they must never break the agent loop. + """ + if callback is None: + return + try: + callback(event) + except Exception: + pass + + +def _to_jsonable(obj: Any) -> Any: + """Convert Pydantic models and nested containers to plain JSON-friendly form. + + Used before handing event payloads to user-supplied callbacks so that + redactors and JSON serializers see plain dicts/lists/strings, not opaque + SDK types. + """ + dump = getattr(obj, "model_dump", None) + if callable(dump): + try: + return dump(mode="json", exclude_unset=True, by_alias=True) + except Exception: + pass + if isinstance(obj, dict): + return {k: _to_jsonable(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_to_jsonable(v) for v in obj] + return obj + + +def _redact_request_for_event(request_kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Drop the credentials field before exposing request kwargs to event callbacks. + + Credentials are an out-of-band channel (API keys, tokens, signed bearer + objects). They must never appear in trace files or be passed to + user-supplied event callbacks, even if the underlying client accepts + them on the model-call payload. + """ + return {k: v for k, v in request_kwargs.items() if k != "credentials"} + + +def _emit_tool_ends( + callback: Callable[[Dict[str, JsonValue]], None] | None, + tool_calls: list, + tool_results: list, + prev_count: int, + step: int, +) -> None: + """Emit one `tool_end` event per newly-appended tool result. + + Correlates each result with its originating tool call by the + `tool_call_id` the scheduler stamps on every result entry. Falls back + to FIFO-by-name only if `tool_call_id` is missing (older code paths). + """ + if callback is None: + return + calls_by_id = {c.get("id"): c for c in tool_calls if c.get("id")} + remaining = list(tool_calls) + for tr in tool_results[prev_count:]: + name = tr.get("name") if isinstance(tr, dict) else None + tr_id = tr.get("tool_call_id") if isinstance(tr, dict) else None + matched = calls_by_id.get(tr_id) if tr_id else None + if matched is None: + # Defensive fallback for any path that hasn't stamped tool_call_id. + matched = next((c for c in remaining if c.get("function", {}).get("name") == name), None) + if matched is not None: + remaining.remove(matched) + event: Dict[str, Any] = {"kind": "tool_end", "step": step, "name": name, "result": tr.get("result")} + if isinstance(tr, dict) and tr.get("error") is not None: + event["error"] = tr["error"] + if matched is not None: + event["tool_call_id"] = matched.get("id") + event["arguments"] = matched.get("function", {}).get("arguments") + _emit(callback, event) + + class _ToolHandler(Protocol): def schemas(self) -> list[Dict]: ... async def exec(self, name: str, args: Dict[str, JsonValue]) -> JsonValue: ... @@ -145,6 +224,7 @@ class _ExecutionConfig: verbose: bool = False debug: bool = False on_tool_event: Callable[[Dict[str, JsonValue]], None] | None = None + on_model_event: Callable[[Dict[str, JsonValue]], None] | None = None return_intent: bool = False policy: PolicyInput = None available_models: list[str] = field(default_factory=list) @@ -293,6 +373,7 @@ def run( verbose: bool | None = None, debug: bool | None = None, on_tool_event: Callable[[Dict[str, JsonValue]], None] | None = None, + on_model_event: Callable[[Dict[str, JsonValue]], None] | None = None, return_intent: bool = False, policy: PolicyInput = None, available_models: list[str] | None = None, @@ -440,6 +521,7 @@ def run( verbose=verbose if verbose is not None else self.verbose, debug=debug or False, on_tool_event=on_tool_event, + on_model_event=on_model_event, return_intent=return_intent, policy=policy, available_models=available_models or [], @@ -539,14 +621,17 @@ async def _execute_turns_async( # Make model call current_messages = self._build_messages(messages, policy_result["prepend"], policy_result["append"]) - response = await self.client.chat.completions.create( - model=policy_result["model"], - messages=current_messages, - tools=tool_handler.schemas() or None, - mcp_servers=policy_result["mcp_servers"], - credentials=exec_config.credentials, + request_kwargs = { + "model": policy_result["model"], + "messages": current_messages, + "tools": tool_handler.schemas() or None, + "mcp_servers": policy_result["mcp_servers"], + "credentials": exec_config.credentials, **{**self._mk_kwargs(model_config), **policy_result["model_kwargs"]}, - ) + } + _emit(exec_config.on_model_event, {"kind": "model_request", "step": steps, "request": _to_jsonable(_redact_request_for_event(request_kwargs))}) + response = await self.client.chat.completions.create(**request_kwargs) + _emit(exec_config.on_model_event, {"kind": "model_response", "step": steps, "response": _to_jsonable(response)}) if exec_config.verbose: actual_model = policy_result["model"] @@ -602,6 +687,7 @@ async def _execute_turns_async( print(f" Extracted {len(tool_calls)} tool calls") for tc in tool_calls: print(f" - {tc.get('function', {}).get('name', '?')} (id: {tc.get('id', '?')})") + prev_tool_count = len(tool_results) await self._execute_tool_calls( tool_calls, tool_handler, @@ -611,6 +697,7 @@ async def _execute_turns_async( steps, verbose=exec_config.verbose, ) + _emit_tool_ends(exec_config.on_tool_event, tool_calls, tool_results, prev_tool_count, steps) # Extract MCP tool executions from the last response mcp_results = _extract_mcp_results(response) @@ -847,14 +934,17 @@ def _execute_turns_sync( else: print(f" API called with single model: {actual_model}") - response = self.client.chat.completions.create( - model=policy_result["model"], - messages=current_messages, - tools=tool_handler.schemas() or None, - mcp_servers=policy_result["mcp_servers"], - credentials=exec_config.credentials, + request_kwargs = { + "model": policy_result["model"], + "messages": current_messages, + "tools": tool_handler.schemas() or None, + "mcp_servers": policy_result["mcp_servers"], + "credentials": exec_config.credentials, **{**self._mk_kwargs(model_config), **policy_result["model_kwargs"]}, - ) + } + _emit(exec_config.on_model_event, {"kind": "model_request", "step": steps, "request": _to_jsonable(_redact_request_for_event(request_kwargs))}) + response = self.client.chat.completions.create(**request_kwargs) + _emit(exec_config.on_model_event, {"kind": "model_response", "step": steps, "response": _to_jsonable(response)}) if exec_config.verbose: print(f" Response received (server says model: {getattr(response, 'model', 'unknown')})") @@ -885,7 +975,9 @@ def _execute_turns_sync( # Execute tools tool_calls = self._extract_tool_calls(response.choices[0]) + prev_tool_count = len(tool_results) self._execute_tool_calls_sync(tool_calls, tool_handler, messages, tool_results, tools_called, steps) + _emit_tool_ends(exec_config.on_tool_event, tool_calls, tool_results, prev_tool_count, steps) # Extract MCP tool executions from the last response mcp_results = _extract_mcp_results(response) diff --git a/tests/lib/test_replay_recorder.py b/tests/lib/test_replay_recorder.py new file mode 100644 index 0000000..b0c7c31 --- /dev/null +++ b/tests/lib/test_replay_recorder.py @@ -0,0 +1,118 @@ +"""Unit tests for `dedalus_labs.lib.replay.Recorder`. + +These tests do not involve the runner — they feed events directly to the +recorder and inspect the resulting JSON file. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from dedalus_labs.lib.replay import ( + FORMAT_VERSION, + MODEL_REQUEST, + MODEL_RESPONSE, + TOOL_END, + Recorder, +) + + +def _read_trace(path: Path) -> dict: + return json.loads(path.read_text(encoding="utf-8")) + + +def test_recorder_captures_event_order(tmp_path: Path) -> None: + """Events are written in the order they were observed and stamped.""" + path = tmp_path / "trace.json" + with Recorder(path) as rec: + rec.on_model({"kind": MODEL_REQUEST, "step": 1, "request": {"x": 1}}) + rec.on_tool({"kind": TOOL_END, "step": 1, "name": "add", "result": 7}) + rec.on_model({"kind": MODEL_RESPONSE, "step": 1, "response": {"id": "abc"}}) + + trace = _read_trace(path) + assert [e["kind"] for e in trace["events"]] == [ + MODEL_REQUEST, + TOOL_END, + MODEL_RESPONSE, + ] + for event in trace["events"]: + assert isinstance(event["ts"], float) + + +def test_recorder_redacts_before_storing(tmp_path: Path) -> None: + """A redactor that strips a known string must keep raw values out of the trace.""" + path = tmp_path / "trace.json" + + def redact(event: dict) -> dict: + return {**event, "request": {"messages": ""}} + + with Recorder(path, redact=redact) as rec: + rec.on_model( + { + "kind": MODEL_REQUEST, + "step": 1, + "request": {"messages": "SECRET-CUSTOMER-DATA"}, + } + ) + + raw = path.read_text(encoding="utf-8") + assert "SECRET-CUSTOMER-DATA" not in raw + trace = _read_trace(path) + assert trace["events"][0]["request"]["messages"] == "" + + +def test_recorder_metadata_and_envelope_fields(tmp_path: Path) -> None: + """User metadata round-trips and the envelope carries format/sdk version.""" + path = tmp_path / "trace.json" + with Recorder(path, metadata={"ticket": "INC-42", "customer": "acme"}) as rec: + rec.on_model({"kind": MODEL_REQUEST, "step": 1, "request": {}}) + + trace = _read_trace(path) + assert trace["format_version"] == FORMAT_VERSION + assert trace["sdk_version"] # truthy; exact value asserted elsewhere + assert trace["metadata"] == {"ticket": "INC-42", "customer": "acme"} + assert trace["recorded_at"].endswith("Z") + + +def test_recorder_context_manager_writes_on_exit(tmp_path: Path) -> None: + """The trace file is created exactly when the `with` block exits.""" + path = tmp_path / "trace.json" + rec = Recorder(path) + rec.on_model({"kind": MODEL_REQUEST, "step": 1, "request": {}}) + assert not path.exists(), "file should not exist before save()" + + with rec: + pass # __exit__ calls save() + + assert path.exists() + assert _read_trace(path)["events"][0]["kind"] == MODEL_REQUEST + + +def test_recorder_swallows_redactor_failure(tmp_path: Path) -> None: + """A redactor that raises must not break recording. The event is kept and marked.""" + path = tmp_path / "trace.json" + + def broken(event: dict) -> dict: + raise RuntimeError("boom") + + with Recorder(path, redact=broken) as rec: + rec.on_model({"kind": MODEL_REQUEST, "step": 1, "request": {"k": "v"}}) + + trace = _read_trace(path) + assert trace["events"][0]["_redaction_failed"] is True + assert trace["events"][0]["request"] == {"k": "v"} + + +def test_recorder_save_is_idempotent(tmp_path: Path) -> None: + """Calling save() twice (or save() then __exit__) writes exactly once.""" + path = tmp_path / "trace.json" + rec = Recorder(path) + rec.on_model({"kind": MODEL_REQUEST, "step": 1, "request": {}}) + rec.save() + first_mtime = path.stat().st_mtime_ns + + rec.save() # idempotent + assert path.stat().st_mtime_ns == first_mtime diff --git a/tests/lib/test_replay_replayer.py b/tests/lib/test_replay_replayer.py new file mode 100644 index 0000000..adfd816 --- /dev/null +++ b/tests/lib/test_replay_replayer.py @@ -0,0 +1,228 @@ +"""Tests for `dedalus_labs.lib.replay.Replayer` and `_fake_client`.""" + +from __future__ import annotations + +import json +from typing import Any, Dict + +import pytest + +from dedalus_labs.lib.replay import FORMAT_VERSION, Replayer +from dedalus_labs.lib.replay._fake_client import _FakeClient +from dedalus_labs.types.chat import ChatCompletion + + +def _make_completion(content: str = "done") -> ChatCompletion: + return ChatCompletion.model_validate( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 0, + "model": "openai/gpt-5-nano", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + ) + + +class TestFakeClient: + def test_create_pops_responses_in_order(self) -> None: + r1 = _make_completion("first") + r2 = _make_completion("second") + client = _FakeClient([r1, r2]) + assert client.chat.completions.create() is r1 + assert client.chat.completions.create() is r2 + + def test_create_raises_on_empty_queue(self) -> None: + client = _FakeClient([]) + with pytest.raises(RuntimeError, match="replay drift"): + client.chat.completions.create() + + def test_create_ignores_kwargs(self) -> None: + r = _make_completion() + client = _FakeClient([r]) + result = client.chat.completions.create(model="anything", messages=[]) + assert result is r + + +# --------------------------------------------------------------------------- +# Helpers shared by Replayer tests + + +def _completion_dict(content: str = "done", tool_calls: list[Dict[str, Any]] | None = None) -> Dict[str, Any]: + message: Dict[str, Any] = {"role": "assistant"} + if content: + message["content"] = content + if tool_calls is not None: + message["tool_calls"] = tool_calls + finish_reason = "tool_calls" if tool_calls else "stop" + return { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 0, + "model": "openai/gpt-5-nano", + "choices": [{"index": 0, "message": message, "finish_reason": finish_reason}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + + +def _minimal_trace(events: list[Dict[str, Any]]) -> Dict[str, Any]: + return {"format_version": FORMAT_VERSION, "sdk_version": "0.3.0", "recorded_at": "2026-05-13T00:00:00Z", "metadata": {}, "events": events} + + +def _req_event(step: int = 1, tool_schemas: list[Dict[str, Any]] | None = None) -> Dict[str, Any]: + return { + "kind": "model_request", + "step": step, + "ts": 0.0, + "request": { + "model": ["openai/gpt-5-nano"], + "messages": [{"role": "user", "content": "What is 3 + 4?"}], + "tools": tool_schemas or [], + "mcp_servers": [], + }, + } + + +def _resp_event(step: int = 1, content: str = "done", tool_calls: list[Dict[str, Any]] | None = None) -> Dict[str, Any]: + return {"kind": "model_response", "step": step, "ts": 0.0, "response": _completion_dict(content=content, tool_calls=tool_calls)} + + +def _tool_end_event(name: str, result: Any, step: int = 1, call_id: str = "call_1", error: str | None = None) -> Dict[str, Any]: + ev: Dict[str, Any] = { + "kind": "tool_end", + "step": step, + "ts": 0.0, + "name": name, + "tool_call_id": call_id, + "arguments": json.dumps({"a": 3, "b": 4}), + "result": result, + } + if error is not None: + ev["error"] = error + return ev + + +# --------------------------------------------------------------------------- +# Replayer tests + + +class TestReplayer: + def test_unknown_format_version_rejected(self) -> None: + with pytest.raises(ValueError, match="unsupported trace format_version"): + Replayer.from_dict({"format_version": "9.9", "events": []}) + + def test_missing_events_rejected(self) -> None: + with pytest.raises(ValueError, match="missing an `events` list"): + Replayer.from_dict({"format_version": FORMAT_VERSION}) + + def test_round_trip_identity_no_tools(self) -> None: + """Single turn with no tool calls: final output matches the recorded response.""" + trace = _minimal_trace([ + _req_event(), + _resp_event(content="The answer is 7."), + ]) + result = Replayer.from_dict(trace).run() + assert result.final_output == "The answer is 7." + + def test_round_trip_with_tool_call(self) -> None: + """Two-turn trace: model calls a tool, then gives a final answer.""" + tool_call = {"id": "call_1", "type": "function", "function": {"name": "add", "arguments": json.dumps({"a": 3, "b": 4})}} + trace = _minimal_trace([ + _req_event(step=1), + _resp_event(step=1, content="", tool_calls=[tool_call]), + _tool_end_event("add", result=7, step=1), + _req_event(step=2), + _resp_event(step=2, content="The answer is 7."), + ]) + result = Replayer.from_dict(trace).run() + assert result.final_output == "The answer is 7." + + def test_swap_tool_runs_real_function(self) -> None: + """swap_tool replaces the recorded stub with the provided callable.""" + called_with: list[Dict[str, Any]] = [] + + def real_add(**kwargs: Any) -> int: + called_with.append(kwargs) + return 100 + + tool_call = {"id": "call_1", "type": "function", "function": {"name": "add", "arguments": json.dumps({"a": 3, "b": 4})}} + trace = _minimal_trace([ + _req_event(step=1), + _resp_event(step=1, content="", tool_calls=[tool_call]), + _tool_end_event("add", result=7, step=1), + _req_event(step=2), + _resp_event(step=2, content="done"), + ]) + Replayer.from_dict(trace).run(swap_tool={"add": real_add}) + assert called_with, "real_add should have been called" + + def test_swap_client_receives_full_recorded_messages(self) -> None: + """swap_client must see all recorded messages (system + user), not just messages[0].""" + captured: list[list[Dict[str, Any]]] = [] + + class CapturingCompletions: + def create(self, **kwargs: Any) -> ChatCompletion: + captured.append(kwargs.get("messages") or []) + return _make_completion("ok") + + class CapturingChat: + completions = CapturingCompletions() + + class CapturingClient: + chat = CapturingChat() + + trace = _minimal_trace([_req_event(), _resp_event(content="ok")]) + # Override the default single-user-message setup with a richer history + trace["events"][0]["request"]["messages"] = [ + {"role": "system", "content": "you are a math tutor"}, + {"role": "user", "content": "What is 3 + 4?"}, + ] + + Replayer.from_dict(trace).run(swap_client=CapturingClient()) + + assert captured, "expected swap_client.create to be called" + first_call_messages = captured[0] + # Both the system and user messages from the trace must be present + roles = [m.get("role") for m in first_call_messages] + assert "system" in roles, f"system message lost during replay: {roles}" + assert "user" in roles, f"user message lost during replay: {roles}" + + def test_swap_client_bypasses_fake_client(self) -> None: + """swap_client routes model calls to the provided object instead.""" + call_count = 0 + + class StubCompletions: + def create(self, **kwargs: Any) -> ChatCompletion: + nonlocal call_count + call_count += 1 + return _make_completion("stub answer") + + class StubChat: + completions = StubCompletions() + + class StubClient: + chat = StubChat() + + trace = _minimal_trace([_req_event(), _resp_event(content="original")]) + result = Replayer.from_dict(trace).run(swap_client=StubClient()) + assert call_count >= 1 + assert result.final_output == "stub answer" + + def test_drift_more_model_calls_than_recorded_raises(self) -> None: + """If the runner needs more model responses than recorded, raise clearly.""" + tool_call = {"id": "call_1", "type": "function", "function": {"name": "add", "arguments": "{}"}} + trace = _minimal_trace([ + _req_event(step=1), + _resp_event(step=1, content="", tool_calls=[tool_call]), + _tool_end_event("add", result=7, step=1), + # deliberately omit the second model_response + ]) + with pytest.raises(RuntimeError, match="replay drift"): + Replayer.from_dict(trace).run() diff --git a/tests/lib/test_replay_runner_integration.py b/tests/lib/test_replay_runner_integration.py new file mode 100644 index 0000000..ae3b863 --- /dev/null +++ b/tests/lib/test_replay_runner_integration.py @@ -0,0 +1,247 @@ +"""Integration tests: `DedalusRunner` emits events to `on_model_event` / `on_tool_event`. + +Uses respx to mock the HTTP boundary so these tests are hermetic and fast. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List + +import httpx +import pytest +from respx import MockRouter + +from dedalus_labs import Dedalus +from dedalus_labs.lib.replay import ( + FORMAT_VERSION, + MODEL_REQUEST, + MODEL_RESPONSE, + TOOL_END, + Recorder, +) +from dedalus_labs.lib.runner import DedalusRunner + +from ..conftest import base_url + + +def _completion( + *, + content: str | None = None, + tool_calls: List[Dict[str, Any]] | None = None, + finish_reason: str = "stop", +) -> Dict[str, Any]: + """Minimal mock chat-completion response shaped like the API.""" + message: Dict[str, Any] = {"role": "assistant"} + if content is not None: + message["content"] = content + if tool_calls is not None: + message["tool_calls"] = tool_calls + return { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 0, + "model": "openai/gpt-5-nano", + "choices": [{"index": 0, "message": message, "logprobs": None, "finish_reason": finish_reason}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + + +def _add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +@pytest.mark.respx(base_url=base_url) +def test_runner_emits_model_request_then_response(client: Dedalus, respx_mock: MockRouter) -> None: + """A no-tool response produces exactly one (request, response) pair.""" + respx_mock.post("/v1/chat/completions").mock( + return_value=httpx.Response(200, json=_completion(content="Hello there.")) + ) + + events: List[Dict[str, Any]] = [] + runner = DedalusRunner(client) + runner.run( + model="openai/gpt-5-nano", + input="hi", + on_model_event=events.append, + ) + + kinds = [e["kind"] for e in events] + assert kinds == [MODEL_REQUEST, MODEL_RESPONSE] + assert events[0]["step"] == 1 + assert events[1]["step"] == 1 + assert events[1]["response"]["choices"][0]["message"]["content"] == "Hello there." + + +@pytest.mark.respx(base_url=base_url) +def test_runner_emits_tool_end_after_local_tool_runs(client: Dedalus, respx_mock: MockRouter) -> None: + """A tool-calling turn produces a `tool_end` event with name + result + correlated id/args.""" + tool_call_payload = { + "id": "call_xyz", + "type": "function", + "function": {"name": "_add", "arguments": json.dumps({"a": 3, "b": 4})}, + } + respx_mock.post("/v1/chat/completions").mock( + side_effect=[ + httpx.Response(200, json=_completion(tool_calls=[tool_call_payload], finish_reason="tool_calls")), + httpx.Response(200, json=_completion(content="The answer is 7.")), + ] + ) + + tool_events: List[Dict[str, Any]] = [] + runner = DedalusRunner(client) + runner.run( + model="openai/gpt-5-nano", + input="What is 3 + 4?", + tools=[_add], + on_tool_event=tool_events.append, + ) + + assert len(tool_events) == 1 + event = tool_events[0] + assert event["kind"] == TOOL_END + assert event["name"] == "_add" + assert event["result"] == 7 + assert event["tool_call_id"] == "call_xyz" + # arguments are forwarded as the raw JSON string the model produced + assert json.loads(event["arguments"]) == {"a": 3, "b": 4} + + +@pytest.mark.respx(base_url=base_url) +def test_runner_does_not_break_when_callback_raises(client: Dedalus, respx_mock: MockRouter) -> None: + """A misbehaving callback must not propagate — the run completes normally.""" + respx_mock.post("/v1/chat/completions").mock( + return_value=httpx.Response(200, json=_completion(content="ok")) + ) + + def boom(_event: Dict[str, Any]) -> None: + raise RuntimeError("callback intentionally broken") + + runner = DedalusRunner(client) + result = runner.run( + model="openai/gpt-5-nano", + input="hi", + on_model_event=boom, + on_tool_event=boom, + ) + assert result.final_output == "ok" + + +@pytest.mark.respx(base_url=base_url) +def test_tool_end_events_correlate_by_id_under_concurrency( + client: Dedalus, respx_mock: MockRouter +) -> None: + """Two concurrent calls to the same tool must map each tool_end to the right tool_call_id.""" + import asyncio + + # Two parallel calls to the same tool, distinguishable only by their arguments + tc_fast = { + "id": "call_fast", + "type": "function", + "function": {"name": "_slow_add", "arguments": json.dumps({"a": 1, "b": 1, "delay": 0.0})}, + } + tc_slow = { + "id": "call_slow", + "type": "function", + "function": {"name": "_slow_add", "arguments": json.dumps({"a": 10, "b": 10, "delay": 0.05})}, + } + respx_mock.post("/v1/chat/completions").mock( + side_effect=[ + httpx.Response(200, json=_completion(tool_calls=[tc_slow, tc_fast], finish_reason="tool_calls")), + httpx.Response(200, json=_completion(content="done")), + ] + ) + + async def _slow_add(a: int, b: int, delay: float = 0.0) -> int: + await asyncio.sleep(delay) + return a + b + + tool_events: List[Dict[str, Any]] = [] + runner = DedalusRunner(client) + runner.run( + model="openai/gpt-5-nano", + input="parallel add", + tools=[_slow_add], + on_tool_event=tool_events.append, + ) + + assert len(tool_events) == 2 + # Each event must carry the ARGUMENTS that match its tool_call_id, not a swap caused by completion ordering + by_id = {e["tool_call_id"]: e for e in tool_events} + assert json.loads(by_id["call_fast"]["arguments"]) == {"a": 1, "b": 1, "delay": 0.0} + assert json.loads(by_id["call_slow"]["arguments"]) == {"a": 10, "b": 10, "delay": 0.05} + assert by_id["call_fast"]["result"] == 2 + assert by_id["call_slow"]["result"] == 20 + + +@pytest.mark.respx(base_url=base_url) +def test_model_request_event_omits_credentials(client: Dedalus, respx_mock: MockRouter) -> None: + """`credentials` must never reach the model_request event payload.""" + respx_mock.post("/v1/chat/completions").mock( + return_value=httpx.Response(200, json=_completion(content="ok")) + ) + + events: List[Dict[str, Any]] = [] + runner = DedalusRunner(client) + runner.run( + model="openai/gpt-5-nano", + input="hi", + credentials=[{"provider": "openai", "api_key": "sk-secret-do-not-leak"}], + on_model_event=events.append, + ) + + request_events = [e for e in events if e["kind"] == MODEL_REQUEST] + assert request_events, "expected at least one model_request event" + for ev in request_events: + assert "credentials" not in ev["request"], ( + f"credentials field leaked into trace event: {ev['request']!r}" + ) + # Defense in depth: the secret value must not appear anywhere serialized + assert "sk-secret-do-not-leak" not in json.dumps(ev) + + +@pytest.mark.respx(base_url=base_url) +def test_record_then_load_produces_valid_trace_format( + client: Dedalus, respx_mock: MockRouter, tmp_path: Path +) -> None: + """End-to-end: record a full run via Recorder, reload the trace, validate the envelope.""" + tool_call_payload = { + "id": "call_xyz", + "type": "function", + "function": {"name": "_add", "arguments": json.dumps({"a": 3, "b": 4})}, + } + respx_mock.post("/v1/chat/completions").mock( + side_effect=[ + httpx.Response(200, json=_completion(tool_calls=[tool_call_payload], finish_reason="tool_calls")), + httpx.Response(200, json=_completion(content="The answer is 7.")), + ] + ) + + trace_path = tmp_path / "trace.json" + runner = DedalusRunner(client) + with Recorder(trace_path, metadata={"test": True}) as rec: + runner.run( + model="openai/gpt-5-nano", + input="What is 3 + 4?", + tools=[_add], + on_model_event=rec.on_model, + on_tool_event=rec.on_tool, + ) + + trace = json.loads(trace_path.read_text(encoding="utf-8")) + + # Envelope shape + assert trace["format_version"] == FORMAT_VERSION + assert trace["sdk_version"] + assert trace["recorded_at"].endswith("Z") + assert trace["metadata"] == {"test": True} + + # Every event has the required keys + for event in trace["events"]: + assert "kind" in event and "step" in event and "ts" in event + + kinds = [e["kind"] for e in trace["events"]] + # Two turns: request → response → tool_end → request → response + assert kinds == [MODEL_REQUEST, MODEL_RESPONSE, TOOL_END, MODEL_REQUEST, MODEL_RESPONSE]