diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 095e4eb..bb9e406 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,7 @@ jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python: ["3.10", "3.11", "3.12"] @@ -27,22 +28,25 @@ jobs: pip install -e ".[dev]" - name: Run tests - run: pytest + run: pytest tests/ -v - - name: Run ruff - run: ruff check src/ + - name: Lint (ruff) + run: | + ruff check src/ tests/ + ruff format --check src/ tests/ - - name: Run mypy - run: mypy src/ + - name: Type check (mypy strict) + run: mypy src/nullrun --strict coverage: runs-on: ubuntu-latest + needs: test steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.12" - run: pip install -e ".[dev]" - - run: coverage run -m pytest + - run: coverage run -m pytest tests/ -v - uses: codecov/codecov-action@v4 - if: always() \ No newline at end of file + if: always() diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index bf63613..6c310a3 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -3,67 +3,98 @@ name: Publish to PyPI on: push: tags: - - 'v*' # триггер только по тегу: git tag v0.1.0 && git push --tags + - 'v*' + workflow_dispatch: + inputs: + target: + description: 'Publish target' + type: choice + required: true + default: testpypi + options: + - testpypi + - pypi + +# Trusted Publishing: no API tokens in repo or CI. +# Configure once on the PyPI side, see SETUP.md for step-by-step. jobs: test: name: Run tests runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python-version: ["3.10", "3.11", "3.12"] - steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install -e ".[dev]" - + run: pip install -e ".[dev]" - name: Run tests run: pytest tests/ -v + - name: Lint (ruff) + run: | + ruff check src/ tests/ + ruff format --check src/ tests/ + - name: Type check (mypy strict) + run: mypy src/nullrun --strict - publish: - name: Build and publish - needs: test # сначала все тесты зелёные — потом публикация + publish-testpypi: + name: Publish to TestPyPI + needs: test + if: github.event_name == 'workflow_dispatch' && inputs.target == 'testpypi' runs-on: ubuntu-latest + environment: + name: testpypi + url: https://test.pypi.org/project/nullrun + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Build package + run: | + pip install hatchling build + python -m build + - name: Check dist + run: | + pip install twine + twine check dist/* + - name: Publish to TestPyPI (Trusted Publishing) + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + attestations: false + publish-pypi: + name: Publish to PyPI + needs: test + if: | + (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')) || + (github.event_name == 'workflow_dispatch' && inputs.target == 'pypi') + runs-on: ubuntu-latest environment: name: pypi - url: https://pypi.org/p/nullrun-sdk - + url: https://pypi.org/project/nullrun permissions: - id-token: write # для trusted publishing (без токена, рекомендуется PyPI) - + id-token: write steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 with: python-version: "3.11" - - name: Build package run: | pip install hatchling build python -m build - - - name: Check dist contents + - name: Check dist run: | pip install twine twine check dist/* - - # Вариант 1: Trusted Publishing (рекомендуется, не нужен токен) - # Настроить на pypi.org: Account → Publishing → Add publisher - # Publisher: GitHub, repo: maltsev-dev/nullrun-sdk, workflow: publish.yml - name: Publish to PyPI (Trusted Publishing) uses: pypa/gh-action-pypi-publish@release/v1 - - # Вариант 2: API токен (раскомментируй если не используешь Trusted Publishing) - # - name: Publish to PyPI (API token) - # uses: pypa/gh-action-pypi-publish@release/v1 - # with: - # password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index f07fbba..c132bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,31 +9,149 @@ Versioning: [Semantic Versioning](https://semver.org/spec/v2.0.0.html) ## [Unreleased] +### Removed (0.4.0 deprecations — full removal in 1.0.0) + +- **gRPC transport removed** (`src/nullrun/grpc_transport.py`, `protos/`): The backend + proto is frozen and missing trace/span fields. The `NULLRUN_USE_GRPC` env var + is now a no-op that emits a single WARNING at init. HTTP is the only + supported ingestion path. Affects users who relied on the binary + protobuf+HTTP/2 path — migrate to the HTTP transport. `grcpio` and + `grpcio-tools` removed from `pyproject.toml`. +- **AsyncTransport removed** (`src/nullrun/transport.py:AsyncTransport`): ~600 lines + of duplicate code that was used only in tests. The sync `Transport` works + fine from async event loops via `nullrun.track_llm` / `@nullrun.protect` — + the underlying httpx client + background flush thread is non-blocking. +- **`AdaptivePool` removed** (`src/nullrun/transport.py:AdaptivePool`, `PoolConfig`): + Backpressure pool for the deleted async transport. +- **Process-wide signal handler removed** (`src/nullrun/transport.py:_register_signal_handlers`): + The `Transport.__init__` no longer overwrites the application's global + `SIGTERM`/`SIGINT` handlers. Callers in long-lived services MUST now + call `transport.stop()` explicitly, use `transport` as a context + manager (`with Transport(...) as t:`), or rely on `weakref.finalize` + for cleanup at process exit. The new `__exit__` method is the + recommended pattern. + +### Fixed + +- **track() double-emit** (P0): The `else` branch of `runtime.track()` + no longer calls `self._transport.track(...)` twice. Previously the + non-gRPC code path produced two `/api/v1/track/batch` events per + call, double-billing customers. +- **Buffer re-binding race** (P0): `Transport._do_flush_locked` no + longer rebinds `self._buffer` to a new list. A new + `_drain_batch()` helper uses in-place slice (`del self._buffer[:]`) + so concurrent `track()` calls holding a reference to the old list + see the post-drain state instead of appending to dead memory. +- **Buffer overflow check** (P0): The CB-OPEN re-queue previously + computed `available_space = max_buffer_size - len(self._buffer)`, + but `self._buffer` was already empty (cleared by `_drain_batch`). + The overflow check was a no-op, and the buffer grew unboundedly + under sustained backend outage. The fix checks `len(batch)` against + `max_buffer_size` and drops the oldest events in the batch. +- **ActionHandler kill contract** (P0): `ActionHandler.handle` no + longer catches `BaseException`. The default KILL handler + intentionally raises `WorkflowKilledInterrupt` (a `BaseException` + subclass) to halt the agent; the previous code silently swallowed + it, breaking the kill contract. The fix catches only `Exception` + and explicitly re-raises `WorkflowKilledInterrupt` and + `WorkflowPausedException`. `KeyboardInterrupt` and `SystemExit` + also propagate correctly. +- **Signal handler side-effects** (P0): The previous + `_register_signal_handlers` called `sys.exit(0)` from inside a + signal context (undefined behaviour per CPython docs) and did file + I/O (`_persist_to_wal`) from a signal context. Both unsafe; removed + with the handler. +- **Atexit LIFO ordering** (P0): Multiple `Transport()` instances + each registered an `atexit` handler, and the LIFO order meant the + last-constructed transport's flush ran first. The new + `weakref.finalize` approach is per-instance: if the transport has + been GC'd, the atexit is a no-op; if not, the flush runs exactly + once. + ### Added -- **Async Policy Cache**: `AsyncTransport` now uses `PolicyCache` for CACHED fallback mode. Previously the async transport always fell back to PERMISSIVE when gateway was unreachable. Now it caches successful execute decisions and uses them when gateway is unavailable. -- **Custom Sensitive Tools API**: Added `add_sensitive_tool()`, `remove_sensitive_tool()`, `register_sensitive_tools()`, and `get_sensitive_tools()` methods to `NullRunRuntime`. Users can now register custom tools as sensitive requiring strict mode enforcement. -- **`NullRunBlockedException.tool_name` attribute** (FIX-5): The `tool_name` - kwarg is now a first-class attribute on `NullRunBlockedException` - (and its subclasses `LoopDetectedException`, etc.) instead of being - absorbed into `**details`. Cookbook examples that read `exc.tool_name` - no longer raise `AttributeError`. Backwards-compatible: `tool_name` - defaults to `None` and does not appear in `exc.details` when unset. - The stringified exception now includes `tool={name}` when set. +- **Context manager support** (`Transport.__enter__` / `__exit__`): The + recommended lifecycle for long-lived services: + `with Transport(api_url=..., api_key=...) as t: ...` + The `__exit__` calls `stop()` which flushes and closes the + client. +- **`_atexit_flush_safe`**: A wrapper around `_atexit_flush` that + catches any exception in the flush and logs it. The atexit chain + must never propagate an exception (which would skip subsequent + atexit handlers in some Python implementations). + +--- + +## [Unreleased] + +### Added + +- **`Transport.evaluate()`**: New public method for the pre-validation + ("what if") path. Routes through the SDK's own connection pool, HMAC + signing, circuit breaker, and retry policy. The previous implementation + in `runtime.evaluate()` reached into `transport._client` directly, + silently bypassing the circuit breaker — a production hazard. +- **`Transport.check()` `on_transport_error` parameter**: matches the + contract of `Transport.execute()` (`"raise"` / `"open"` / `"closed"` / + `"legacy"`). The previous default returned `decision="block"` on + every transport error, which contradicted the ADR-008 fail-OPEN + promise for `check_workflow_budget`. +- **`AsyncTransport.execute()` `on_transport_error` parameter**: mirrors + the sync `Transport.execute()` contract. The previous async + implementation used a 2-attempt retry loop with no classified + failure source — calling gates (e.g. `_enforce_sensitive_tool` in + `decorators.py`) could not tell a transport failure apart from a + real policy block. +- **`NullRunRuntime.check_control_plane()` accepts `str | None`**: the + previous signature required `str` even though the contract is + "contextvar → API-key-bound workflow → no-op" and `None` is the + canonical "no workflow scoped" value. The wrapper no longer needs + to fake a value. +- **`_safe_bump_coverage` helper in `nullrun.instrumentation.auto`**: a + new module-level utility that bumps a per-host coverage counter on + the runtime. Tolerates stub runtimes (MagicMock, namespace + objects) by no-oping when the counter attribute is missing. The + sync/async transport hooks and `auto_requests` now use it to + record per-host coverage consistently. +- **README env var table** now documents `NULLRUN_TRANSPORT` + (WebSocket vs HTTP poller), `NULLRUN_WAL_PATH` (WAL override), + and `NULLRUN_TLS_CA_CERT` (custom HTTPS CA). ### Fixed -- **SDK silent runtime fallback removed** (FIX-4): `_get_or_create_runtime` - in `nullrun.decorators` no longer wraps `NullRunRuntime.get_instance()` - in a `try/except Exception` that rebuilds a no-arg `NullRunRuntime()`. - In 0.3.0 (T3-S2) the no-arg constructor requires `api_key` and raises - `NullRunAuthenticationError` — so the fallback swallowed the auth - error from `get_instance()` only to crash with the same error from - the fallback path itself. After this fix, the auth error propagates - cleanly to the first `@protect` invocation, mirroring the fail-loud - contract of `nullrun.init()`. Aligns with the T3-S2 invariant that - the SDK has no local mode: a missing API key is a hard error, not a - silent allow-all. +- **Test fixture path**: `tests/test_kill_contract.py` no longer references the obsolete `sdk-python/` path; it now documents the correct repository name. +- **Examples**: `examples/basic.py`, `examples/async_usage.py`, `examples/basic_observe.py`, `examples/cost_dashboard.py` updated to the 0.3.0 contract (api_key is required, no `organization_id` kwarg, no `coverage_report()`). +- **Version drift**: `src/nullrun/__version__.py` was reporting `0.2.0` while `pyproject.toml` declared `0.3.0`; both now agree on `0.3.0`. +- **License classifier**: `pyproject.toml` now declares `License :: OSI Approved :: Apache Software License` (the actual license is Apache-2.0). +- **Type hint fix**: `NullRunRuntime.wrap_tool` / `wrap` now annotate `Callable[..., Any]` (from `collections.abc`) instead of the built-in `callable` function — `mypy --strict` is now happy. +- **Wrap NameError**: `NullRunRuntime.wrap` previously referenced an undefined `workflow_id` in its `NullRunBlockedException` raise; it now resolves the workflow id from the contextvar (with a `` fallback matching the rest of the runtime). +- **`requests` import removed**: `Transport._refetch_credentials` no longer imports `requests` (which was never declared as a dependency); it now uses the SDK's own `httpx` client. +- **`_safe_bump_coverage` added**: `nullrun.instrumentation.auto` now exports the `_safe_bump_coverage` helper that `nullrun.instrumentation.auto_requests` and the sync/async transport hooks all use to record per-host coverage counters on the runtime. +- **Transport error routing (ADR-008)**: `Transport.execute()` and `Transport.check()` now accept an `on_transport_error` kwarg (`"raise"` / `"open"` / `"closed"` / `"legacy"`) and classify the failure source as `NETWORK_ERROR` / `GATEWAY_ERROR` / `BREAKER_OPEN` (was previously opaque to the caller). +- **`transport.check` fail-OPEN**: `check()` no longer silently returns `decision="block"` on every transport error; it routes through the same `_handle_transport_error` helper as `execute()` and lets the caller decide. +- **`shutdown()` closes gRPC**: `NullRunRuntime.shutdown()` now also closes the gRPC channel when one is in use, so a `NULLRUN_USE_GRPC=1` init does not leak the channel on shutdown. +- **`GrpcTransport.track` cost_cents default**: the gRPC `track` method's `cost_cents` argument is now keyword-only with a default of `0`, matching the runtime call site. +- **WAL location**: `Transport._persist_to_wal` now writes to `$NULLRUN_WAL_PATH` or `$TMPDIR/.nullrun.wal` (per-user) instead of `os.getcwd()/.nullrun.wal` — the previous location was unsafe in production. +- **CHANGELOG links**: links now point at the canonical repository `nullrunio/nullrun-sdk-python`. +- **CI mypy strict**: `mypy src/nullrun --strict` is now wired into the `make check` / `make type-check` targets and the GitHub Actions CI workflow runs it explicitly. +- **CI ruff tests**: the GitHub Actions CI workflow now lints both `src/` and `tests/` (the previous version only covered `src/`). +- **AsyncTransport `_client is None` bug**: `AsyncTransport._flush_locked()` now lazy-initializes the httpx async client if `start()` was not called first. Tests that drive `_flush_locked` directly (a few in `tests/test_transport.py`) no longer crash with `'NoneType' object has no attribute 'post'`. +- **`_retry_with_backoff` propagation**: the helper used to wrap the original exception in `BreakerTransportError` on retry exhaustion, conflating "CB OPEN" and "retries exhausted". It now re-raises the original exception so the calling gate can classify the source (network vs 5xx). +- **Defense-in-depth check in `decorators._enforce_sensitive_tool`**: now recognizes both `"FALLBACK_*"` (legacy `fallback_mode` shape) and the new `TransportErrorSource` enum values (`NETWORK_ERROR` / `GATEWAY_ERROR` / `BREAKER_OPEN` / `AUTH_ERROR`). The old check only matched `"FALLBACK_*"`, which meant a synthetic allow with `decision_source="NETWORK_ERROR"` would slip through. +- **`auto._check_kill_before_send` tolerates stub runtimes**: was crashing with `AttributeError: 'X' object has no attribute '_resolve_workflow_id'` when patched against a MagicMock / namespace runtime. Now uses `getattr(runtime, "_resolve_workflow_id", None)` and no-ops when the method is missing. +- **`_strip_wire_only_fields` helper**: centralized the wire-format contract (currently a single field: `cost_cents`) in one method on `NullRunRuntime` so future local-only fields land in the same place. +- **Test: `tests/test_runtime.py`** — removed the `NULLRUN_WORKSPACE_ID` env var that the SDK does not read (the organization id comes from `/auth/verify` in 0.3.0). +- **Test: `tests/test_runtime_default_transport.py`** — fixed path references from `sdk-python/src/...` to the correct `src/...`. +- **Test: `tests/test_toolbox_langgraph.py`** — the wrapper tests now pass a stub `runtime` argument so they do not require `NULLRUN_API_KEY` in the test env. The public wrapper contract (`wrapper(app, runtime=None)`) is unchanged. +- **Test: `tests/test_preflight_fail_policy.py`** — `test_real_block_still_honored` now distinguishes the budget pre-check from the sensitive-tool pre-check via a `side_effect` callback that reads the request body. The two gate calls share the `/api/v1/gate` URL but differ in the `check_type` field, so a single response mock would have masked the real-block path. +- **Dockerfile**: the runtime stage now installs `nullrun[langgraph]` (not `nullrun-breaker[langgraph]`, which never existed) and sets `CMD ["python"]` (the `python -m nullrun.breaker` entry point it referenced does not exist; the SDK is a library, not a CLI). + +### Removed + +- **`docs/adr/008-...` references**: docstrings in `runtime.py` and `transport.py` no longer reference a path that does not exist in this repository (the ADR lives in the gateway repo, not the SDK). +- **Stale CHANGELOG link to `maltsev-dev/nullrun-sdk`**: replaced with the canonical `nullrunio/nullrun-sdk-python`. +- **Stale `docs/kill-contract.md` and `docs/known-limitations.md` references**: docstrings now point at the gateway repository, where the canonical design notes live. +- **Dead "How to integrate" comment block in `observability.py`**: the comment showed the old direct-attribute pattern (`metrics.transport.batches_sent += 1`); the runtime actually uses the lock-aware `metrics.inc_transport(...)` path. --- @@ -73,28 +191,29 @@ Versioning: [Semantic Versioning](https://semver.org/spec/v2.0.0.html) `from nullrun.transport import FallbackMode, PoolConfig`) remain available. Audited for 0 external callers. -### Migration - -- **0.2.x → 0.3.0**: - - `nullrun.init()` calls without `api_key` will raise. Pass - `api_key="nr_live_..."` explicitly or set `NULLRUN_API_KEY`. - - `NullRunRuntime(...)` constructions without `api_key` will raise - (same fix). - - Tests using `NullRunNoop` / `local_mode=True` mocking must switch - to `NullRunRuntime(api_key="test-key", _test_mode=True)` — - `_test_mode` skips the network calls without silently bypassing - policy. - - `from nullrun import BreakerError` (and the 6 other legacy names) - must use the canonical paths above. - ### Added - **Async Policy Cache**: `AsyncTransport` now uses `PolicyCache` for CACHED fallback mode. Previously the async transport always fell back to PERMISSIVE when gateway was unreachable. Now it caches successful execute decisions and uses them when gateway is unavailable. - **Custom Sensitive Tools API**: Added `add_sensitive_tool()`, `remove_sensitive_tool()`, `register_sensitive_tools()`, and `get_sensitive_tools()` methods to `NullRunRuntime`. Users can now register custom tools as sensitive requiring strict mode enforcement. -### Deprecated +### Fixed -- **No-api-key init / local mode** (T3-S1): Calling `nullrun.init()` or constructing `NullRunRuntime(...)` without an `api_key` (and with `NULLRUN_API_KEY` unset) now emits a `DeprecationWarning`. The runtime still falls back to local mode and silently bypasses every backend gate (budget, policy, control plane). The fallback will be **removed in 0.3.0** — passing `api_key='nr_live_...'` explicitly or setting `NULLRUN_API_KEY` is the only supported path going forward. Pin the warning to a hard error with `python -W error::DeprecationWarning` to catch callers in CI. +- **SDK silent runtime fallback removed** (FIX-4): `_get_or_create_runtime` + in `nullrun.decorators` no longer wraps `NullRunRuntime.get_instance()` + in a `try/except Exception` that rebuilds a no-arg `NullRunRuntime()`. + In 0.3.0 (T3-S2) the no-arg constructor requires `api_key` and raises + `NullRunAuthenticationError` — so the fallback swallowed the auth + error from `get_instance()` only to crash with the same error from + the fallback path itself. After this fix, the auth error propagates + cleanly to the first `@protect` invocation, mirroring the fail-loud + contract of `nullrun.init()`. +- **`NullRunBlockedException.tool_name` attribute** (FIX-5): The `tool_name` + kwarg is now a first-class attribute on `NullRunBlockedException` + (and its subclasses `LoopDetectedException`, etc.) instead of being + absorbed into `**details`. Cookbook examples that read `exc.tool_name` + no longer raise `AttributeError`. Backwards-compatible: `tool_name` + defaults to `None` and does not appear in `exc.details` when unset. + The stringified exception now includes `tool={name}` when set. --- @@ -129,18 +248,34 @@ Versioning: [Semantic Versioning](https://semver.org/spec/v2.0.0.html) ### Notes - Requires Python ≥ 3.10 -- Compatible with NullRun API version `2024-01-15` +- Compatible with NullRun API version `2026-06-16` --- ## How to upgrade -### 0.x → next +### 0.1.x → 0.2.x + +_No breaking changes recorded. The 0.2.x line was a hardening series that did not break the public surface._ + +### 0.2.x → 0.3.0 -_No breaking changes yet. Watch this file._ +- `nullrun.init()` calls without `api_key` will raise. Pass + `api_key="nr_live_..."` explicitly or set `NULLRUN_API_KEY`. +- `NullRunRuntime(...)` constructions without `api_key` will raise + (same fix). +- Tests using `NullRunNoop` / `local_mode=True` mocking must switch + to `NullRunRuntime(api_key="test-key", _test_mode=True)` — + `_test_mode` skips the network calls without silently bypassing + policy. +- `from nullrun import BreakerError` (and the 6 other legacy names) + must use the canonical paths: + `from nullrun.breaker.exceptions import NullRunBlockedException` + and `from nullrun.runtime import Policy` / `from nullrun.transport import FallbackMode, PoolConfig`. --- -[Unreleased]: https://github.com/maltsev-dev/nullrun-sdk/compare/v0.1.1...HEAD -[0.1.1]: https://github.com/maltsev-dev/nullrun-sdk/releases/tag/v0.1.1 -[0.1.0]: https://github.com/maltsev-dev/nullrun-sdk/releases/tag/v0.1.0 +[Unreleased]: https://github.com/nullrunio/nullrun-sdk-python/compare/v0.3.0...HEAD +[0.3.0]: https://github.com/nullrunio/nullrun-sdk-python/compare/v0.1.1...v0.3.0 +[0.1.1]: https://github.com/nullrunio/nullrun-sdk-python/compare/v0.1.0...v0.1.1 +[0.1.0]: https://github.com/nullrunio/nullrun-sdk-python/releases/tag/v0.1.0 diff --git a/Dockerfile b/Dockerfile index ef19b74..0b8a6c8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,7 +31,13 @@ RUN pip install /app/dist/*.whl --force-reinstall RUN useradd -m -u 1000 nullrun USER nullrun -# Install optional dependencies -RUN pip install "nullrun-breaker[langgraph]" - -ENTRYPOINT ["python", "-m", "nullrun.breaker"] +# Install optional dependencies (langgraph integration is the only +# one with a non-trivial extra deps tree at the moment). The +# `nullrun[langgraph]` extra is defined in pyproject.toml. +RUN pip install "nullrun[langgraph]" + +# The SDK ships as a library — there is no `python -m nullrun.breaker` +# entry point. The default CMD is `python` so the user can wire +# their own agent. Override at run time: +# docker run -it --rm nullrun-sdk python -c "from nullrun import protect; print('ok')" +CMD ["python"] diff --git a/README.md b/README.md index 8feba1b..930d3f6 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,9 @@ pip install nullrun ## Quick start ```python -from nullrun import protect +from nullrun import init, protect + +init(api_key="nr_live_...") # required as of 0.3.0 — see CHANGELOG @protect def my_agent(prompt: str) -> str: @@ -31,13 +33,16 @@ integrations. | Env var | Default | Description | |---|---|---| -| `NULLRUN_API_KEY` | — | API key from the NullRun dashboard. **Required.** | +| `NULLRUN_API_KEY` | — | API key from the NullRun dashboard. **Required** as of 0.3.0. | | `NULLRUN_API_URL` | `https://api.nullrun.io` | Backend base URL. | +| `NULLRUN_TRANSPORT` | `ws` | Control-plane transport. `ws` (default) uses the WebSocket push channel for sub-second kill/pause propagation; `http` falls back to the legacy 1-second HTTP poll. | | `NULLRUN_HMAC_REQUIRED` | `false` | Server-side: require HMAC body signature. | -| `NULLRUN_SKIP_BUDGET_CHECK` | unset | Opt-out of pre-flight `/check` (test only). | +| `NULLRUN_SKIP_BUDGET_CHECK` | unset | Opt-out of pre-flight `/check` (test only). **Full billing bypass**, not just check bypass. | | `NULLRUN_SENSITIVE_FAIL_OPEN` | unset | Opt-out of fail-CLOSED for sensitive tools (test only). | | `NULLRUN_TLS_CLIENT_CERT` | unset | mTLS client cert path (server-side). | | `NULLRUN_TLS_CLIENT_KEY` | unset | mTLS client key path (server-side). | +| `NULLRUN_TLS_CA_CERT` | unset | Custom CA cert path for HTTPS verification. | +| `NULLRUN_WAL_PATH` | `/.nullrun.wal` | Where the SDK writes unflushed events on shutdown for replay on next start. | | `NULLRUN_LOG_LEVEL` | `INFO` | One of `DEBUG` / `INFO` / `WARNING` / `ERROR`. | | `NULLRUN_BATCH_SIZE` | `100` | Track event batch size. | | `NULLRUN_FLUSH_INTERVAL_MS` | `5000` | Track event flush interval. | @@ -48,21 +53,17 @@ integrations. | Env var | Default | Description | |---|---|---| | `NULLRUN_USE_GRPC` | unset | **Do not enable in production.** See warning below. | -| `NULLRUN_GRPC_URL` | `localhost:50051` | gRPC server address (server-side: `GRPC_PORT`). | -| `NULLRUN_GRPC_REFLECTION` | unset | Server-side: `1` enables proto schema reflection on `:50051`. | -| `NULLRUN_GRPC_UNSAFE_ALLOW` | unset | Server-side: required alongside `NULLRUN_USE_GRPC=1` to acknowledge the gRPC server is unsafe. The backend refuses to start if `NULLRUN_USE_GRPC=1` is set without this. Never set in shared environments. | +| `NULLRUN_GRPC_URL` | `localhost:50051` | gRPC server address. | > ⚠️ **The gRPC server is intentionally frozen.** It does not validate > `x-api-key` in metadata (the auth helper exists in the > [gateway repository](https://github.com/nullrunio/nullrun) but is not > wired into the RPC handlers), runs over plaintext HTTP/2, and exposes -> the full proto schema via reflection (when enabled). The backend's -> startup script (in the [gateway repository](https://github.com/nullrunio/nullrun)) -> refuses to start if `NULLRUN_USE_GRPC=1` is set without the explicit -> opt-in `NULLRUN_GRPC_UNSAFE_ALLOW=1`. The opt-in is for local/dev use -> only and is logged at WARN. See the activation checklist (TLS → auth → -> proto extensions → cost pipeline parity → tests) in the gateway repo -> that must be completed before this transport is production-safe. +> the full proto schema via reflection. The backend refuses to start if +> `NULLRUN_USE_GRPC=1` is set without `NULLRUN_GRPC_UNSAFE_ALLOW=1`. +> See the activation checklist (TLS → auth → proto extensions → cost +> pipeline parity → tests) in the gateway repo that must be completed +> before this transport is production-safe. If you copy `.env.example` to `.env`, copy this block as well: diff --git a/examples/async_usage.py b/examples/async_usage.py index d70960b..106e857 100644 --- a/examples/async_usage.py +++ b/examples/async_usage.py @@ -1,22 +1,36 @@ """ -Async usage — @protect with async functions in local mode. -Run: python examples/async_usage.py +Async usage — `@protect` with async functions in cloud mode. + +Run: + export NULLRUN_API_KEY=nr_live_... + python examples/async_usage.py """ import asyncio +import os + +from nullrun import init, protect -from nullrun import protect, init +# Cloud mode — api_key is required as of 0.3.0 (T3-S2). The previous +# silent fallback to a "local mode" stub was removed because it hid +# policy violations and bypassed every backend gate. Pass +# `api_key=...` explicitly or set NULLRUN_API_KEY. +init( + api_key=os.environ.get("NULLRUN_API_KEY", "nr_live_demo_key"), + api_url=os.environ.get("NULLRUN_API_URL", "https://api.nullrun.io"), +) -# No api_key → local mode (auto-detected). No network calls, no polling. -init() @protect async def async_tool(prompt: str) -> str: await asyncio.sleep(0.01) - return f"[async local] {prompt}" + return f"[async protected] {prompt}" + async def main() -> None: print("Running async protected function...") result = await async_tool("Tell me a joke") print(f"Result: {result}") -asyncio.run(main()) \ No newline at end of file + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/basic.py b/examples/basic.py index d4739f0..d7335e2 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -1,17 +1,31 @@ """ -Basic usage — @protect decorator in local mode. -Run: python examples/basic.py +Basic usage — `@protect` decorator with a cloud runtime. + +Run: + export NULLRUN_API_KEY=nr_live_... + python examples/basic.py """ -from nullrun import protect, init +import os + +from nullrun import init, protect + +# Cloud mode — api_key is required as of 0.3.0 (T3-S2). The previous +# silent fallback to a "local mode" stub was removed because it hid +# policy violations and bypassed every backend gate. Pass +# `api_key=...` explicitly or set NULLRUN_API_KEY. +init( + api_key=os.environ.get("NULLRUN_API_KEY", "nr_live_demo_key"), + api_url=os.environ.get("NULLRUN_API_URL", "https://api.nullrun.io"), +) -# No api_key → local mode (auto-detected). No network calls, no polling. -init() @protect def call_llm(prompt: str) -> str: - return f"[local-mode response] {prompt[:50]}" + return f"[protected call] {prompt[:50]}" + -print("Calling protected function...") -result = call_llm("What is the capital of France?") -print(f"Result: {result}") -print("Done.") \ No newline at end of file +if __name__ == "__main__": + print("Calling protected function...") + result = call_llm("What is the capital of France?") + print(f"Result: {result}") + print("Done.") diff --git a/examples/basic_observe.py b/examples/basic_observe.py index 18a8868..42a3bcf 100644 --- a/examples/basic_observe.py +++ b/examples/basic_observe.py @@ -1,15 +1,14 @@ """ Phase 2 hero example — basic observability, no code changes. -The promise: install `nullrun`, call `init(api_key=..., org_id=...)`, -and the SDK observes your existing LLM calls. No decorator needed. -The dashboard picks up the events as they happen. +The promise: install `nullrun`, call `init(api_key=...)`, and the SDK +observes your existing LLM calls. No decorator needed. The dashboard +picks up the events as they happen. Run: - pip install -e ../sdk-python + pip install -e . export NULLRUN_API_KEY=nr_live_... - export NULLRUN_ORGANIZATION_ID=org-123 - python basic_observe.py + python examples/basic_observe.py """ import os @@ -17,27 +16,25 @@ import nullrun from openai import OpenAI -# 1. One-line init. The SDK reads NULLRUN_API_KEY and -# NULLRUN_ORGANIZATION_ID from the environment if you don't pass -# them. Auto-instrumentation wires up the OpenAI transport AFTER +# 1. One-line init. The SDK reads NULLRUN_API_KEY and NULLRUN_API_URL +# from the environment if you don't pass them explicitly. +# Auto-instrumentation wires up the OpenAI transport AFTER # `init()` returns — see `init()` for the wiring order. nullrun.init( - organization_id=os.environ.get("NULLRUN_ORGANIZATION_ID", "org-demo"), - api_key=os.environ.get("NULLRUN_API_KEY", "demo-key"), - api_url=os.environ.get("NULLRUN_API_URL", "http://localhost:8080"), + api_key=os.environ.get("NULLRUN_API_KEY", "nr_live_demo_key"), + api_url=os.environ.get("NULLRUN_API_URL", "https://api.nullrun.io"), ) # 2. Use OpenAI exactly as you did before. The auto-instrumentation -# in `nullrun.instrumentation.auto` patches `openai.OpenAI` and -# `openai.AsyncOpenAI` to record every chat completion as a +# in `nullrun.instrumentation.auto` patches `httpx.Client` and +# `httpx.AsyncClient` to record every chat completion as a # `llm_call` event with token counts, latency, and cost. client = OpenAI() # 3. Make a real call. The SDK records: # - workflow_id: derived from the API key on the backend -# (or by `with workflow("..."):` to override locally) # - tokens: from the response.usage -# - cost: computed server-side from `model_pricing` +# - cost: computed server-side from the org's pricing policy # - latency: from request start to response # The dashboard updates within ~2s. for i in range(3): @@ -46,10 +43,3 @@ messages=[{"role": "user", "content": f"Say hi (call #{i + 1})"}], ) print(f"call #{i + 1}: {resp.choices[0].message.content!r}") - -# 4. Optional: print a coverage snapshot. The same payload is sent -# over the WS heartbeat every 60s and via the HTTP-fallback path -# when the WS connection is down. -print("\nCoverage snapshot:") -for k, v in nullrun.coverage_report().items(): - print(f" {k}: {v}") diff --git a/examples/cost_dashboard.py b/examples/cost_dashboard.py index 105e886..41210fc 100644 --- a/examples/cost_dashboard.py +++ b/examples/cost_dashboard.py @@ -7,10 +7,9 @@ can see that the SDK and the dashboard agree. Run: - pip install -e ../sdk-python + pip install -e . export NULLRUN_API_KEY=nr_live_... - export NULLRUN_ORGANIZATION_ID=org-123 - python cost_dashboard.py + python examples/cost_dashboard.py """ import os @@ -23,55 +22,60 @@ def fetch_last_24h_spend(api_url: str, org_id: str, api_key: str, workflow_id: s """ Read the rolling 24h spend for one workflow from the backend. - The backend exposes this as `/api/v1/orgs/{org_id}/usage`. The - response shape is `{"workflows": [{...}], "totals": {...}}` — - filter to the workflow of interest on the client side because - the server-side filter is a Phase 4 follow-up. + The canonical endpoint is `/api/v1/orgs/{org_id}/quota` (per + `contracts/openapi.yaml:2306-2321`). The legacy `/usage` path + was removed when the dashboard migrated to the unified + `OrgStatusResponse` shape; this example uses the + dashboard-friendly status endpoint and projects a 24h window + from the `usage_today_cents` field. + + Authentication: ``X-API-Key`` header (per + `contracts/openapi.yaml:59-74`). The SDK never sends a + ``Authorization: Bearer`` token on the user's behalf. """ - headers = {"Authorization": f"Bearer {api_key}"} + headers = {"X-API-Key": api_key} with httpx.Client(timeout=10.0) as client: resp = client.get( - f"{api_url}/api/v1/orgs/{org_id}/usage", - params={"window": "24h"}, + f"{api_url}/api/v1/orgs/{org_id}/quota", headers=headers, ) resp.raise_for_status() body = resp.json() - for wf in body.get("workflows", []): - if wf.get("workflow_id") == workflow_id: - return wf - return { "workflow_id": workflow_id, - "cost_cents": 0, - "tokens": 0, - "calls": 0, - "note": "no events in window", + "cost_cents": body.get("usage_today_cents", 0), + "tokens": body.get("tokens_today", 0), + "calls": body.get("calls_today", 0), + "budget_remaining_cents": body.get("budget_remaining_cents"), } def main() -> None: - api_url = os.environ.get("NULLRUN_API_URL", "http://localhost:8080") - org_id = os.environ.get("NULLRUN_ORGANIZATION_ID", "org-demo") - api_key = os.environ.get("NULLRUN_API_KEY", "demo-key") + api_url = os.environ.get("NULLRUN_API_URL", "https://api.nullrun.io") + api_key = os.environ.get("NULLRUN_API_KEY", "nr_live_demo_key") workflow_id = os.environ.get("NULLRUN_WORKFLOW_ID", "research-agent") nullrun.init( - organization_id=org_id, api_key=api_key, api_url=api_url, ) - print(f"Reading last 24h for workflow {workflow_id!r} in org {org_id!r}...") + # Organization ID is returned by /auth/verify on init and is + # available on the runtime singleton — fetch it after init. + from nullrun import get_runtime + org_id = get_runtime().organization_id or "unknown" + + print(f"Reading today for workflow {workflow_id!r} in org {org_id!r}...") wf = fetch_last_24h_spend(api_url, org_id, api_key, workflow_id) cost_dollars = wf.get("cost_cents", 0) / 100.0 print(f" cost: ${cost_dollars:,.2f}") print(f" tokens: {wf.get('tokens', 0):,}") print(f" calls: {wf.get('calls', 0):,}") - if "note" in wf: - print(f" note: {wf['note']}") + if wf.get("budget_remaining_cents") is not None: + remaining = wf["budget_remaining_cents"] / 100.0 + print(f" remaining budget: ${remaining:,.2f}") # The same number is the truth the dashboard shows — there is no # second source of truth in code. The policy in the Control diff --git a/protos/nullrun/v1/track.proto b/protos/nullrun/v1/track.proto deleted file mode 100644 index 86c1187..0000000 --- a/protos/nullrun/v1/track.proto +++ /dev/null @@ -1,37 +0,0 @@ -syntax = "proto3"; -package nullrun.v1; - -service TrackService { - rpc BatchTrack(BatchTrackRequest) returns (BatchTrackResponse); - rpc Track(TrackRequest) returns (TrackResponse); -} - -message TrackRequest { - string event_id = 1; - string workflow_id = 2; - string event_type = 3; - int64 tokens = 4; - int64 cost_cents = 5; - string tool_name = 6; - bool is_retry = 7; -} - -message BatchTrackRequest { - repeated TrackRequest events = 1; -} - -message TrackResponse { - bool accepted = 1; - string message = 2; -} - -message BatchTrackResponse { - repeated string accepted_event_ids = 1; - repeated Action actions_taken = 2; -} - -message Action { - string type = 1; - string workflow_id = 2; - string reason = 3; -} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6091d81..a5b2bb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ keywords = [ classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", + "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -33,7 +33,6 @@ classifiers = [ dependencies = [ "httpx>=0.27.0,<1.0", - "grpcio>=1.60.0,<2.0", ] [project.optional-dependencies] @@ -74,8 +73,13 @@ dev = [ "mypy>=1.10", "ruff>=0.5", "coverage[toml]>=7.0", - "grpcio-tools>=1.60.0,<2.0", "httpx>=0.27.0,<1.0", + # `nullrun.instrumentation.langgraph` does a module-level + # `from langchain_core.callbacks import BaseCallbackHandler`, + # so any test that imports `nullrun` transitively pulls this in. + "langchain-core>=0.3,<1.0", + # test_ws_push.py exercises a real WebSocket client/server. + "websockets>=12.0", ] [project.urls] diff --git a/src/nullrun/__init__.py b/src/nullrun/__init__.py index b684932..e269f50 100644 --- a/src/nullrun/__init__.py +++ b/src/nullrun/__init__.py @@ -115,8 +115,8 @@ def my_agent(): # Imported lazily so we don't pull the runtime into the namespace # when the user only wants the static helpers. - from nullrun.runtime import NullRunRuntime import nullrun.runtime as _rt_mod + from nullrun.runtime import NullRunRuntime runtime = NullRunRuntime( api_key=api_key, diff --git a/src/nullrun/__version__.py b/src/nullrun/__version__.py index f68998a..606ff38 100644 --- a/src/nullrun/__version__.py +++ b/src/nullrun/__version__.py @@ -1,4 +1,4 @@ """NullRun Platform SDK.""" -__version__ = "0.2.0" +__version__ = "0.3.0" __platform_version__ = "1.0.0" diff --git a/src/nullrun/actions.py b/src/nullrun/actions.py index cf94612..8097087 100644 --- a/src/nullrun/actions.py +++ b/src/nullrun/actions.py @@ -198,17 +198,25 @@ def handle( if self._webhooks: self._queue_webhook(action_type, workflow_id, reason or "Unknown", details) + # Catch only `Exception` here AND re-raise the control-signal + # exceptions that the SDK uses to halt agents. The previous + # `except BaseException` silently swallowed kills, which + # broke the kill contract per `docs/kill-contract.md` §1. + # (P0-0.5 fix.) + from nullrun.breaker.exceptions import ( + WorkflowKilledInterrupt, + WorkflowPausedException, + ) + try: handler(workflow_id, reason or "Unknown", **details) # type: ignore[no-untyped-call] - except BaseException as e: - # Don't let handler exceptions propagate. We catch - # `BaseException` (not just `Exception`) because - # `WorkflowKilledInterrupt` is intentionally a - # `BaseException` subclass — it's a non-recoverable - # control signal, but inside the ActionHandler dispatch - # loop we want the kill to be recorded in history - # (already done above) and swallowed, NOT re-raised into - # the caller's frame. + except (WorkflowKilledInterrupt, WorkflowPausedException): + # Control signals — MUST propagate to halt the agent. + raise + except Exception as e: + # Genuine handler errors (user bugs, network failures in + # webhook callbacks) are logged and swallowed so one + # broken handler does not break the dispatch loop. logger.error(f"Action handler error: {e}") def _default_kill( diff --git a/src/nullrun/breaker/circuit_breaker.py b/src/nullrun/breaker/circuit_breaker.py index 41ce87b..96659dc 100644 --- a/src/nullrun/breaker/circuit_breaker.py +++ b/src/nullrun/breaker/circuit_breaker.py @@ -12,7 +12,7 @@ import time from collections.abc import Callable from enum import Enum -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def __init__( failure_threshold: int = 5, recovery_timeout: float = 30.0, half_open_max_calls: int = 1, - redis_client: Optional[Any] = None, + redis_client: Any | None = None, name: str = "default", ): self._failure_threshold = failure_threshold @@ -96,7 +96,7 @@ def _get_async_lock(self) -> asyncio.Lock: # Redis-based distributed state sharing # ============================================================================= - def _check_global_state(self) -> Optional[str]: + def _check_global_state(self) -> str | None: """ Check if any instance has the circuit open in Redis. diff --git a/src/nullrun/breaker/exceptions.py b/src/nullrun/breaker/exceptions.py index fc90a35..270dd45 100644 --- a/src/nullrun/breaker/exceptions.py +++ b/src/nullrun/breaker/exceptions.py @@ -54,6 +54,51 @@ def __init__( ) +class RateLimitError(NullRunTransportError): + """Raised when the gateway returns HTTP 429 with a ``Retry-After`` + header (or JSON body field). + + Phase 4 of the production-readiness plan: 429s were previously + raised as a generic ``NullRunTransportError`` which lost the + ``retry_after`` (seconds) and ``upgrade_url`` fields the + operator needs to schedule a retry or surface a billing + upgrade prompt. We subclass ``NullRunTransportError`` so + ``except NullRunTransportError`` keeps catching it (no + backwards-incompatible behaviour change). + + Attributes: + retry_after: Seconds the server asks the client to wait + before retrying. ``None`` when the response did not + carry a ``Retry-After`` header. + upgrade_url: Plan-upgrade URL the server returns with the + 429 body (per ``contracts/errors.ts:1-19``). ``None`` + when the response did not include one. + body: The parsed JSON body the server returned with the + 429. Useful for surfacing the original ``error`` / + ``message`` / ``details`` to the operator. + """ + def __init__( + self, + message: str, + source: TransportErrorSource, + endpoint: str, + retry_after: float | None = None, + upgrade_url: str | None = None, + body: dict[str, Any] | None = None, + **details: Any, + ) -> None: + self.retry_after = retry_after + self.upgrade_url = upgrade_url + self.body = body or {} + # Surface the retry_after in the canonical detail dict + # too so callers inspecting ``exc.details`` see it. + if retry_after is not None: + details.setdefault("retry_after", retry_after) + if upgrade_url is not None: + details.setdefault("upgrade_url", upgrade_url) + super().__init__(message, source, endpoint, **details) + + class BreakerTransportError(BreakerError): """ Raised when transport layer fails and events cannot be delivered. @@ -258,7 +303,8 @@ class WorkflowKilledException(BaseException): non-recoverable signal and should not be caught by generic ``except Exception`` clauses. Only ``except BaseException`` or the explicit ``except WorkflowKilledInterrupt`` reliably stops the work. - See ``docs/kill-contract.md`` §6 for the full rationale. + See the kill-contract design note in the gateway repository + for the full rationale. """ def __init__(self, workflow_id: str, reason: str) -> None: @@ -295,8 +341,9 @@ class WorkflowKilledInterrupt(WorkflowKilledException): silently bypass the kill. * ``except BaseException`` catches it, like the stdlib interrupts. - See ``docs/kill-contract.md`` §6 for the full rationale, including - the four-level coverage model and the decision tree for users. + See the kill-contract design note in the gateway repository for + the full rationale, including the four-level coverage model and + the decision tree for users. Fields: workflow_id: The workflow that was killed. diff --git a/src/nullrun/common/__init__.py b/src/nullrun/common/__init__.py deleted file mode 100644 index 271dfc1..0000000 --- a/src/nullrun/common/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -NullRun Common - Shared utilities for NullRun platform. - -This module contains common utilities shared across all NullRun products. -""" - -__all__ = [] diff --git a/src/nullrun/decision_history.py b/src/nullrun/decision_history.py index a5468ac..57b31ec 100644 --- a/src/nullrun/decision_history.py +++ b/src/nullrun/decision_history.py @@ -143,10 +143,9 @@ class DecisionHistoryRecorder: session = recorder.stop_recording() session.save("recording.json") - # Local re-emission (re-runs the cost line items through the - # local tracker; no network calls to the gateway) + # Offline inspection session = RecordingSession.load("recording.json") - results = recorder.replay_locally(session) + summary = recorder.estimate_cost(session) """ def __init__(self, runtime: Optional["NullRunRuntime"] = None): @@ -247,56 +246,6 @@ def stop_recording(self) -> RecordingSession | None: return session - def replay_locally( - self, - session: RecordingSession, - on_event: Callable[[RecordedEvent], None] | None = None, - ) -> list[dict[str, Any]]: - """ - Re-emit a recorded session's events through the local runtime tracker. - - IMPORTANT: This is a local-only operation. It does NOT call any LLM - provider and does NOT contact the gateway. It re-runs each event - through `runtime.track()` so the local cost/usage tracker sees the - same line items. Useful for offline cost analysis and integration - tests. - - For true server-side re-evaluation of a recorded decision, use the - backend's Decision History API: GET /api/v1/orgs/:org_id/decision-history. - """ - results: list[dict[str, Any]] = [] - for event in session.events: - result = self.runtime.track(event.raw_event) - results.append(result) - if on_event is not None: - on_event(event) - return results - - def replay_event(self, event: RecordedEvent) -> dict[str, Any]: - """ - Re-emit a single recorded event through the local runtime tracker. - - Note: This only re-tracks the event locally through the runtime. - It does NOT communicate with the backend and does NOT re-execute - any LLM call. - """ - return self.runtime.track(event.raw_event) - - def replay_from_file(self, path: str) -> list[dict[str, Any]]: - """ - Load a recorded session from disk and re-emit it locally. - - Args: - path: Path to the JSON file produced by `RecordingSession.save()` - - Returns: - List of results from each event - - See `replay_locally()` for the honest scope of this method. - """ - session = RecordingSession.load(path) - return self.replay_locally(session) - def estimate_cost(self, session: RecordingSession) -> dict[str, Any]: """ Estimate total cost from a recorded session. diff --git a/src/nullrun/decorators.py b/src/nullrun/decorators.py index 6a2c5c0..fb4873e 100644 --- a/src/nullrun/decorators.py +++ b/src/nullrun/decorators.py @@ -38,13 +38,11 @@ def researcher(q): import inspect import logging import os -import re from collections.abc import Callable from typing import Any, TypeVar -from nullrun.instrumentation.openai import is_patched, patch_openai -from nullrun.runtime import NullRunRuntime, get_runtime from nullrun.context import get_workflow_id +from nullrun.runtime import NullRunRuntime, get_runtime from nullrun.tracing import ( SpanContext, create_child_span, @@ -58,7 +56,24 @@ def researcher(q): F = TypeVar("F", bound=Callable[..., Any]) -SENSITIVE_ARG_KEYS = {"password", "token", "secret", "api_key", "key", "auth", "authorization"} +# Phase 3: expanded sensitive-arg keys. The original 7-key set +# missed obvious PII tokens and credential names; ``@sensitive`` and +# ``_safe_kwargs`` would have shipped them in the audit log. +# Matching is case-insensitive (see ``_safe_kwargs`` which calls +# ``.lower()`` on the key). +SENSITIVE_ARG_KEYS = frozenset({ + # Credentials / secrets + "password", "passwd", "pwd", + "token", "secret", "api_key", "apikey", + "key", "auth", "authorization", "bearer", + "session", "session_id", "cookie", + "access_token", "refresh_token", "id_token", + "private_key", "secret_key", + # PII + "email", "phone", "ssn", + "credit_card", "credit_card_number", "cvv", "cvc", "pin", + "otp", "mfa", +}) def _safe_repr(value: object, max_len: int = 50) -> str: @@ -70,41 +85,120 @@ def _safe_repr(value: object, max_len: int = 50) -> str: def _safe_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: - """Mask sensitive kwargs.""" + """Mask sensitive kwargs (case-insensitive).""" return { k: "***" if k.lower() in SENSITIVE_ARG_KEYS else _safe_repr(v) for k, v in kwargs.items() } -# SEC-29: regex used to strip the `details={...}` payload from an -# exception's string form before it lands in the span_end audit event. +# SEC-29: strip the `details={...}` payload from an exception's +# string form before it lands in the span_end audit event. # `details` is caller-supplied structured data — it can contain raw -# tool args, kwargs, or other user-controlled content that we do not -# want to ship to the audit log. The two pattern variants match the -# shape produced by NullRunBlockedException.__str__ / NullRunTransportError.__str__. -_DETAILS_REDACTED = "details=" -_DETAILS_RE = re.compile(r"details=\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}") +# tool args, kwargs, or other user-controlled content that we do +# not want to ship to the audit log. The previous implementation +# used a single-level regex which failed for nested dicts and for +# dict values that contain `{` / `}` in their string content. +# Phase 3 replaces it with a balanced-brace walker that handles +# arbitrary nesting depth and balanced strings. +_DETAILS_REDACTED = "" # the payload only — caller prepends "details=" + + +def _strip_details_balanced(text: str) -> str: + """Replace every top-level ``details={...}`` substring with + ``details=``. + + Walks the string with a small state machine that tracks + brace depth and string-literal state (``"…\\"…"'…`` and + ``'…\\'…'``). At depth 1 the opening ``{`` was just + consumed; when the depth returns to 0 the substring is + replaced. The walker tolerates ``{`` and ``}`` inside + string values so it does not under-report nesting. + + Only ``details={…}`` constructs are redacted; a bare + ``details=foo`` (no opening brace) is left as-is so we + don't lose the user's free-form text. + """ + out: list[str] = [] + i = 0 + n = len(text) + needle = "details=" + while i < n: + idx = text.find(needle, i) + if idx < 0: + out.append(text[i:]) + break + + # Append everything before the "details=" token. + out.append(text[i:idx]) + + # Look ahead past optional whitespace for the opening + # brace. If absent, this isn't a `details={…}` construct — + # preserve the literal "details=…" up to the next ',', ')', + # or newline (so the user's free-form text is kept), and + # continue scanning. + j = idx + len(needle) + while j < n and text[j] in " \t": + j += 1 + if j >= n or text[j] != "{": + end = j + while end < n and text[end] not in ",)\n": + end += 1 + out.append(text[idx:end]) + i = end + continue + + # It IS a `details={…}` construct. Append the literal + # "details=" prefix, then walk to the matching '}' and + # replace the whole payload (including the braces) with + # ````. + out.append(text[idx:j]) # "details=" plus any leading whitespace + depth = 0 + in_str: str | None = None + k = j + while k < n: + ch = text[k] + if in_str is not None: + if ch == "\\" and k + 1 < n: + k += 2 + continue + if ch == in_str: + in_str = None + elif ch in ('"', "'"): + in_str = ch + elif ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + k += 1 + break + k += 1 + out.append(_DETAILS_REDACTED) + i = k + return "".join(out) def _safe_error_str(error: BaseException | None) -> str | None: - """Return a log-safe string for `error`. + """Return a log-safe string for ``error``. SEC-29: ``str(error)`` for our blocked / transport exceptions embeds the caller's ``details`` payload (free-form structured data the SDK has no way to scrub). That payload can include raw tool args / kwargs. We strip the ``details={...}`` substring - before handing the string to ``track_event`` so the audit log - only sees the stable envelope (workflow_id, reason, action, - tool_name) and never the caller's arbitrary data. - - Non-None return; returns ``None`` only when `error` is None so - callers can pass the result straight to ``_emit_span_end``. + via a balanced-brace walker (handles nested dicts and dict + values that themselves contain ``{`` / ``}``) before handing + the string to ``track_event`` so the audit log only sees the + stable envelope (workflow_id, reason, action, tool_name) and + never the caller's arbitrary data. + + Non-None return; returns ``None`` only when ``error`` is None + so callers can pass the result straight to ``_emit_span_end``. """ if error is None: return None raw = str(error) - return _DETAILS_RE.sub(_DETAILS_REDACTED, raw) + return _strip_details_balanced(raw) # Module-level cache for the runtime instance — the @protect decorator needs @@ -139,8 +233,12 @@ def _get_or_create_runtime() -> NullRunRuntime: the SDK has no local mode: a missing API key must be a hard error, not a silent allow-all. - Tries to patch OpenAI on first creation so the auto-instrumentation - path picks up the runtime the user will eventually use. + The previous OpenAI v0.x auto-patch hook was removed in 0.4.0: + `openai>=1.0` does not expose `ChatCompletion.create` as an + attribute. All OpenAI v1.0+ traffic is now tracked + vendor-independently by the httpx transport hook in + `nullrun.instrumentation.auto`, which is wired by + `nullrun.init()` — not at the lazy-resolve path here. """ global _runtime @@ -149,13 +247,6 @@ def _get_or_create_runtime() -> NullRunRuntime: _runtime = NullRunRuntime.get_instance() - if not is_patched(): - try: - patch_openai() - logger.info("OpenAI auto-patch enabled") - except Exception as e: - logger.debug(f"OpenAI patching skipped: {e}") - logger.info("NullRun runtime initialized: mode=cloud") return _runtime @@ -466,14 +557,28 @@ def _enforce_sensitive_tool( ) from exc # Defense in depth (ADR-008 Rule 1 + Rule 2): if `runtime.execute` - # ever returns a dict with `decision_source` starting with - # `FALLBACK_` (i.e. transport failed but a synthetic allow slipped - # through — currently impossible when runtime passes - # `on_transport_error="raise"`, but easy to regress), honor the - # gate's fail-CLOSED policy here. The body still must not run. + # ever returns a dict with `decision_source` indicating a + # transport failure (i.e. `FALLBACK_*` from the historical + # `fallback_mode` path, or a `TransportErrorSource` enum value + # like `NETWORK_ERROR` / `GATEWAY_ERROR` / `BREAKER_OPEN` from + # the new `on_transport_error="open"` path), honor the gate's + # fail-CLOSED policy here. The body still must not run. if isinstance(result, dict): decision_source = result.get("decision_source", "") - if isinstance(decision_source, str) and decision_source.startswith("FALLBACK_"): + is_transport_fallback = ( + isinstance(decision_source, str) + and ( + decision_source.startswith("FALLBACK_") + or decision_source.startswith("FALLBACK") + or decision_source in { + "NETWORK_ERROR", + "GATEWAY_ERROR", + "BREAKER_OPEN", + "AUTH_ERROR", + } + ) + ) + if is_transport_fallback: if fail_open: logger.warning( f"sensitive tool pre-check for {fn.__name__!r} returned " diff --git a/src/nullrun/flow/__init__.py b/src/nullrun/flow/__init__.py deleted file mode 100644 index 23735c1..0000000 --- a/src/nullrun/flow/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -NullRun Flow - AI Agent Orchestration. - -Third product in the NullRun platform. -Placeholder for future implementation. -""" - -__all__ = [] diff --git a/src/nullrun/gate/__init__.py b/src/nullrun/gate/__init__.py deleted file mode 100644 index e304046..0000000 --- a/src/nullrun/gate/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -NullRun Gate - AI Agent Gateway / Routing. - -Second product in the NullRun platform. -Placeholder for future implementation. -""" - -__all__ = [] diff --git a/src/nullrun/grpc_transport.py b/src/nullrun/grpc_transport.py deleted file mode 100644 index f521923..0000000 --- a/src/nullrun/grpc_transport.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -gRPC transport for high-performance event ingestion. - -Uses binary protobuf + HTTP/2 to achieve 30-50% overhead reduction -compared to REST/JSON for high-frequency /track operations. -""" -from __future__ import annotations - -import os -from typing import Optional - -import grpc - -# These will be generated by grpcio-tools from the proto file shipped in ./protos/ -# Run: python -m grpc_tools.protoc -I./protos --python_out=./src/nullrun/v1 --grpc_python_out=./src/nullrun/v1 ./protos/nullrun/v1/track.proto -try: - from nullrun.v1 import track_pb2, track_pb2_grpc -except ImportError: - # Proto files not generated yet - track_pb2 = None - track_pb2_grpc = None - - -class GrpcTransport: - """ - High-performance gRPC transport for event ingestion. - - Usage: - transport = GrpcTransport( - api_url="localhost:50051", - api_key="your-api-key" - ) - result = transport.batch_track([...]) - """ - - def __init__( - self, - api_url: str, - api_key: str, - use_tls: bool = True, - ): - """ - Initialize gRPC transport. - - Args: - api_url: gRPC server address (e.g., "localhost:50051") - api_key: API key for authentication - use_tls: Whether to use TLS (default True in production) - """ - self.api_url = api_url - self.api_key = api_key - self.use_tls = use_tls - - if track_pb2 is None or track_pb2_grpc is None: - raise RuntimeError( - "Proto files not generated. Run:\n" - "make protos # from the SDK repo root" - ) - - # Create channel with optional TLS - if use_tls: - # In production, configure proper TLS credentials - credentials = grpc.ssl_channel_credentials() - self.channel = grpc.secure_channel(api_url, credentials) - else: - self.channel = grpc.insecure_channel(api_url) - - self.stub = track_pb2_grpc.TrackServiceStub(self.channel) - - def _make_metadata(self) -> list[tuple[str, str]]: - """Create gRPC metadata with auth headers.""" - return [ - ("x-api-key", self.api_key), - ] - - def track( - self, - event_id: str, - workflow_id: str, - tokens: int, - cost_cents: int, - tool_name: Optional[str] = None, - is_retry: bool = False, - event_type: str = "", - ) -> tuple[bool, str]: - """ - Track a single event via gRPC. - - Returns: - Tuple of (accepted, message) - """ - request = track_pb2.TrackRequest( - event_id=event_id, - workflow_id=workflow_id, - event_type=event_type, - tokens=tokens, - cost_cents=cost_cents, - tool_name=tool_name or "", - is_retry=is_retry, - ) - - try: - response = self.stub.Track(request, metadata=self._make_metadata()) - return response.accepted, response.message - except grpc.RpcError as e: - return False, f"gRPC error: {e.code()}: {e.details()}" - - def batch_track( - self, - events: list[dict], - ) -> dict: - """ - Track multiple events via gRPC batch API. - - Args: - events: List of event dicts with keys: - - event_id: str - - workflow_id: str - - tokens: int - - cost_cents: int - - tool_name: Optional[str] - - is_retry: bool - - event_type: str (optional) - - Returns: - Dict with: - - accepted_event_ids: List[str] - - actions_taken: List[dict] - """ - proto_events = [] - for event in events: - proto_events.append(track_pb2.TrackRequest( - event_id=event["event_id"], - workflow_id=event["workflow_id"], - event_type=event.get("event_type", ""), - tokens=event["tokens"], - cost_cents=event["cost_cents"], - tool_name=event.get("tool_name", "") or "", - is_retry=event.get("is_retry", False), - )) - - request = track_pb2.BatchTrackRequest(events=proto_events) - - try: - response = self.stub.BatchTrack(request, metadata=self._make_metadata()) - return { - "accepted_event_ids": list(response.accepted_event_ids), - "actions_taken": [ - {"type": a.type, "workflow_id": a.workflow_id, "reason": a.reason} - for a in response.actions_taken - ], - } - except grpc.RpcError as e: - return { - "accepted_event_ids": [], - "actions_taken": [], - "error": f"gRPC error: {e.code()}: {e.details()}", - } - - def close(self): - """Close the gRPC channel.""" - if hasattr(self, "channel"): - self.channel.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - - -def create_grpc_transport( - api_url: Optional[str] = None, - api_key: Optional[str] = None, -) -> Optional[GrpcTransport]: - """ - Factory function to create GrpcTransport if gRPC is available. - - Returns None if: - - NULLRUN_USE_GRPC env var is not set - - Required proto files are not generated - """ - if not os.getenv("NULLRUN_USE_GRPC"): - return None - - url = api_url or os.getenv("NULLRUN_GRPC_URL", "localhost:50051") - key = api_key or os.getenv("NULLRUN_API_KEY", "") - - if not key: - return None - - try: - return GrpcTransport(api_url=url, api_key=key) - except RuntimeError: - # Proto files not generated - return None \ No newline at end of file diff --git a/src/nullrun/instrumentation/__init__.py b/src/nullrun/instrumentation/__init__.py index d74d6b0..8028127 100644 --- a/src/nullrun/instrumentation/__init__.py +++ b/src/nullrun/instrumentation/__init__.py @@ -6,16 +6,18 @@ live in `nullrun.toolbox` (e.g. `nullrun.toolbox.langgraph.wrapper`, which replaced `nullrun.instrumentation.langgraph.instrument` in Phase 1 Commit 6). + +The v0.x `openai.ChatCompletion.create` patcher was removed in +0.4.0 — openai >= 1.0 does not expose that attribute. All OpenAI +v1.0+ traffic is tracked vendor-independently by the httpx +transport hook in `nullrun.instrumentation.auto`. """ from nullrun.instrumentation.auto import auto_instrument, is_auto_instrumented from nullrun.instrumentation.langgraph import NullRunCallback -from nullrun.instrumentation.openai import patch_openai, unpatch_openai __all__ = [ "NullRunCallback", - "patch_openai", - "unpatch_openai", "auto_instrument", "is_auto_instrumented", ] diff --git a/src/nullrun/instrumentation/auto.py b/src/nullrun/instrumentation/auto.py index f6fe2bb..ba38d22 100644 --- a/src/nullrun/instrumentation/auto.py +++ b/src/nullrun/instrumentation/auto.py @@ -253,7 +253,8 @@ def _match_extractor(host: str) -> Callable[[bytes, int], ExtractedUsage | None] def _check_kill_before_send(runtime: Any, request: httpx.Request) -> None: """ - L2 of the kill contract (see docs/kill-contract.md §2). + L2 of the kill contract (see the kill-contract design note in + the gateway repository). Pre-request gate: inspects the cached remote state for the workflow bound to the current context / API key. If the workflow has been @@ -264,6 +265,9 @@ def _check_kill_before_send(runtime: Any, request: httpx.Request) -> None: No-ops when: - runtime is missing + - runtime is a stub / test double that does not expose + `_resolve_workflow_id` or `_remote_states` (the unit-test + transport hooks run against a plain MagicMock / namespace). - the request host is not a known LLM provider (out of scope) - no workflow can be resolved (no active context, no API key binding) - the cached state is anything other than Killed / Paused @@ -282,7 +286,14 @@ def _check_kill_before_send(runtime: Any, request: httpx.Request) -> None: host = request.url.host if _match_extractor(host) is None: return - workflow_id = runtime._resolve_workflow_id(None) + # Tolerate test stubs (MagicMock, namespace objects) that do not + # implement the runtime's full surface. The L2 check is + # observability, not enforcement, so a missing workflow resolver + # or remote-state dict means "no kill possible" — a safe no-op. + resolve = getattr(runtime, "_resolve_workflow_id", None) + if resolve is None: + return + workflow_id = resolve(None) if not workflow_id: return state = getattr(runtime, "_remote_states", {}).get(workflow_id, {}) @@ -396,6 +407,7 @@ def _emit( body: bytes, status: int, ) -> None: + _safe_bump_coverage(self._runtime, "_coverage_tracked", host) try: self._runtime.track( { @@ -499,6 +511,7 @@ def _emit( body: bytes, status: int, ) -> None: + _safe_bump_coverage(self._runtime, "_coverage_tracked", host) try: self._runtime.track( { @@ -555,6 +568,36 @@ def _fingerprint_for(host: str, body: bytes, status: int) -> str: return h.hexdigest()[:16] +def _fingerprint_for_event_dict(event: dict[str, Any]) -> str: + """Stable fingerprint for a generic event dict. + + Phase 3 of the production-readiness plan: `runtime.track_event` + was the only emit path that did NOT set `_fingerprint`, so two + observers firing for the same LLM call (the user's manual + `track_event` plus the httpx transport hook) produced two + `/track` POSTs. This helper gives the dedup LRU a stable key + derived from the event's content. + + Args: + event: The event dict (already-enriched is fine; the + fingerprint includes the JSON-serialised body). + + Returns: + 16-char hex digest (same length as ``_fingerprint_for``). + """ + try: + payload = json.dumps(event, sort_keys=True, default=str).encode("utf-8") + except (TypeError, ValueError): + # Last-resort fallback: if the event is not JSON-serialisable + # (e.g. a custom dataclass the user did not register), hash + # the type name + repr to keep the dedup LRU functional. + payload = repr(event).encode("utf-8") + h = hashlib.sha256() + h.update(b"event|") + h.update(payload) + return h.hexdigest()[:16] + + # --------------------------------------------------------------------------- # D3: patch_httpx — idempotent __init__ wrap # --------------------------------------------------------------------------- @@ -1003,3 +1046,44 @@ def _fingerprint_is_seen(state: OrderedDict[str, None], fp: str) -> bool: if len(state) > DEDUP_LRU_MAX: state.popitem(last=False) return False + + +# --------------------------------------------------------------------------- +# Coverage counters (dashboard / observability) +# --------------------------------------------------------------------------- +# The transport hook tracks WHICH hosts the SDK actually saw during a +# session. The dashboard surfaces these so operators can detect when +# an LLM call is going to a host the SDK does not yet know how to +# extract usage from. Counter is a plain dict[host, count] stored on +# the runtime instance under a configurable attribute name — the +# helper tolerates stub runtimes (test doubles, MagicMock) that do +# not expose the attribute. + +def _safe_bump_coverage(runtime: Any, target_attr: str, host: str) -> None: + """ + Bump `runtime.[host] += 1` if the runtime supports + the attribute. Silently no-ops on stub runtimes (e.g. MagicMock + without the attribute, or a custom test double). + + The contract is: callers in `nullrun.instrumentation.auto` and + `nullrun.instrumentation.auto_requests` always call this helper + rather than touching the attribute directly, so a single missing + attribute does not crash the observation path. + + Args: + runtime: The runtime instance (or any object). + target_attr: Name of the coverage dict attribute, e.g. + "_coverage_seen", "_coverage_tracked", or + "_coverage_streaming_skipped". + host: The host string (e.g. "api.openai.com"). + """ + if not host: + return + target = getattr(runtime, target_attr, None) + if target is None: + return + if hasattr(target, "__setitem__"): + try: + target[host] = int(target.get(host, 0)) + 1 + except Exception as e: # pragma: no cover — defensive + logger.debug("coverage bump failed on %s: %s", target_attr, e) diff --git a/src/nullrun/instrumentation/auto_requests.py b/src/nullrun/instrumentation/auto_requests.py index b1a754c..70426c9 100644 --- a/src/nullrun/instrumentation/auto_requests.py +++ b/src/nullrun/instrumentation/auto_requests.py @@ -29,10 +29,11 @@ - Double-emission guard: `request._nullrun_tracked = True` is set on the PreparedRequest after a successful track, so a future `urllib3` patch (which `requests` uses under the hood) can skip - already-tracked requests. See plan section P2 / "requests ↔ urllib3". + already-tracked requests. -`aiohttp` is deliberately out of scope for this phase — see -`docs/known-limitations.md` and the plan's open questions. +`aiohttp` is deliberately out of scope for this phase — see the +known-limitations section in the gateway repository for the up-to-date +list of unsupported HTTP clients. """ from __future__ import annotations @@ -144,7 +145,7 @@ def patch_requests(runtime: Any) -> bool: if _requests_patched: return True try: - import requests # type: ignore[import-not-found] + import requests # noqa: F401 # type: ignore[import-not-found] except ImportError: logger.debug("requests not installed; auto-instrumentation skipped") return False diff --git a/src/nullrun/instrumentation/openai.py b/src/nullrun/instrumentation/openai.py deleted file mode 100644 index e60a5d2..0000000 --- a/src/nullrun/instrumentation/openai.py +++ /dev/null @@ -1,236 +0,0 @@ -""" -OpenAI instrumentation for NullRun SDK. - -DEPRECATED: This module patches the v0.x attribute path -(`openai.ChatCompletion.create`) which is no longer exposed by -`openai>=1.0` clients. The v1.0+ Python SDK does not expose -`ChatCompletion` as an attribute — `openai.chat.completions.create(...)` -is the only supported entry point. - -Use `nullrun.instrumentation.auto_instrument` (or just `nullrun.init`) -instead — it patches `httpx.Client` so all vendor SDKs (openai, -anthropic, mistral, google-genai, cohere, bedrock) are tracked -vendor-independently. `auto_instrument` covers OpenAI v1.0+ and is -the supported path going forward. - -This module is preserved for backward compatibility with v0.x -OpenAI clients. The patches are best-effort — they emit a warning -when the v0.x attribute path is not present and stay inactive. - -Provides automatic patching of OpenAI API calls for zero-effort tracking. -""" - -import logging -import time -from collections.abc import Callable -from typing import Any - -logger = logging.getLogger(__name__) - -# Store original function -_original_chat_create: Callable[..., Any] | None = None -_original_embed_create: Callable[..., Any] | None = None -_patched = False - - -def _patched_chat_create(*args: Any, **kwargs: Any) -> Any: - """ - Patched version of openai.ChatCompletion.create. - - Tracks all calls automatically. - """ - from nullrun.runtime import get_runtime - - runtime = get_runtime() - - # Capture start time - start_time = time.time() - - # Call original - response = _original_chat_create(*args, **kwargs) # type: ignore[misc] - - # Calculate latency - latency_ms = int((time.time() - start_time) * 1000) - - # Extract usage - usage = response.get("usage", {}) if isinstance(response, dict) else None - if usage: - total_tokens = usage.get("total_tokens", 0) - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - else: - total_tokens = 0 - prompt_tokens = 0 - completion_tokens = 0 - - # Get model - model = kwargs.get("model") or (args[0] if args else "unknown") - - # Commit 4: track_llm now takes (input_tokens, output_tokens) - # instead of (tokens, cost_cents). The backend computes cost - # server-side from the split token counts + the org's pricing - # policy. Splitting prompt vs completion matters because most - # models price them differently. - # - # We still pass prompt/completion via metadata for backwards- - # compatible observability (the backend also reads them from - # the new top-level fields). - - # Track - try: - runtime.track_llm( - input_tokens=prompt_tokens, - output_tokens=completion_tokens, - model=model, - latency_ms=latency_ms, - metadata={ - "provider": "openai", - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - }, - ) - logger.debug( - f"OpenAI tracked: model={model}, in={prompt_tokens}, out={completion_tokens}" - ) - except Exception as e: - logger.warning(f"Failed to track OpenAI call: {e}") - - return response - - -def _patched_embed_create(*args: Any, **kwargs: Any) -> Any: - """ - Patched version of openai.Embedding.create. - - Tracks embedding calls. - """ - from nullrun.runtime import get_runtime - - runtime = get_runtime() - start_time = time.time() - - response = _original_embed_create(*args, **kwargs) # type: ignore[misc] - - latency_ms = int((time.time() - start_time) * 1000) - - # Extract usage - usage = response.get("usage", {}) if isinstance(response, dict) else None - tokens = usage.get("total_tokens", 0) if usage else 0 - - model = kwargs.get("model") or (args[0] if args else "unknown") - - # Commit 4: embeddings don't split prompt/completion the way - # completions do — OpenAI returns just `total_tokens`. We treat - # all of it as input_tokens (output is 0). Backend computes - # cost from the org's embedding pricing. - try: - runtime.track_llm( - input_tokens=tokens, - output_tokens=0, - model=model, - latency_ms=latency_ms, - metadata={"provider": "openai", "type": "embedding"}, - ) - except Exception as e: - logger.warning(f"Failed to track embedding call: {e}") - - return response - - -def patch_openai() -> None: - """ - Patch OpenAI API to automatically track all calls. - - This is a global patch that affects all subsequent OpenAI calls. - - Usage: - import openai - from nullrun.instrumentation import patch_openai - - patch_openai() - - # All calls now tracked automatically - openai.ChatCompletion.create(model="gpt-4", messages=[...]) - - Note: - Call this AFTER importing openai but BEFORE making any calls. - This modifies openai.ChatCompletion.create in place. - """ - global _original_chat_create, _original_embed_create, _patched - - if _patched: - logger.warning("OpenAI already patched") - return - - try: - import openai - except ImportError: - logger.warning("OpenAI package not installed") - return - - # Store originals - _original_chat_create = openai.ChatCompletion.create # type: ignore[attr-defined] - _original_embed_create = openai.Embedding.create # type: ignore[attr-defined] - - # Apply patches - openai.ChatCompletion.create = _patched_chat_create # type: ignore[attr-defined] - openai.Embedding.create = _patched_embed_create # type: ignore[attr-defined] - - _patched = True - logger.info("OpenAI API patched for automatic tracking") - - -def unpatch_openai() -> None: - """ - Restore original OpenAI functions. - - Usage: - from nullrun.instrumentation import unpatch_openai - - unpatch_openai() - """ - global _original_chat_create, _original_embed_create, _patched - - if not _patched: - logger.warning("OpenAI not patched") - return - - try: - import openai - - if _original_chat_create: - openai.ChatCompletion.create = _original_chat_create # type: ignore[attr-defined] - if _original_embed_create: - openai.Embedding.create = _original_embed_create # type: ignore[attr-defined] - - _patched = False - logger.info("OpenAI API restored") - except ImportError: - logger.warning("Could not import openai to unpatch") - - -def is_patched() -> bool: - """Check if OpenAI is currently patched.""" - return _patched - - -class OpenAIPatcher: - """ - Context manager for OpenAI patching. - - Usage: - from nullrun.instrumentation import OpenAIPatcher - - with OpenAIPatcher(): - openai.ChatCompletion.create(...) # tracked - # Outside context, original behavior restored - """ - - def __enter__(self) -> "OpenAIPatcher": - patch_openai() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - unpatch_openai() - return False diff --git a/src/nullrun/observability.py b/src/nullrun/observability.py index 40790f5..1953044 100644 --- a/src/nullrun/observability.py +++ b/src/nullrun/observability.py @@ -2,7 +2,22 @@ src/nullrun/observability.py Structured logging + metrics for production readiness. -This is a new module - add to src/nullrun/ and import in runtime.py and transport.py. + +Exposes: + * `get_logger(name)` — a `StructuredLogger` factory that tags every + log record with a `structured` extra dict for JSON ingest pipelines. + * `TenantFilter` / `configure_logging_with_tenant_context()` — a + `logging.Filter` that attaches `organization_id` / `api_key_id` + to every record so logs can be partitioned per tenant in the + downstream pipeline. Opt-in: call + `configure_logging_with_tenant_context()` once at startup. + * `metrics` — a global `MetricsRegistry` (thread-safe) for SDK + counters. See `MetricsRegistry.inc_transport` / + `MetricsRegistry.inc_runtime` / `MetricsRegistry.set_transport` + for the supported write paths. Direct `metrics.transport.x = N` + assignment is also supported but bypasses the lock. + * `timed(logger, event)` — context manager for measuring + operation time. """ from __future__ import annotations @@ -81,10 +96,13 @@ class TenantFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: # Import here to avoid circular imports - from nullrun.context import get_org_id, get_organization_id, get_api_key_id + from nullrun.context import get_api_key_id, get_organization_id - # Add tenant fields to the record for structured logging - record.org_id = get_org_id() or "none" + # Add tenant fields to the record for structured logging. + # Only the canonical `organization_id` is set; the legacy + # `org_id` field is gone (was tied to the deprecated + # `get_org_id()` helper, removed in 0.4.0 alongside the + # workspace_id → organization_id migration). record.organization_id = get_organization_id() or "none" record.api_key_id = get_api_key_id() or "none" @@ -275,47 +293,4 @@ def timed(logger: StructuredLogger, event: str, **kwargs: Any) -> Generator[None detail=str(exc)[:200], **kwargs, ) - raise - - -# ---------------------------------------------------------------- -# How to integrate in transport.py and runtime.py -# ---------------------------------------------------------------- -# -# In transport.py replace: -# import logging -# logger = logging.getLogger(__name__) -# -# With: -# from nullrun.observability import get_logger, metrics, timed -# logger = get_logger("transport") -# -# In _do_flush_locked(): -# with timed(logger, "batch_flush", batch_size=len(batch)): -# result = self._circuit_breaker.call(self._send_batch, batch) -# metrics.transport.batches_sent += 1 -# metrics.transport.events_sent += len(batch) -# -# On flush error: -# metrics.transport.batches_failed += 1 -# metrics.transport.last_error = str(exc)[:200] -# -# On enqueue(): -# metrics.transport.events_enqueued += 1 -# -# On drop (buffer overflow): -# metrics.transport.events_dropped += 1 -# -# In circuit_breaker.py _on_success / _on_failure: -# if newly_opened: -# metrics.transport.circuit_breaker_opens += 1 -# -# In runtime.py track(): -# metrics.runtime.track_calls += 1 -# -# In runtime.py execute(): -# metrics.runtime.execute_calls += 1 -# if result.allowed: -# metrics.runtime.execute_allowed += 1 -# else: -# metrics.runtime.execute_blocked += 1 \ No newline at end of file + raise \ No newline at end of file diff --git a/src/nullrun/py.typed b/src/nullrun/py.typed index e69de29..6fa808b 100644 --- a/src/nullrun/py.typed +++ b/src/nullrun/py.typed @@ -0,0 +1,4 @@ +# Marker file for PEP 561 — see https://peps.python.org/pep-0561/. +# The presence of this file tells type checkers (mypy, pyright, etc.) +# that the `nullrun` package ships inline type annotations and can +# be type-checked by downstream users. Do NOT delete this file. diff --git a/src/nullrun/runtime.py b/src/nullrun/runtime.py index bd67182..aa83612 100644 --- a/src/nullrun/runtime.py +++ b/src/nullrun/runtime.py @@ -26,9 +26,10 @@ The "Opt-out" column makes it explicit that `NULLRUN_SKIP_BUDGET_CHECK=1` is a **different category** of action than `NULLRUN_SENSITIVE_FAIL_OPEN=1` (bypass vs. change semantics), despite -the similar naming. See `docs/adr/008-sdk-preflight-fail-policy.md` -for the full rules, including transport error classification -(`FALLBACK_NETWORK_ERROR` / `FALLBACK_GATEWAY_ERROR` / `FALLBACK_BREAKER_OPEN`). +the similar naming. The full ADR (including transport error +classification into `NETWORK_ERROR` / `GATEWAY_ERROR` / +`BREAKER_OPEN` via `TransportErrorSource`) lives in the gateway +repository; see the link in the README. """ import asyncio @@ -39,7 +40,7 @@ import time import uuid from collections import OrderedDict, defaultdict, deque -from collections.abc import MutableMapping +from collections.abc import Callable, MutableMapping from dataclasses import dataclass, field from typing import Any, Optional, TypeVar @@ -53,7 +54,6 @@ NullRunAuthenticationError, NullRunBlockedException, RetryStormException, - WorkflowKilledException, WorkflowKilledInterrupt, WorkflowPausedException, ) @@ -67,7 +67,6 @@ get_workflow_id, ) from nullrun.decision_history import DecisionHistoryRecorder -from nullrun.grpc_transport import GrpcTransport, create_grpc_transport from nullrun.observability import metrics from nullrun.transport import DecisionSource, FallbackMode, FlushConfig, Transport @@ -375,7 +374,6 @@ def __init__( self._max_retries = 3 self._debug = debug self._transport: Transport | None = None - self._grpc_transport: GrpcTransport | None = None # Local enforcement state # PER-WORKFLOW cost tracking - was a global counter before (BUG) @@ -400,9 +398,66 @@ def __init__( self._local_loop_threshold = 6 self._local_rate_limit = 1000 # calls per minute + # Coverage counters (Phase 3 of the production-readiness plan). + # The instrumentation layer in `nullrun.instrumentation.auto` + # calls `_safe_bump_coverage(runtime, "_coverage_seen" / + # "_coverage_tracked" / "_coverage_streaming_skipped", host)` + # so the dashboard can show "which LLM hosts the SDK is + # seeing vs. successfully tracking". Previous versions + # relied on `_safe_bump_coverage` to no-op when these + # attributes were missing — the dashboard's coverage tab + # was always empty. + self._coverage_seen: dict[str, int] = {} + self._coverage_tracked: dict[str, int] = {} + self._coverage_streaming_skipped: dict[str, int] = {} + # Remote control plane state (per-workflow, pushed from server via WS). # Unified model: effective_state = max(local_state, remote_state) + # P1-1.1: All reads/writes go through the `_remote_state_for` / + # `_set_remote_state` helpers under `_states_lock` to avoid the + # TOCTOU race that was previously possible between the + # "if workflow_id not in self._remote_states" check and the + # subsequent dict write. `dict` itself is GIL-atomic for + # individual ops, but the "check then insert" pattern in + # `track()` is not. Re-entrant lock is used because the WS + # callback and the synchronous `check_control_plane` can + # both be on the same call path in nested cases. self._remote_states: dict[str, dict[str, Any]] = {} + self._states_lock = threading.RLock() + + # Phase B: control plane transport (WS push vs HTTP poll). + self._transport_mode: str = os.getenv("NULLRUN_TRANSPORT", "ws").lower() + self._ws_thread: threading.Thread | None = None + self._ws_stop_event = threading.Event() + self._ws_connection: Any = None + self._ws_loop: Any = None + # Legacy HTTP-poll state — only used when transport mode is `http`. + self._poll_thread: threading.Thread | None = None + self._poll_running = False + + # Action handling + decision-history recorder. + self._action_handler: ActionHandler | None = None + self._recorder: DecisionHistoryRecorder | None = None + self._is_recording = False + + def _remote_state_for(self, workflow_id: str) -> dict[str, Any]: + """Get-or-create the per-workflow state dict under lock. + + Used by every read of `self._remote_states` to avoid the + TOCTOU race that was previously possible.""" + with self._states_lock: + state = self._remote_states.get(workflow_id) + if state is None: + state = {} + self._remote_states[workflow_id] = state + return state + + def _set_remote_state( + self, workflow_id: str, state: dict[str, Any] + ) -> None: + """Atomically set the per-workflow state under lock.""" + with self._states_lock: + self._remote_states[workflow_id] = state # Phase B: control plane transport. The SDK connects to the server's # WS endpoint and receives state push events (killed/paused) within @@ -437,22 +492,40 @@ def __init__( ), ) - # P2: Try to initialize gRPC transport for high-performance event ingestion - # gRPC uses binary protobuf + HTTP/2 for 30-50% overhead reduction vs REST/JSON + # P2 (removed in 0.4.0): gRPC transport was deleted. The backend + # proto is frozen and missing trace/span fields; HTTP is the + # only supported transport. The NULLRUN_USE_GRPC env var is + # now a no-op (logged once at WARNING if set). + + # Action handler + decision-history recorder are initialised + # in BOTH the test-mode and the cloud-mode branches below, + # so the runtime always has the attribute available when + # ``track()`` consults ``_is_recording`` (Phase 5 cleanup + # had a regression where ``_test_mode=True`` skipped the + # recorder init and ``rt.track(...)`` raised AttributeError). + + # Initialize if os.getenv("NULLRUN_USE_GRPC"): - self._grpc_transport = create_grpc_transport( - api_key=self.api_key, + logger.warning( + "NULLRUN_USE_GRPC is set but the gRPC transport has been " + "removed in SDK 0.4.0 — falling back to HTTP. The env var " + "is now a no-op. See CHANGELOG.md for the migration timeline." ) - if self._grpc_transport: - logger.info("gRPC transport initialized for high-performance event ingestion") - else: - logger.warning("NULLRUN_USE_GRPC is set but gRPC transport could not be initialized (proto files may be missing)") - # Initialize if self._test_mode: # Test mode: skip all network calls, use local policy self._policy = self._policy or Policy.default_local() self._transport.start() + # Initialise the action handler and the local + # decision-history recorder so ``track()`` and + # ``start_recording()`` work in test mode. The previous + # code only initialised them in the cloud branch + # (below) and skipped them for ``_test_mode=True``, + # which broke ``test_track_increments_counter`` and + # any test that called ``rt.track(...)`` without going + # through ``auth/verify``. + self._action_handler = ActionHandler() + self._recorder = DecisionHistoryRecorder(runtime=self) else: try: self._authenticate() @@ -478,7 +551,7 @@ def __init__( # Phase 1.4: Sensitive tools that require strict mode (pre-execution enforcement) # These tools MUST go through /execute endpoint, NOT direct execution - self._sensitive_tools: set = { + self._sensitive_tools: set[str] = { # Financial operations "stripe.charge", "stripe.refund", @@ -580,10 +653,14 @@ def _authenticate(self) -> None: logger.debug(f"Authenticating with API at {self.api_url}/auth/verify") try: - # Use Transport's client for connection pooling, retry, and circuit breaker - response = self._transport._client.post( - f"{self.api_url}/api/v1/auth/verify", - json={"api_key": self.api_key}, + # Route through _signed_post so HMAC + W3C trace context + # are applied automatically. Phase 1: HMAC always-on. + # /auth/verify accepts a signed body for symmetry with + # the rest of the API surface. + response = self._transport._signed_post( + "/api/v1/auth/verify", + {"api_key": self.api_key}, + timeout=10.0, ) if response.status_code == 200: @@ -642,10 +719,10 @@ def _fetch_policy(self) -> None: return try: - # Use Transport's client for connection pooling, retry, and circuit breaker - response = self._transport._client.post( - f"{self.api_url}/api/v1/policies", - json={"organization_id": self.organization_id}, + # Route through _signed_post (Phase 1). + response = self._transport._signed_post( + "/api/v1/policies", + {"organization_id": self.organization_id}, ) if response.status_code == 200: @@ -764,12 +841,12 @@ def on_state_change(state: dict[str, Any]) -> None: if not workflow_id: logger.debug("WS state message missing workflow_id: %s", state) return - self._remote_states[workflow_id] = { + self._set_remote_state(workflow_id, { "state": state.get("state", "Normal"), "version": state.get("version", 0), "reason": state.get("reason"), "updated_at": state.get("updated_at", 0), - } + }) logger.debug( "WS state push: workflow=%s state=%s reason=%s", workflow_id, @@ -811,8 +888,12 @@ def _poll_commands(self) -> None: """ while self._poll_running: try: - # Get all workflows we're tracking - workflow_ids = list(self._remote_states.keys()) + # Get all workflows we're tracking. Snapshot the keys + # under lock to avoid `RuntimeError: dictionary + # changed size during iteration` if a concurrent + # `_set_remote_state` adds a workflow mid-poll. + with self._states_lock: + workflow_ids = list(self._remote_states.keys()) if not workflow_ids: # If no workflows yet, try to get organization workflows pass @@ -849,32 +930,45 @@ def _resolve_workflow_id(self, explicit: str | None = None) -> str | None: return self.workflow_id def _fetch_remote_state(self, workflow_id: str) -> None: - """Fetch remote state for a specific workflow from /status endpoint.""" + """Fetch remote state for a specific workflow from /status endpoint. + + Phase 1: routed through ``_transport._signed_request`` so the + canonical header set (X-API-Key, X-API-Version, optional HMAC) + is applied in one place. A GET has no body, so no signature + is computed — the server authenticates via the X-API-Key + header. + """ try: - response = httpx.get( - f"{self.api_url}/api/v1/status/{workflow_id}", - headers=self._auth_headers(), + response = self._transport._signed_request( + "GET", + f"/api/v1/status/{workflow_id}", timeout=5.0, ) if response.status_code == 200: data = response.json() - self._remote_states[workflow_id] = { + self._set_remote_state(workflow_id, { "state": data.get("state", "Normal"), "version": data.get("version", 0), "reason": data.get("reason"), "updated_at": data.get("updated_at", 0), - } - logger.debug(f"Remote state for {workflow_id}: {self._remote_states[workflow_id]}") + }) + logger.debug(f"Remote state for {workflow_id}: {data}") except Exception as e: logger.debug(f"Failed to fetch remote state for {workflow_id}: {e}") - def check_control_plane(self, workflow_id: str) -> None: + def check_control_plane(self, workflow_id: str | None) -> None: """ Check remote control plane state and raise if workflow is paused/killed. This is called in the execution path after local enforcement. The unified state model: effective_state = max(local_state, remote_state) + Args: + workflow_id: Optional workflow id. Resolved through + `_resolve_workflow_id` (contextvar → API-key-bound + workflow → no-op). `None` is the canonical "no + workflow scoped" value — the gate then no-ops. + Raises: WorkflowPausedException: If workflow is paused on server WorkflowKilledInterrupt: If workflow is killed on server @@ -884,17 +978,23 @@ def check_control_plane(self, workflow_id: str) -> None: # in that case there's no workflow to check, so we no-op # (preserves pre-139 behavior for keys that have never been # workflow-bound). - resolved = self._resolve_workflow_id(workflow_id or None) + resolved = self._resolve_workflow_id(workflow_id) if not resolved: return workflow_id = resolved - # Ensure we have the latest remote state - if workflow_id not in self._remote_states: + # Ensure we have the latest remote state. The "is in cache" + # check is done under the lock to avoid the TOCTOU race + # where a concurrent `_set_remote_state` could change the + # answer between the check and the read. + with self._states_lock: + in_cache = workflow_id in self._remote_states + if not in_cache: # Fetch synchronously if not in cache yet self._fetch_remote_state(workflow_id) - remote_state = self._remote_states.get(workflow_id, {}) + with self._states_lock: + remote_state = self._remote_states.get(workflow_id, {}) state = remote_state.get("state", "Normal") if state == "Paused": @@ -993,25 +1093,36 @@ def _auth_headers(self) -> dict[str, str]: return headers def shutdown(self) -> None: - """Shutdown runtime gracefully.""" + """Shutdown runtime gracefully. + + Defensive against missing attributes: the test-mode + constructor does not initialize `_poll_thread`, `_ws_thread`, + `_ws_stop_event`, etc. — `getattr` is used everywhere a + missing attribute is possible. Without this, a test-mode + runtime that calls `shutdown()` raises `AttributeError`.""" # Stop the HTTP poller (legacy path) if it was started. self._poll_running = False - if self._poll_thread and self._poll_thread.is_alive(): - self._poll_thread.join(timeout=2.0) + poll_thread = getattr(self, "_poll_thread", None) + if poll_thread is not None and poll_thread.is_alive(): + poll_thread.join(timeout=2.0) # Stop the WS control plane listener (Phase B). Closing the # connection causes the receive task to unblock, the loop to # exit, and the thread to terminate. - self._ws_stop_event.set() - conn = self._ws_connection - if conn is not None and self._ws_loop is not None: + ws_stop_event = getattr(self, "_ws_stop_event", None) + if ws_stop_event is not None: + ws_stop_event.set() + conn = getattr(self, "_ws_connection", None) + ws_loop = getattr(self, "_ws_loop", None) + if conn is not None and ws_loop is not None: try: - future = asyncio.run_coroutine_threadsafe(conn.close(), self._ws_loop) + future = asyncio.run_coroutine_threadsafe(conn.close(), ws_loop) future.result(timeout=2.0) except Exception as e: logger.debug(f"WS close on shutdown failed (best-effort): {e}") - if self._ws_thread and self._ws_thread.is_alive(): - self._ws_thread.join(timeout=2.0) + ws_thread = getattr(self, "_ws_thread", None) + if ws_thread is not None and ws_thread.is_alive(): + ws_thread.join(timeout=2.0) if self._transport: self._transport.stop() @@ -1118,8 +1229,10 @@ def track( # may be None on legacy keys — that's fine, the no-op # branch in check_control_plane will skip polling. workflow_id = enriched.get("workflow_id") - if workflow_id and workflow_id not in self._remote_states: - self._remote_states[workflow_id] = {} + if workflow_id: + # Use the helper to avoid the TOCTOU race between the + # "not in dict" check and the write. + self._remote_state_for(workflow_id) # Local policy enforcement (BEFORE sending) if self._policy: @@ -1130,29 +1243,13 @@ def track( # contextvar → self.workflow_id → no-op (legacy keys). self.check_control_plane(workflow_id) - # Buffer for transport - use gRPC if available for better performance - if self._grpc_transport: - # gRPC path: direct send for lowest latency - try: - self._grpc_transport.track( - event_id=enriched.get("event_id", ""), - workflow_id=enriched.get("workflow_id", ""), - tokens=enriched.get("tokens", 0), - tool_name=enriched.get("tool_name"), - is_retry=enriched.get("is_retry", False), - event_type=enriched.get("event_type", ""), - ) - except Exception as e: - logger.warning(f"gRPC track failed, falling back to HTTP: {e}") - wire_event = {k: v for k, v in enriched.items() if k != "cost_cents"} - self._transport.track(wire_event) - else: - # The wire payload must NOT include cost_cents — the SDK - # does not estimate cost. The backend recomputes it from - # tokens + the org's policy. Local budget enforcement - # already ran on the original event dict above. - wire_event = {k: v for k, v in enriched.items() if k != "cost_cents"} - self._transport.track(wire_event) + # Buffer for transport. (gRPC path was removed in 0.4.0 — + # the backend proto is frozen and missing trace/span fields.) + # The wire payload must NOT include `cost_cents` — the SDK + # does not estimate cost; the backend recomputes it from + # `tokens` + the org's pricing policy. + if self._transport is not None: + self._transport.track(self._strip_wire_only_fields(enriched)) # Update metrics (thread-safe) metrics.inc_runtime("track_calls") @@ -1337,7 +1434,7 @@ def execute( metrics.inc_runtime("execute_allowed") return result - def wrap_tool(self, tool_name: str, tool_fn: callable) -> callable: + def wrap_tool(self, tool_name: str, tool_fn: Callable[..., Any]) -> Callable[..., Any]: """ Wrap a tool function with pre-execution enforcement. @@ -1354,7 +1451,7 @@ def wrap_tool(self, tool_name: str, tool_fn: callable) -> callable: Wrapped function """ @functools.wraps(tool_fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: # Pre-execution check (raises if blocked) input_data = {"args": args, "kwargs": kwargs} self.execute(tool_name, input_data) @@ -1368,7 +1465,7 @@ def wrapper(*args, **kwargs): return output return wrapper - def wrap(self, tool_fn: callable) -> callable: + def wrap(self, tool_fn: Callable[..., Any]) -> Callable[..., Any]: """ Wrap a tool function with NullRun protection. @@ -1385,18 +1482,26 @@ def wrap(self, tool_fn: callable) -> callable: Returns: Wrapped function that auto-calls execute() before running """ + from nullrun.context import get_workflow_id + tool_name = tool_fn.__name__ @functools.wraps(tool_fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: # Pre-execution check input_data = {"args": args, "kwargs": kwargs} result = self.execute(tool_name, input_data) - # Raise if blocked + # Raise if blocked. Resolve workflow_id from the active + # contextvar — `execute()` already raises NullRunBlockedException + # for a real gateway block, so this branch only fires if a + # future caller returns a dict with decision=block without + # raising. Use the same fallback the rest of the runtime + # uses ("") when no contextvar is set. if result.get("decision") == "block": + resolved_wf = get_workflow_id() or "" raise NullRunBlockedException( - workflow_id=workflow_id or "", + workflow_id=resolved_wf, reason=result.get("explanation", "policy violation"), tool_name=tool_name, ) @@ -1560,39 +1665,44 @@ def evaluate( workflow_id = get_workflow_id() trace_id = get_trace_id() or str(uuid.uuid4()) - # Call /evaluate endpoint if available, otherwise fallback to /execute - # Use transport._client for connection pooling, retry, and circuit breaker + # Route through `transport.evaluate()` (the public API) so the + # call benefits from the same connection pool, HMAC signing, + # circuit breaker, and retry policy as `execute()`. The + # previous implementation reached into `transport._client` + # directly, which silently bypassed the circuit breaker — a + # production hazard on a long-lived runtime. + # + # `transport.evaluate()` is fail-CLOSED on transport error by + # default (raises NullRunTransportError) per ADR-008. We + # swallow that here because the public `runtime.evaluate()` + # contract is "always return a dict" (used for pre-validation + # / dry-run), not "halt the agent on a backend outage". + from nullrun.breaker.exceptions import NullRunTransportError + try: - response = self._transport._client.post( - f"{self.api_url}/api/v1/evaluate", - json={ - "organization_id": organization_id, - "execution_id": workflow_id, - "trace_id": trace_id, - "tool": tool_name, - "context": context or {}, - }, - headers=self._auth_headers(), - timeout=5.0, + return self._transport.evaluate( + organization_id=organization_id, + execution_id=workflow_id, + trace_id=trace_id, + tool=tool_name, + context=context or {}, + on_transport_error="closed", ) - - if response.status_code == 200: - return response.json() # type: ignore[no-any-return] - - except httpx.RequestError: - pass - - # Fallback: simulate evaluate response based on local policy - is_sensitive = self.is_sensitive_tool(tool_name) - return { - "decision": "allow" if not is_sensitive else "block", - "decision_source": DecisionSource.FALLBACK, - "explanation": "Evaluation endpoint unavailable", - "policy_version": 0, - "matched_rules": [], - "scores": {}, - "allow_execution": not is_sensitive, - } + except NullRunTransportError as exc: + # Transport unavailable — return a local-fallback decision + # so pre-validation never halts the user's agent. + is_sensitive = self.is_sensitive_tool(tool_name) + return { + "decision": "allow" if not is_sensitive else "block", + "decision_source": DecisionSource.FALLBACK, + "explanation": ( + f"Evaluation endpoint unavailable ({exc.source.value}): {exc}" + ), + "policy_version": 0, + "matched_rules": [], + "scores": {}, + "allow_execution": not is_sensitive, + } def start_recording(self, workflow_id: str, metadata: dict[str, Any] = None) -> str: """ @@ -1669,6 +1779,29 @@ def _enrich_event(self, event: dict[str, Any]) -> dict[str, Any]: return enriched + @staticmethod + def _strip_wire_only_fields(event: dict[str, Any]) -> dict[str, Any]: + """Remove fields the SDK adds for local enforcement but that + do not belong on the wire (the backend recomputes them from + `tokens` + the org's pricing policy). + + Two fields are stripped: + + - ``cost_cents``: backend recomputes from tokens + pricing. + - ``_fingerprint``: sink-only dedup key for the + ``_seen_track_fingerprints`` LRU; never reaches the + gateway. + + Centralized so the wire-format contract is in one place; if + a future SDK revision adds more local-only fields they land + here too. + """ + return { + k: v + for k, v in event.items() + if k not in ("cost_cents", "_fingerprint") + } + def _check_local_limits(self, event: dict[str, Any]) -> None: """ Check local policy limits without network call. @@ -1891,6 +2024,17 @@ def track_event( # to 0 so the deserializer accepts the event; the cost # computation in the handler treats 0 tokens as no-op. event.setdefault("tokens", 0) + # Phase 3: emit a stable fingerprint so the dedup LRU at + # the track() sink can collapse repeat emissions of the + # same event (e.g. when the user calls track_event manually + # AND the httpx transport hook fires for the same LLM + # call). Field is stripped before wire send (see + # `_strip_wire_only_fields`). + if "_fingerprint" not in event: + from nullrun.instrumentation.auto import ( + _fingerprint_for_event_dict, + ) + event["_fingerprint"] = _fingerprint_for_event_dict(event) return self.track(event) diff --git a/src/nullrun/tracing.py b/src/nullrun/tracing.py index 9a3de70..6019939 100644 --- a/src/nullrun/tracing.py +++ b/src/nullrun/tracing.py @@ -36,7 +36,6 @@ import uuid from contextvars import ContextVar from dataclasses import dataclass -from typing import Optional def _new_id() -> str: @@ -66,19 +65,19 @@ class SpanContext: trace_id: str span_id: str - parent_span_id: Optional[str] = None + parent_span_id: str | None = None depth: int = 0 # The currently-active span. `None` means "no trace in progress" — track_* # will fall back to creating a synthetic root on each call so events are # still attributed to *something*. -_current_span: ContextVar[Optional[SpanContext]] = ContextVar( +_current_span: ContextVar[SpanContext | None] = ContextVar( "nullrun_span", default=None ) -def get_current_span() -> Optional[SpanContext]: +def get_current_span() -> SpanContext | None: """ Return the active span, or None if no `@protect` / manual `set_span` has put us inside a trace. diff --git a/src/nullrun/transport.py b/src/nullrun/transport.py index 9e03e86..926ad79 100644 --- a/src/nullrun/transport.py +++ b/src/nullrun/transport.py @@ -5,19 +5,16 @@ Includes fallback modes for Gateway unavailability. """ -import asyncio -import atexit import hashlib import hmac import json import logging import os import random -import signal -import sys import threading import time import uuid +import weakref from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass @@ -27,7 +24,14 @@ from nullrun.actions import handle_action from nullrun.breaker.circuit_breaker import CircuitBreaker -from nullrun.breaker.exceptions import BreakerTransportError, InsecureTransportError, NullRunAuthenticationError +from nullrun.breaker.exceptions import ( + BreakerTransportError, + InsecureTransportError, + NullRunAuthenticationError, + NullRunTransportError, + RateLimitError, + TransportErrorSource, +) from nullrun.observability import metrics # OpenTelemetry imports (lazy-loaded to support optional dependency) @@ -44,124 +48,6 @@ # ============================================================================= -# Pool Configuration & Adaptive Pool -# ============================================================================= - -@dataclass -class PoolConfig: - """Configuration for adaptive connection pool. - - Args: - initial_connections: Starting number of connections (default: 5) - max_connections: Maximum concurrent connections (default: 100) - max_keepalive: Max keepalive connections (default: 20) - acquire_timeout: Timeout for acquiring a connection (default: 30s) - idle_timeout: Keepalive expiry (default: 60s) - scale_up_threshold: Scale up when waiting > active * threshold (default: 2.0) - scale_down_idle: Scale down if idle > this fraction of active (default: 0.3) - """ - initial_connections: int = 5 - max_connections: int = 100 - max_keepalive: int = 20 - acquire_timeout: float = 30.0 - idle_timeout: float = 60.0 - scale_up_threshold: float = 2.0 - scale_down_idle: float = 0.3 - - -class AdaptivePool: - """Connection pool that scales based on demand. - - Uses a semaphore to limit concurrent connections. Provides backpressure - signaling when pool is exhausted via the pool_exhausted metric. - """ - - def __init__(self, config: PoolConfig): - self._config = config - self._semaphore = asyncio.Semaphore(config.max_connections) - self._active_connections = 0 - self._waiting_tasks = 0 - self._total_acquired = 0 - self._total_released = 0 - self._exhausted_count = 0 - self._lock = asyncio.Lock() - - async def acquire(self) -> bool: - """Acquire connection with backpressure. - - Returns True if acquired, False if timeout (pool exhausted). - """ - async with self._lock: - self._waiting_tasks += 1 - - try: - acquired = await asyncio.wait_for( - self._semaphore.acquire(), - timeout=self._config.acquire_timeout - ) - async with self._lock: - self._active_connections += 1 - self._total_acquired += 1 - self._waiting_tasks -= 1 - return True - - except asyncio.TimeoutError: - async with self._lock: - self._waiting_tasks -= 1 - self._exhausted_count += 1 - metrics.inc_transport("pool_exhausted") - logger.warning( - f"Pool exhausted: {self._active_connections} active, " - f"{self._waiting_tasks} waiting, {self._exhausted_count} total exhaustions" - ) - return False - - def release(self) -> None: - """Release a connection back to the pool.""" - self._active_connections -= 1 - self._total_released += 1 - self._semaphore.release() - - async def scale_up_if_needed(self) -> None: - """Increase pool size if demand is high. - - Called periodically to check if we should allow more concurrent connections. - Scales up when waiting tasks > active connections * threshold. - """ - async with self._lock: - if self._waiting_tasks > self._active_connections * self._config.scale_up_threshold: - if self._active_connections < self._config.max_connections: - self._semaphore.release() - self._active_connections += 1 - metrics.inc_transport("pool_scaled_up") - logger.debug( - f"Scaled up pool: active={self._active_connections}, " - f"waiting={self._waiting_tasks}" - ) - - async def scale_down_if_needed(self) -> None: - """Decrease pool size if we have excess idle capacity. - - Scales down when active connections < max_connections and - we haven't used the full pool recently. - """ - async with self._lock: - if self._active_connections > self._config.initial_connections: - usage_ratio = self._active_connections / self._config.max_connections - if usage_ratio < self._config.scale_down_idle: - pass # Conservative - don't auto-scale down aggressively - - def get_stats(self) -> dict: - """Get current pool statistics.""" - return { - "active": self._active_connections, - "waiting": self._waiting_tasks, - "max": self._config.max_connections, - "total_acquired": self._total_acquired, - "total_released": self._total_released, - "exhausted_count": self._exhausted_count, - } - __api_version__ = "1.0" @@ -321,6 +207,13 @@ def get_stats(self) -> dict: def __len__(self) -> int: return len(self._cache) + def clear(self) -> None: + """Drop every cached decision. Counters (hits/misses) are + preserved so observability dashboards can still see the + lifetime aggregate after a clear. + """ + self._cache.clear() + # ============================================================================= # Retry with exponential backoff + jitter @@ -348,6 +241,13 @@ def _retry_with_backoff( Formula (without Retry-After): delay = min(base_delay * backoff_factor^attempt, max_delay) delay += random.uniform(-jitter * delay, jitter * delay) Formula (with Retry-After): actual_delay = min(last_retry_after_seconds, max_delay) + + Re-raises the original exception on retry exhaustion so callers can + inspect the concrete cause (httpx.ConnectError for network, an + HTTPStatusError for 5xx, etc.) and produce a classified + `decision_source`. The legacy `BreakerTransportError` wrapper + that used to be raised here conflated "CB OPEN" and "retries + exhausted" — ADR-008 requires the caller to tell the two apart. """ last_exc: Exception | None = None @@ -358,12 +258,21 @@ def _retry_with_backoff( if hasattr(result, "status_code"): if result.status_code == 401: raise NullRunAuthenticationError("Invalid API key") - if result.status_code >= 400: - result.raise_for_status() + # Do NOT raise on 4xx/5xx here — the caller wants to + # inspect the response (decide on_transport_error, + # fall back, etc.) without the retry helper converting + # a 5xx into an HTTPStatusError that masks the + # status code. 4xx non-auth responses are real + # gateway decisions and should NOT be retried either, + # so the caller short-circuits on the first such + # response — the retry loop is only useful for + # network-level flakes. return result except (BreakerTransportError, NullRunAuthenticationError): + # CB OPEN or auth — do not retry. Re-raise so the caller + # classifies. raise except Exception as exc: @@ -400,9 +309,222 @@ def _retry_with_backoff( time.sleep(actual_delay) + # Retries exhausted. Re-raise the original exception so the caller + # can inspect the concrete cause. If the helper was somehow + # called with no exception (defensive), fall back to a generic + # BreakerTransportError. + if last_exc is not None: + raise last_exc raise BreakerTransportError( f"Request failed after {max_retries + 1} attempts" - ) from last_exc + ) + + +# ============================================================================= +# Transport-error routing (ADR-008) +# ============================================================================= +# When the policy engine is unreachable, the caller has to decide +# between fail-OPEN ("let the call through with a flagged decision +# so the dashboard can show it") and fail-CLOSED ("block the call +# so a denied `charge_card()` cannot run during an outage"). The +# contract is declared per-call via `on_transport_error`: +# +# "raise" → raise NullRunTransportError so the calling gate +# (e.g. `_enforce_sensitive_tool`) can apply its own +# fail-OPEN/CLOSED rule. Default for /execute. +# "open" → return synthetic allow with `decision_source` set +# to the classified source. Used for fail-OPEN callers +# that want the dict shape instead of an exception. +# "closed" → return synthetic block with `decision_source` set +# to the classified source. Used for fail-CLOSED +# callers that want the dict shape. +# "legacy" → use the historical `fallback_mode` (STRICT / CACHED / +# PERMISSIVE) to decide. Preserved for backward compat. + +_TRANSPORT_ERROR_RESULTS = { + "open": { + "decision": "allow", + "decision_source": None, # filled in by the handler with the source + "explanation": "Gateway unavailable, fail-OPEN", + "policy_version": 0, + }, + "closed": { + "decision": "block", + "decision_source": None, # filled in by the handler with the source + "explanation": "Gateway unavailable, fail-CLOSED", + "policy_version": 0, + }, +} + + +def _handle_transport_error( + mode: str, + source: "TransportErrorSource", + endpoint: str, + detail: str, +) -> dict[str, Any] | None: + """ + Route a transport failure to the declared caller policy. See the + module-level ADR-008 contract for the four `mode` values. + + Returns a dict ONLY when the caller asked for a dict (open / + closed). For `mode == "raise"` this raises + `NullRunTransportError` and does not return. For + `mode == "legacy"` the caller is expected to apply its own + `fallback_mode` logic, so this returns None as a sentinel — + the public method's post-handler fallback block is what actually + runs in that case. + """ + if mode == "raise": + raise NullRunTransportError( + detail, + source=source, + endpoint=endpoint, + ) + if mode == "open": + return { + "decision": "allow", + "decision_source": source, + "explanation": ( + f"Gateway unavailable ({source.value}), fail-OPEN: {detail}" + ), + "policy_version": 0, + } + if mode == "closed": + return { + "decision": "block", + "decision_source": source, + "explanation": ( + f"Gateway unavailable ({source.value}), fail-CLOSED: {detail}" + ), + "policy_version": 0, + } + if mode == "legacy": + return None # caller applies fallback_mode itself + # Unknown mode: fail-OPEN by default. Wrong is better than silent. + logger.warning( + "Unknown on_transport_error=%r; falling back to raise", mode + ) + raise NullRunTransportError(detail, source=source, endpoint=endpoint) + + +def _parse_error_envelope( + response: httpx.Response, + endpoint: str, +) -> Exception: + """Translate a non-2xx ``httpx.Response`` into the right exception + subclass per the canonical ``contracts/errors.ts`` envelope. + + Phase 4 of the production-readiness plan: previously every + 4xx / 5xx was raised as a generic ``NullRunTransportError``, + which lost the ``error`` slug (``rate_limit_exceeded`` / + ``unauthorized`` / …) the operator needs to classify the + failure. We map the most common slugs to distinct + ``RateLimitError`` / ``NullRunAuthenticationError`` / + ``NullRunTransportError(GATEWAY_ERROR)`` so callers can + branch on the type instead of string-matching ``str(exc)``. + + Args: + response: The non-2xx ``httpx.Response`` from the gateway. + endpoint: A short string naming the endpoint the request + targeted (``"track"``, ``"gate"``, ``"evaluate"``, + ``"status"``). Embedded in the raised exception so + callers can implement endpoint-specific retry policy. + + Returns: + A concrete exception instance — always a subclass of + ``BreakerError``. The caller is expected to ``raise`` it. + """ + status = response.status_code + # Best-effort parse of the JSON envelope. Some endpoints + # (e.g. NGINX 502 pages) don't return JSON; we tolerate + # that and fall back to the raw status code. + try: + body = response.json() + except Exception: + body = None + if not isinstance(body, dict): + body = {} + error_slug: str = body.get("error", "") or "" + message: str = ( + body.get("message") + or response.text + or f"HTTP {status}" + ) + + # 401 / 403 — auth-class. Per ADR-008 these are NEVER silenced + # by ``on_transport_error``; the SDK propagates the failure + # so the runtime can re-run ``auth/verify`` and retry once + # (Phase 4: ``_authenticate`` does this transparently for + # direct runtime calls; transport.execute() / transport.check() + # leave the exception for the caller). + if status in (401, 403): + return NullRunAuthenticationError( + f"Auth failed on {endpoint} (status {status}, " + f"error={error_slug!r}): {message}" + ) + + # 429 — rate-limit. The gateway sends ``Retry-After`` as a + # standard HTTP header (preferred) and may also include + # ``retry_after`` / ``upgrade_url`` in the JSON body. We + # honour both and surface them on the raised exception. + if status == 429: + # ``_extract_retry_after`` is a method on ``Transport``, + # not a module-level helper. We replicate the parser + # inline so the envelope helper can be called without a + # transport instance (the body parsing in particular + # runs from background threads that don't carry the + # full transport state). + retry_after: float | None = None + ra_header = response.headers.get("Retry-After") + if ra_header: + try: + retry_after = float(ra_header) + except ValueError: + try: + from email.utils import parsedate_to_datetime + from datetime import datetime, timezone + dt = parsedate_to_datetime(ra_header) + retry_after = ( + dt - datetime.now(timezone.utc) + ).total_seconds() + except Exception: + retry_after = None + upgrade_url = body.get("upgrade_url") if isinstance(body, dict) else None + return RateLimitError( + f"Rate limited on {endpoint} (status 429, error={error_slug!r}): " + f"{message}", + source=TransportErrorSource.GATEWAY_ERROR, + endpoint=endpoint, + retry_after=retry_after, + upgrade_url=upgrade_url, + body=body, + ) + + # 5xx — server-side. Distinct exception type so callers can + # branch on it (e.g. trigger a circuit-breaker backoff vs. + # 4xx which is a permanent client error). + if 500 <= status < 600: + return NullRunTransportError( + f"Gateway error on {endpoint} (status {status}, " + f"error={error_slug!r}): {message}", + source=TransportErrorSource.GATEWAY_ERROR, + endpoint=endpoint, + status_code=status, + error_slug=error_slug, + ) + + # 4xx (non-auth) — the gateway explicitly rejected the + # request. We surface as ``NullRunTransportError`` with the + # slug embedded so the caller can decide whether to retry. + return NullRunTransportError( + f"Client error on {endpoint} (status {status}, " + f"error={error_slug!r}): {message}", + source=TransportErrorSource.GATEWAY_ERROR, + endpoint=endpoint, + status_code=status, + error_slug=error_slug, + ) # ============================================================================= # Fallback Modes (Phase 1 - SDK Resilience) @@ -498,8 +620,24 @@ def __init__( self._flush_thread: threading.Thread | None = None self._running = False - # mTLS client certificate support - # NULLRUN_TLS_CLIENT_CERT and NULLRUN_TLS_CLIENT_KEY env vars for client cert auth + # mTLS client certificate support. + # + # Phase 6 of the production-readiness plan: the three env + # vars ``NULLRUN_TLS_CLIENT_CERT``, ``NULLRUN_TLS_CLIENT_KEY``, + # and ``NULLRUN_TLS_CA_CERT`` are read here and wired into + # the underlying ``httpx.Client``. The contract is + # documented in ``audits/new_audit_ux.md:876-887`` and is + # the SDK-facing surface for the platform's opt-in mTLS + # mode (server-side flag ``TLS_CLIENT_AUTH_ENABLED=true``). + # + # When BOTH ``NULLRUN_TLS_CLIENT_CERT`` and + # ``NULLRUN_TLS_CLIENT_KEY`` are set, the SDK presents a + # client certificate during the TLS handshake (mutual + # auth). When only ``NULLRUN_TLS_CA_CERT`` is set, the SDK + # uses it as the trust anchor for verifying the server + # certificate (one-way TLS with a private CA, common in + # staging). When NONE are set, the platform's public CA + # chain is used. client_cert_path = os.environ.get("NULLRUN_TLS_CLIENT_CERT") client_key_path = os.environ.get("NULLRUN_TLS_CLIENT_KEY") ca_cert_path = os.environ.get("NULLRUN_TLS_CA_CERT") # Optional custom CA @@ -555,36 +693,57 @@ def __init__( self._tracer = trace.get_tracer("nullrun.transport") self._propagator = TraceContextTextMapPropagator() - # Register atexit handler for final flush - atexit.register(self._atexit_flush) - - # Register signal handler for graceful shutdown - self._signal_handler_registered = False - self._register_signal_handlers() - - def _register_signal_handlers(self) -> None: - """Register signal handlers for SIGTERM/SIGINT.""" - if self._signal_handler_registered: - return + # Register a weakref-based atexit handler. The closure holds + # a weakref to self; if the transport has been GC'd by the + # time the process exits, the atexit becomes a no-op. This + # replaces the previous signal-handler-and-atexit pair, which + # (a) overwrote the application's global SIGTERM/SIGINT + # handlers on every Transport() construction, (b) called + # sys.exit(0) from a signal context, and (c) did file I/O + # from a signal context — all of which are unsafe in + # long-lived services. + # + # Callers are now responsible for shutdown in long-lived + # services: either call `transport.stop()` explicitly, use + # `transport` as a context manager, or rely on + # `weakref.finalize` (registered below) to fire on GC. + weakref.finalize( + self, + self._atexit_flush_safe, + id(self), # bind for debug log + ) - def _handle_shutdown(signum, frame): - logger.info(f"Received signal {signum}, initiating graceful shutdown") - self._running = False - self._do_flush() # Sync flush - self._persist_to_wal() # Persist unflushed events to WAL - self._client.close() - sys.exit(0) + def _atexit_flush_safe(self, instance_id: int) -> None: + """Final-flush callable used by `weakref.finalize` / `atexit`. - signal.signal(signal.SIGTERM, _handle_shutdown) - signal.signal(signal.SIGINT, _handle_shutdown) - self._signal_handler_registered = True + Wraps `_atexit_flush` so an exception in the flush does NOT + propagate to the interpreter's atexit machinery (which would + silently swallow the next atexit handler — a real footgun + in multi-Transport processes). + """ + try: + self._atexit_flush() + except Exception as exc: # noqa: BLE001 — last-chance hook + logger.warning( + "atexit flush failed for transport id=%s: %s", instance_id, exc + ) def _persist_to_wal(self) -> None: - """Persist unflushed events to WAL file for replay on restart.""" + """Persist unflushed events to WAL file for replay on restart. + + Location precedence: + 1. `NULLRUN_WAL_PATH` env var (operator override). + 2. `/.nullrun.wal` (per-user, OS-appropriate temp dir + — `/tmp` on Linux, `C:\\Users\\\\AppData\\Local\\Temp` + on Windows, etc.). This replaced the previous + `os.getcwd()/.nullrun.wal` location, which was unsafe in + long-lived production services (WAL would land in + whatever directory the SDK was started from). + """ if not self._buffer: return event_count = len(self._buffer) - wal_path = os.path.join(os.getcwd(), ".nullrun.wal") + wal_path = self._wal_path() with open(wal_path, "a") as f: for event in self._buffer: f.write(json.dumps(event) + "\n") @@ -593,11 +752,11 @@ def _persist_to_wal(self) -> None: def _replay_from_wal(self) -> None: """Replay events from WAL file on startup.""" - wal_path = os.path.join(os.getcwd(), ".nullrun.wal") + wal_path = self._wal_path() if not os.path.exists(wal_path): return events = [] - with open(wal_path, "r") as f: + with open(wal_path) as f: for line in f: try: events.append(json.loads(line.strip())) @@ -609,6 +768,16 @@ def _replay_from_wal(self) -> None: os.remove(wal_path) # Clean up WAL after successful replay logger.info(f"Replayed {len(events)} events from WAL") + @staticmethod + def _wal_path() -> str: + """Resolve the WAL file path. See `_persist_to_wal` for the + precedence rules.""" + configured = os.environ.get("NULLRUN_WAL_PATH") + if configured: + return configured + import tempfile + return os.path.join(tempfile.gettempdir(), ".nullrun.wal") + def track(self, event: dict[str, Any]) -> None: """ Add event to buffer. Non-blocking. @@ -642,7 +811,13 @@ def start(self) -> None: logger.info("Transport flush thread started") def stop(self, timeout: float = 10.0) -> None: - """Stop background flush thread and flush remaining events.""" + """Stop background flush thread and flush remaining events. + + Callers in long-lived services MUST call this explicitly + (or use the `Transport` as a context manager) — the SDK + no longer installs process-wide signal handlers. See the + class docstring for the recommended lifecycle. + """ self._running = False self._stopped = True # Mark as stopped to prevent double flush if self._flush_thread: @@ -650,10 +825,23 @@ def stop(self, timeout: float = 10.0) -> None: self._do_flush() # Final flush self._persist_to_wal() # WAL any remaining events self._client.close() - # Unregister atexit to avoid double flush - atexit.unregister(self._atexit_flush) logger.info("Transport stopped") + def __enter__(self) -> "Transport": + """Context manager entry. Starts the background flush thread. + + Usage: + with Transport(api_url=..., api_key=...) as t: + t.track(...) + # Final flush + WAL written; weakref.finalize fires on GC. + """ + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + """Context manager exit. Stops the background flush thread.""" + self.stop() + def _atexit_flush(self) -> None: """Final flush on process exit. Guaranteed by atexit registration.""" if self._stopped: @@ -676,14 +864,41 @@ def _do_flush(self) -> None: with self._lock: self._do_flush_locked() + def _drain_batch(self) -> list[dict[str, Any]] | None: + """Atomically snapshot + clear the in-memory buffer. + + Returns the batch to send (or None if the buffer is empty). + Must be called with `self._lock` held. This is the only + code path that mutates `_buffer` outside of `_buffer.append` + in `track()` — the previous contract had two distinct bugs: + + 1. `self._buffer[:]` + `self._buffer.clear()` was not atomic + against concurrent `track()` calls, which could lose + events appended between the snapshot and the clear. + 2. The CB-OPEN re-queue path did `self._buffer = self._buffer[overflow:]`, + which re-bound the attribute to a new list — any + concurrent `track()` call that captured a reference to + the old list would silently drop into dead memory. + + Both are fixed by centralizing all drain/clear operations + through this single helper. + """ + if not self._buffer: + return None + # In-place slice is essential: it mutates the existing list + # in place rather than rebinding the attribute. Any code + # that holds a reference to the list (e.g. an in-flight + # `track()` call) sees the post-drain state, not stale data. + batch = self._buffer[:] + del self._buffer[:] + return batch + def _do_flush_locked(self) -> None: """Flush under lock. Must be called with _lock held.""" - if not self._buffer: + batch = self._drain_batch() + if batch is None: logger.debug("Buffer empty, skipping flush") return - - batch = self._buffer[:] - self._buffer.clear() logger.debug(f"Sending batch of {len(batch)} events") # Circuit breaker wrapped send - uses proper 3-state circuit breaker @@ -708,16 +923,20 @@ def send_batch(): logger.warning( f"Circuit breaker OPEN. Batch of {len(batch)} events will be re-queued." ) - # Enforce max buffer size BEFORE re-queue to prevent unbounded growth - # Drop oldest events first to make room for new batch - available_space = self.config.max_buffer_size - len(self._buffer) - if available_space < len(batch): - overflow = len(batch) - available_space - if overflow > 0: - # Drop oldest from front (batch) since it hasn't been sent yet - logger.warning(f"Buffer overflow on CB OPEN: dropping {overflow} oldest events from pending batch") - batch = batch[overflow:] # type: ignore[assignment] - metrics.inc_transport("events_dropped", overflow) + # Enforce max buffer size BEFORE re-queue. We check the + # batch's own size against the configured ceiling, not + # the current buffer length (the buffer is empty after + # `_drain_batch` — checking it would be a no-op). If the + # batch alone is larger than max_buffer_size, drop the + # oldest events from the batch before re-queuing. + if len(batch) > self.config.max_buffer_size: + overflow = len(batch) - self.config.max_buffer_size + logger.warning( + f"Batch of {len(batch)} exceeds max_buffer_size=" + f"{self.config.max_buffer_size}; dropping {overflow} oldest" + ) + batch = batch[overflow:] + metrics.inc_transport("events_dropped", overflow) # Append to END (not front) so oldest events are retried first self._buffer.extend(batch) # Update metrics on failure (thread-safe) @@ -753,6 +972,168 @@ def _add_hmac_headers(self, headers: dict[str, str], body: str) -> None: headers["X-Signature-Timestamp"] = str(timestamp) headers["X-Signature"] = signature + def _build_signed_headers( + self, + body: str | None, + extra: dict[str, str] | None = None, + ) -> dict[str, str]: + """ + Build the canonical header set for a gateway request. + + Includes: + - Content-Type: application/json + - X-API-Version: __api_version__ + - X-API-Key: (when set) + - X-Signature + X-Signature-Timestamp (when both api_key and + secret_key are set; signature is computed over the exact + `body` bytes the client will transmit) + - W3C trace context (when opentelemetry is installed) + + The `extra` dict is merged in last so callers can override + defaults (e.g. add `Authorization: Bearer …` if needed). + + Returns a new dict; the caller's `extra` is not mutated. + """ + headers: dict[str, str] = { + "Content-Type": "application/json", + "X-API-Version": __api_version__, + } + if self.api_key: + headers["X-API-Key"] = self.api_key + if body is not None: + self._add_hmac_headers(headers, body) + self._inject_trace_context(headers) + if extra: + for k, v in extra.items(): + headers[k] = v + return headers + + def _signed_post( + self, + path: str, + payload: dict[str, Any], + extra_headers: dict[str, str] | None = None, + timeout: float | None = None, + ) -> httpx.Response: + """ + POST to the gateway with HMAC signing, trace context, and + the canonical header set applied automatically. + + Args: + path: URL path (e.g. ``/api/v1/track/batch``). Joined + onto ``self.api_url`` with a single ``/`` separator. + payload: JSON-serialisable body. The exact bytes + produced by ``json.dumps(payload)`` are signed. + extra_headers: Optional extra headers merged on top + of the defaults (X-API-Key, X-Signature, etc.). + timeout: Per-request timeout in seconds. When ``None`` + the shared client's default is used. + + Returns: + The ``httpx.Response`` — the caller is responsible + for inspecting the status code and body. The transport + does NOT raise on 4xx/5xx (per HTTP semantics); the + caller routes the result through ``parse_error_envelope`` + for typed error handling. + """ + body = json.dumps(payload) + headers = self._build_signed_headers(body, extra_headers) + url = f"{self.api_url}{path}" if not path.startswith("/") else f"{self.api_url}{path}" + kwargs: dict[str, Any] = {"headers": headers, "content": body} + if timeout is not None: + kwargs["timeout"] = timeout + return self._client.post(url, **kwargs) + + def _signed_request( + self, + method: str, + path: str, + payload: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, + timeout: float | None = None, + ) -> httpx.Response: + """ + Generic signed request to the gateway. Used for GET (no body + → no signature) and for non-JSON bodies. + + Args: + method: HTTP verb (``GET``, ``POST``, …). Case-insensitive. + path: URL path joined onto ``self.api_url``. + payload: JSON-serialisable body. When provided, the + exact bytes produced by ``json.dumps(payload)`` are + signed. When ``None`` (e.g. GET), no signature is + added — the server treats unsigned GETs as + authenticated via the ``X-API-Key`` header. + extra_headers: Optional extra headers merged on top of + the defaults. + timeout: Per-request timeout in seconds. When ``None`` + the shared client's default is used. + """ + body: str | None = None + if payload is not None: + body = json.dumps(payload) + headers = self._build_signed_headers(body, extra_headers) + url = f"{self.api_url}{path}" if not path.startswith("/") else f"{self.api_url}{path}" + kwargs: dict[str, Any] = {"headers": headers} + if body is not None: + kwargs["content"] = body + if timeout is not None: + kwargs["timeout"] = timeout + return self._client.request(method.upper(), url, **kwargs) + + def post_signed_with_401_retry( + self, + path: str, + payload: dict[str, Any], + reauth_callback: Callable[[], bool] | None = None, + timeout: float | None = None, + ) -> httpx.Response: + """ + POST to the gateway with HMAC signing AND a one-shot + re-authentication retry on HTTP 401. + + Phase 4: the server can rotate an API key out from under + the SDK (via the dashboard's "rotate" button, the WS + ``KeyRotated`` event, or the cron job). The first request + after the rotation returns 401; the SDK re-calls + ``auth/verify`` to pick up the new ``secret_key``, then + retries the original request. A second 401 propagates as + ``NullRunAuthenticationError``. + + Args: + path: URL path joined onto ``self.api_url``. + payload: JSON-serialisable body. Signed with the + current ``self.secret_key`` (which may be the + freshly-rotated key after the first 401). + reauth_callback: A no-arg callable that re-fetches + credentials. The SDK does not know about the + runtime directly; the runtime wires this in + (typically ``lambda: self._authenticate()``). + When ``None``, the first 401 propagates as-is. + timeout: Per-request timeout in seconds. + + Returns: + The ``httpx.Response`` (after at most one re-auth + + retry). The caller is responsible for inspecting + the status code; success is 2xx, anything else is + routed through ``_parse_error_envelope`` for typed + exception raising. + """ + response = self._signed_post(path, payload, timeout=timeout) + if response.status_code == 401 and reauth_callback is not None: + try: + reauthenticated = reauth_callback() + except Exception as exc: # noqa: BLE001 + logger.debug( + f"401 retry: reauth_callback raised: {exc}; " + "propagating the original 401" + ) + return response + if reauthenticated: + # Re-sign the body with the freshly-rotated key. + response = self._signed_post(path, payload, timeout=timeout) + return response + def _inject_trace_context(self, headers: dict[str, str]) -> None: """ Inject trace context into request headers (W3C Trace Context format). @@ -798,22 +1179,13 @@ def _extract_retry_after(self, response: httpx.Response) -> float | None: def _send_batch_with_retry_info(self, batch: list[dict[str, Any]]) -> 'SendResult': """Send batch to server using batch endpoint. Returns SendResult with retry info.""" logger.debug(f"Sending batch of {len(batch)} events to {self.api_url}/api/v1/track/batch") - headers = {"Content-Type": "application/json", "X-API-Version": __api_version__} - if self.api_key: - headers["X-API-Key"] = self.api_key - - # Add HMAC signature headers - body = json.dumps({"events": batch}) - self._add_hmac_headers(headers, body) - - # Inject trace context for distributed tracing (W3C Trace Context) - self._inject_trace_context(headers) - - # Use batch endpoint for efficiency - single request for all events - response = self._client.post( - f"{self.api_url}/api/v1/track/batch", - json={"events": batch}, - headers=headers, + # Route through _signed_post so HMAC + W3C trace context + + # canonical headers (X-API-Key, X-API-Version) are applied + # in one place. Phase 1 of the production-readiness plan: + # HMAC always-on when secret_key is present. + response = self._signed_post( + "/api/v1/track/batch", + {"events": batch}, ) # P0: Extract retry_after from response headers or body @@ -848,12 +1220,20 @@ def _send_batch_with_retry_info(self, batch: list[dict[str, Any]]) -> 'SendResul self._last_retry_after_seconds = 0.0 self._last_failure_policy_limit = is_policy_limit - # Handle 429 response - extract and store Retry-After before raising + # Phase 4: handle 429 with a typed ``RateLimitError`` so the + # caller (background flush thread, runtime.execute, …) can + # branch on the exception type instead of parsing ``str(exc)``. + # We still store the parsed ``Retry-After`` on the transport + # so the existing backoff machinery in the flush thread + # keeps working. if response.status_code == 429: retry_after = self._extract_retry_after(response) if retry_after: self._last_retry_after_seconds = retry_after - response.raise_for_status() + # _parse_error_envelope returns a concrete RateLimitError + # instance; raising it surfaces the structured + # ``retry_after`` / ``upgrade_url`` to the caller. + raise _parse_error_envelope(response, "track") response.raise_for_status() # Process actions_taken from server response @@ -896,6 +1276,7 @@ def execute( mode: str = "auto", fallback_mode: str = FallbackMode.PERMISSIVE, operation_id: str | None = None, + on_transport_error: str = "raise", ) -> dict[str, Any]: """ Pre-execution policy evaluation via unified gate endpoint. @@ -910,8 +1291,23 @@ def execute( tool: Tool to execute input_data: Tool input mode: Execution mode ("auto", "inline", "strict") - fallback_mode: What to do if Gateway unavailable + fallback_mode: What to do if Gateway unavailable. Used only + when `on_transport_error="legacy"`. operation_id: Optional idempotency key + on_transport_error: How to react when the transport cannot + reach the gateway. One of: + - "raise" → raise `NullRunTransportError` with a + classified `source` (NETWORK_ERROR / GATEWAY_ERROR / + BREAKER_OPEN). Default per ADR-008. + - "open" → return a synthetic allow with + `decision_source = NETWORK_ERROR` (or GATEWAY_ERROR + on 5xx). Use for fail-OPEN callers. + - "closed" → return a synthetic block with + `decision_source = NETWORK_ERROR`. Use for fail-CLOSED + callers that want the dict shape instead of an + exception. + - "legacy" → use the historical `fallback_mode` to + decide what to do. Kept for backward compatibility. Returns: Dict with: @@ -920,6 +1316,14 @@ def execute( - explanation: Human-readable explanation - policy_version: Policy version used - decision_context: Context for replay (if available) + + Raises: + NullRunTransportError: When the transport fails AND + `on_transport_error="raise"`. Carries `source` and + `endpoint` for the calling gate to apply its declared + fail-OPEN/CLOSED policy. + NullRunAuthenticationError: On 401/403 from the gateway + regardless of `on_transport_error` (never silenced). """ gate_request = { "organization_id": organization_id, @@ -931,34 +1335,29 @@ def execute( "operation_id": operation_id or str(uuid.uuid4()), } - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["X-API-Key"] = self.api_key - - # Add HMAC signature headers - body = json.dumps(gate_request) - self._add_hmac_headers(headers, body) - - # Inject trace context for distributed tracing (W3C Trace Context) - self._inject_trace_context(headers) - def do_gate_request() -> httpx.Response: - return self._client.post( - f"{self.api_url}/api/v1/gate", - json=gate_request, - headers=headers, + # Route through _signed_post so HMAC + W3C trace context + # are applied automatically (Phase 1). + return self._signed_post( + "/api/v1/gate", + gate_request, timeout=5.0, ) - # Try Gateway with retry backoff - try: - response = _retry_with_backoff( - do_gate_request, - max_retries=2, - base_delay=0.5, - ) - - if response.status_code == 200: + # Try Gateway with retry on network + 5xx. We use a custom + # loop (rather than `_retry_with_backoff`) so we can treat + # 5xx as a retryable transport error without losing the + # status code on exhaustion. + last_status: int | None = None + last_exc: BaseException | None = None + for _attempt in range(3): + try: + response = do_gate_request() + except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.WriteTimeout, httpx.PoolTimeout, httpx.RequestError) as exc: + last_exc = exc + time.sleep(0.1) + continue + if 200 <= response.status_code < 300: data = response.json() data["decision_source"] = DecisionSource.GATEWAY # Cache successful decision for CACHED mode @@ -973,21 +1372,51 @@ def do_gate_request() -> httpx.Response: data.get("policy_version") ) return data # type: ignore[no-any-return] - elif response.status_code >= 400: - # 4xx - don't retry, return block - return { - "decision": "block", - "decision_source": DecisionSource.FALLBACK, - "explanation": f"Gateway returned {response.status_code}", - "policy_version": 0, - } + if 500 <= response.status_code < 600: + last_status = response.status_code + time.sleep(0.1) + continue + if response.status_code == 401: + # Auth errors are NEVER silenced by on_transport_error — + # they indicate a real credential problem and must + # propagate so the SDK re-checks the key on next init. + raise NullRunAuthenticationError( + f"Auth failed with status {response.status_code}. " + f"API key may be invalid or expired." + ) + # 4xx (non-auth): a real gateway decision, not a + # transport failure. Return a block so the caller sees + # the gateway's verdict. + return { + "decision": "block", + "decision_source": DecisionSource.FALLBACK, + "explanation": f"Gateway returned {response.status_code}", + "policy_version": 0, + } - except BreakerTransportError: - pass # Will fall through to fallback mode - except NullRunAuthenticationError: - raise # Don't fall back on auth errors + # All 3 attempts exhausted. Classify and route. + if last_status is not None and 500 <= last_status < 600: + result = _handle_transport_error( + on_transport_error, + TransportErrorSource.GATEWAY_ERROR, + "execute", + f"gateway returned {last_status} after 3 attempts", + ) + if result is not None: + return result + elif last_exc is not None: + result = _handle_transport_error( + on_transport_error, + TransportErrorSource.NETWORK_ERROR, + "execute", + f"{type(last_exc).__name__}: {last_exc}", + ) + if result is not None: + return result - # All attempts failed - apply fallback mode + # The "legacy" branch falls through to the historical + # fallback_mode handling. Only reachable when the caller + # passes on_transport_error="legacy". if fallback_mode == FallbackMode.STRICT: return { "decision": "block", @@ -1027,7 +1456,11 @@ def do_gate_request() -> httpx.Response: "policy_version": 0, } - def check(self, check_request: dict[str, Any]) -> dict[str, Any]: + def check( + self, + check_request: dict[str, Any], + on_transport_error: str = "raise", + ) -> dict[str, Any]: """ Call /api/v1/gate endpoint for pre-execution budget checking. @@ -1044,6 +1477,13 @@ def check(self, check_request: dict[str, Any]) -> dict[str, Any]: - tool_name: Tool name (for tool checks) - estimated_tokens: Token count (for LLM checks) - input: Optional input data + on_transport_error: Same as `Transport.execute()`. Default + is "raise" per ADR-008 — the calling gate (e.g. + `check_workflow_budget`) is expected to wrap the + call in a `try/except NullRunTransportError` to + implement its own fail-OPEN/CLOSED policy. Use + "open" / "closed" to get a dict shape with the + classified `decision_source` instead. Returns: Dict with: @@ -1053,6 +1493,13 @@ def check(self, check_request: dict[str, Any]) -> dict[str, Any]: - projected_cost_cents: Projected cost for this operation - explanations: List of explanation strings - suggestions: List of suggestion strings + + Raises: + NullRunTransportError: On transport failure AND + `on_transport_error="raise"`. Carries `source` and + `endpoint` for the calling gate to apply its declared + fail-OPEN/CLOSED policy. + NullRunAuthenticationError: On 401 (never silenced). """ # Convert check_request to gate_request format gate_request = { @@ -1068,49 +1515,228 @@ def check(self, check_request: dict[str, Any]) -> dict[str, Any]: "operation_id": check_request.get("operation_id") or str(uuid.uuid4()), } - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["X-API-Key"] = self.api_key - headers["X-API-Version"] = __api_version__ + # Custom retry loop: retry on network + 5xx. Unlike + # `_retry_with_backoff` we keep the status code on + # exhaustion so the caller can route through + # `on_transport_error` with a classified source. + last_status: int | None = None + last_exc: BaseException | None = None + for _attempt in range(3): + try: + # Route through _signed_post so HMAC + W3C trace + # context are applied automatically (Phase 1). + response = self._signed_post( + "/api/v1/gate", + gate_request, + timeout=5.0, + ) + except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.WriteTimeout, httpx.PoolTimeout, httpx.RequestError) as exc: + last_exc = exc + time.sleep(0.1) + continue + if 200 <= response.status_code < 300: + return response.json() # type: ignore[no-any-return] + if 500 <= response.status_code < 600: + last_status = response.status_code + time.sleep(0.1) + continue + if response.status_code == 401: + raise NullRunAuthenticationError( + f"Auth failed with status {response.status_code}. " + f"API key may be invalid or expired." + ) + # 4xx (non-auth): a real gateway decision, not a + # transport failure. + return { + "decision": "block", + "reservation_id": None, + "remaining_budget_cents": 0, + "projected_cost_cents": 0, + "explanations": [f"Gate endpoint returned {response.status_code}"], + "suggestions": ["Check API availability"], + } + + # All 3 attempts exhausted. Classify and route. + if last_status is not None and 500 <= last_status < 600: + result = _handle_transport_error( + on_transport_error, + TransportErrorSource.GATEWAY_ERROR, + "check", + f"gateway returned {last_status} after 3 attempts", + ) + if result is not None: + return _shape_check_result(result) + elif last_exc is not None: + result = _handle_transport_error( + on_transport_error, + TransportErrorSource.NETWORK_ERROR, + "check", + f"{type(last_exc).__name__}: {last_exc}", + ) + if result is not None: + return _shape_check_result(result) + # Legacy path: transport failure with mode="legacy" — fall + # through to the historical fail-CLOSED block return so + # callers that haven't opted into the new contract still + # see the same shape they did before. + return { + "decision": "block", + "reservation_id": None, + "remaining_budget_cents": 0, + "projected_cost_cents": 0, + "explanations": ["Gateway unavailable (legacy mode)"], + "suggestions": ["Check API availability"], + } - # Add HMAC signature headers - body = json.dumps(gate_request) - self._add_hmac_headers(headers, body) - # Inject trace context for distributed tracing (W3C Trace Context) - self._inject_trace_context(headers) +def _shape_check_result(result: dict[str, Any]) -> dict[str, Any]: + """ + Adapt a result dict from `execute()` semantics (`decision` / + `decision_source` / `explanation` / `policy_version`) to the + `check()` shape (`decision` / `reservation_id` / + `remaining_budget_cents` / `projected_cost_cents` / + `explanations` / `suggestions`). + + Used when `check()` is called with `on_transport_error="open"` + or `"closed"` so the result still matches the calling gate's + expectations. + """ + explanation = result.get("explanation", "Gateway unavailable") + return { + "decision": result.get("decision", "block"), + "reservation_id": None, + "remaining_budget_cents": 0, + "projected_cost_cents": 0, + "explanations": [explanation], + "suggestions": ["Check API availability"], + } + + # ============================================================================= + # Evaluate — pre-validation / "what if" (no execution) + # ============================================================================= + + def evaluate( + self, + organization_id: str, + execution_id: str | None, + trace_id: str, + tool: str, + context: dict[str, Any] | None = None, + on_transport_error: str = "raise", + ) -> dict[str, Any]: + """ + Dry-run / pre-validation against the gateway. + + POSTs to `/api/v1/evaluate` with the same envelope as + `execute()` so the call goes through the SDK's own + connection pool, HMAC headers, circuit breaker, and retry + policy. The gateway returns a decision + matched-rule report + without any side effect (no execution, no state change). + + The previous implementation in `runtime.evaluate()` reached + into `self._client` directly, which silently bypassed the + circuit breaker — a production hazard on a long-lived + runtime. This method is the public surface for that call. + + Args: + organization_id: Organization identifier. + execution_id: Optional workflow id. May be None when + the user has not opened a `with workflow(...)` block + — the gateway tolerates null. + trace_id: Trace id for cross-system correlation. + tool: Tool name to evaluate. + context: Optional per-tool context dict forwarded to + the gateway as `context` (kept separate from the + `input` field used by `execute()`). + on_transport_error: Same contract as `execute()` / + `check()`. Default is `"raise"`. + + Returns: + Dict with the gateway's evaluate response (decision, + decision_source, explanation, policy_version, + matched_rules, scores, …). Shape is gateway-defined. + + Raises: + NullRunTransportError: When the transport fails AND + `on_transport_error="raise"`. + NullRunAuthenticationError: On 401 (never silenced). + """ + eval_request = { + "organization_id": organization_id, + "execution_id": execution_id, + "trace_id": trace_id, + "tool": tool, + "context": context or {}, + } try: - response = self._client.post( - f"{self.api_url}/api/v1/gate", - json=gate_request, - headers=headers, + # Route through _signed_post so HMAC + W3C trace context + # are applied automatically (Phase 1). + response = self._signed_post( + "/api/v1/evaluate", + eval_request, timeout=5.0, ) - - if response.status_code == 200: - return response.json() # type: ignore[no-any-return] - else: - # Return block decision on error - return { - "decision": "block", - "reservation_id": None, - "remaining_budget_cents": 0, - "projected_cost_cents": 0, - "explanations": [f"Gate endpoint returned {response.status_code}"], - "suggestions": ["Check API availability"], - } - except Exception as e: - logger.warning(f"Gate request failed: {e}") + except NullRunAuthenticationError: + raise + except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.WriteTimeout, httpx.PoolTimeout, httpx.RequestError) as exc: + result = _handle_transport_error( + on_transport_error, + TransportErrorSource.NETWORK_ERROR, + "evaluate", + str(exc), + ) + if result is not None: + return result + # legacy: fall through to the historical block fallback return { "decision": "block", - "reservation_id": None, - "remaining_budget_cents": 0, - "projected_cost_cents": 0, - "explanations": [f"Gate request failed: {e}"], - "suggestions": ["Check API availability"], + "decision_source": DecisionSource.FALLBACK, + "explanation": "Evaluation endpoint unavailable (legacy mode)", + "policy_version": 0, + "matched_rules": [], + "scores": {}, } + if response.status_code == 200: + return response.json() # type: ignore[no-any-return] + if response.status_code == 401: + raise NullRunAuthenticationError( + f"Auth failed with status {response.status_code}. " + f"API key may be invalid or expired." + ) + if 500 <= response.status_code < 600: + result = _handle_transport_error( + on_transport_error, + TransportErrorSource.GATEWAY_ERROR, + "evaluate", + f"gateway returned {response.status_code}", + ) + if result is not None: + # The result is shaped for `execute()`. Re-shape to + # the evaluate dict (matched_rules / scores). + return { + "decision": result.get("decision", "block"), + "decision_source": result.get( + "decision_source", DecisionSource.FALLBACK + ), + "explanation": result.get( + "explanation", "Evaluation endpoint unavailable" + ), + "policy_version": result.get("policy_version", 0), + "matched_rules": [], + "scores": {}, + } + # 4xx (non-auth): treat as a real gateway decision (block). + return { + "decision": "block", + "decision_source": DecisionSource.FALLBACK, + "explanation": f"Evaluation endpoint returned {response.status_code}", + "policy_version": 0, + "matched_rules": [], + "scores": {}, + } + # ============================================================================= # WebSocket Connection (Task 6 - WebSocket Push) # ============================================================================= @@ -1118,7 +1744,7 @@ def check(self, check_request: dict[str, Any]) -> dict[str, Any]: def clear_policy_cache(self) -> None: """Clear the policy cache, forcing next gate/execute to fetch fresh policy.""" if hasattr(self, '_policy_cache'): - self._policy_cache._cache.clear() + self._policy_cache.clear() logger.debug("Policy cache cleared") async def connect_websocket( @@ -1193,13 +1819,15 @@ async def _refetch_credentials(self) -> None: This is called when the server notifies us via WebSocket that our HMAC secret_key has been rotated. We need to get the new secret_key from the /auth/verify endpoint. + + Uses the SDK's own httpx client (already pooled, mTLS-aware) + so we don't add a `requests` dependency for a single call. """ try: - import requests - response = requests.post( + response = self._client.post( f"{self.api_url}/auth/verify", json={"api_key": self.api_key}, - timeout=10, + timeout=10.0, ) if response.status_code == 200: data = response.json() @@ -1213,640 +1841,3 @@ async def _refetch_credentials(self) -> None: logger.warning(f"Failed to refetch credentials: {response.status_code}") except Exception as e: logger.error(f"Error refetching credentials: {e}") - - -class AsyncTransport: - """ - Async HTTP transport with batching support. - - For use with asyncio-based applications. - """ - - def __init__( - self, - api_url: str, - api_key: str | None = None, - secret_key: str | None = None, - config: FlushConfig | None = None, - redis_client: Any = None, - pool_config: PoolConfig | None = None, - ): - self.api_url = api_url.rstrip("/") - self.api_key = api_key - self.secret_key = secret_key # HMAC signing key - self.config = config or FlushConfig() - self._pool_config = pool_config or PoolConfig() - self._pool = AdaptivePool(self._pool_config) - self._buffer: list[dict[str, Any]] = [] - self._in_flight: dict[str, dict[str, Any]] = {} # event_id -> event for retry dedup - self._lock = asyncio.Lock() - self._client: httpx.AsyncClient | None = None - self._flush_task: asyncio.Task | None = None - self._running = False - self._redis_client = redis_client - self._circuit_breaker = CircuitBreaker( - failure_threshold=self.config.max_failed_flush, - recovery_timeout=30.0, - redis_client=redis_client, - name="async_transport", - ) - self._last_retry_after_ms = 0.0 # P0: Store last retry_after for smart backoff - self._last_failure_policy_limit = False # P0: Track if last failure was policy limit - self._last_retry_after_seconds = 0.0 # Honor Retry-After from backend (429 response) - self._policy_cache = PolicyCache( - maxsize=1000, - ttl_seconds=300.0, - ) - - # OpenTelemetry tracer initialization (lazy - only if opentelemetry is installed) - self._tracer = None - self._propagator = None - if _OTEL_AVAILABLE: - self._tracer = trace.get_tracer("nullrun.async_transport") - self._propagator = TraceContextTextMapPropagator() - - def _persist_to_wal(self) -> None: - """Persist unflushed events to WAL file for replay on restart.""" - if not self._buffer: - return - event_count = len(self._buffer) - wal_path = os.path.join(os.getcwd(), ".nullrun.wal") - with open(wal_path, "a") as f: - for event in self._buffer: - f.write(json.dumps(event) + "\n") - self._buffer.clear() - logger.debug(f"Persisted {event_count} events to WAL at {wal_path}") - - async def _replay_from_wal_async(self) -> None: - """Replay events from WAL file on startup (async version).""" - wal_path = os.path.join(os.getcwd(), ".nullrun.wal") - if not os.path.exists(wal_path): - return - events = [] - with open(wal_path, "r") as f: - for line in f: - try: - events.append(json.loads(line.strip())) - except json.JSONDecodeError: - continue - if events: - self._buffer.extend(events) - await self._flush() - os.remove(wal_path) # Clean up WAL after successful replay - logger.info(f"Replayed {len(events)} events from WAL") - - async def track(self, event: dict[str, Any]) -> None: - """Add event to buffer. Non-blocking.""" - async with self._lock: - # Generate event_id if not provided - if "event_id" not in event or not event["event_id"]: - event["event_id"] = str(uuid.uuid4()) - - # Store in-flight for retry dedup - self._in_flight[event["event_id"]] = event - - self._buffer.append(event) - metrics.inc_transport("events_enqueued") - if len(self._buffer) >= self.config.batch_size: - await self._flush_locked() - - async def start(self) -> None: - """Start background flush task.""" - if self._running: - return - # Replay any events from WAL that were persisted due to previous crash - await self._replay_from_wal_async() - self._running = True - # Configure httpx.AsyncClient with adaptive pool limits - self._client = httpx.AsyncClient( - timeout=httpx.Timeout( - connect=5.0, - read=30.0, - write=10.0, - pool=self._pool_config.acquire_timeout, - ), - verify=True, - limits=httpx.Limits( - max_connections=self._pool_config.max_connections, - max_keepalive_connections=self._pool_config.max_keepalive, - keepalive_expiry=self._pool_config.idle_timeout, - ), - ) - self._flush_task = asyncio.create_task(self._flush_loop()) - logger.info( - f"AsyncTransport started with pool config: " - f"max_connections={self._pool_config.max_connections}, " - f"max_keepalive={self._pool_config.max_keepalive}" - ) - - async def stop(self, timeout: float = 10.0) -> None: - """Stop background flush task and flush remaining events.""" - self._running = False - if self._flush_task: - self._flush_task.cancel() - try: - await asyncio.wait_for(self._flush_task, timeout=timeout) - except asyncio.TimeoutError: - logger.warning("Flush task did not complete within timeout, proceeding with shutdown") - except asyncio.CancelledError: - pass - await self._flush() - self._persist_to_wal() # WAL any remaining events - if self._client: - await self._client.aclose() - logger.info("AsyncTransport stopped") - - async def _flush_loop(self) -> None: - """Background loop that periodically flushes.""" - while self._running: - await asyncio.sleep(self.config.flush_interval) - if self._running: - # Check if we should scale up the pool based on demand - await self._pool.scale_up_if_needed() - await self._flush() - - async def _flush(self) -> None: - """Perform the actual flush.""" - async with self._lock: - await self._flush_locked() - - async def _flush_locked(self) -> None: - """Flush under lock. Must be called with _lock held.""" - if not self._buffer: - return - - batch = self._buffer[:] - self._buffer.clear() - - # Circuit breaker wrapped async send with pool backpressure - async def send_batch(): - # Acquire from adaptive pool with backpressure - acquired = await self._pool.acquire() - if not acquired: - # Pool exhausted - apply backpressure - backoff = self._calculate_backoff() - logger.warning( - f"Pool exhausted during flush, backing off {backoff:.2f}s " - f"for batch of {len(batch)} events" - ) - # Re-add entire batch to buffer for retry - self._buffer.extend(batch) - metrics.inc_transport("pool_backpressure_events", len(batch)) - # Return a mock response that will trigger circuit breaker to re-queue - raise BreakerTransportError(f"Pool exhausted, batch of {len(batch)} re-queued") - - try: - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["X-API-Key"] = self.api_key - headers["X-API-Version"] = __api_version__ - - # Add HMAC signature headers - body = json.dumps({"events": batch}) - if self.secret_key and self.api_key: - timestamp = int(time.time()) - signature = generate_hmac_signature( - self.api_key, - self.secret_key, - timestamp, - body, - ) - headers["X-Signature-Timestamp"] = str(timestamp) - headers["X-Signature"] = signature - - # Inject trace context for distributed tracing (W3C Trace Context) - await self._inject_trace_context(headers) - - response = await self._client.post( - f"{self.api_url}/api/v1/track/batch", - json={"events": batch}, - headers=headers, - ) - - # Extract retry info - retry_after_seconds = self._extract_retry_after(response) - is_policy_limit = self._is_policy_limit_response(response) - self._last_retry_after_seconds = retry_after_seconds or 0.0 - self._last_failure_policy_limit = is_policy_limit - - # Process actions_taken from server response - try: - data = response.json() - actions = data.get("actions_taken", []) - for action in actions: - action_type = action.get("type", "") - workflow_id = action.get("workflow_id", "unknown") - reason = action.get("reason", "") - if action_type: - handle_action(action_type, workflow_id, reason) - - # Remove accepted events from in-flight - accepted_event_ids = data.get("accepted_event_ids", []) - for event in batch: - if event.get("event_id") in accepted_event_ids: - self._in_flight.pop(event.get("event_id"), None) - except Exception as e: - logger.warning(f"Failed to process actions_taken: {e}") - - logger.debug(f"Batch track: sent {len(batch)} events") - # Update metrics on successful flush (thread-safe) - metrics.inc_transport("batches_sent") - metrics.inc_transport("events_sent", len(batch)) - metrics.set_transport("last_flush_at", time.monotonic()) - return response - finally: - self._pool.release() - - try: - await self._circuit_breaker.call(send_batch) - except BreakerTransportError: - # Circuit breaker is open - re-add batch to buffer for retry later - logger.warning( - f"Circuit breaker OPEN. Batch of {len(batch)} events will be re-queued." - ) - # Enforce max buffer size BEFORE re-queue to prevent unbounded growth - # Drop oldest events first to make room for new batch - available_space = self.config.max_buffer_size - len(self._buffer) - if available_space < len(batch): - overflow = len(batch) - available_space - if overflow > 0: - # Drop oldest from front (batch) since it hasn't been sent yet - logger.warning(f"Buffer overflow on CB OPEN: dropping {overflow} oldest events from pending batch") - batch = batch[overflow:] # type: ignore[assignment] - metrics.inc_transport("events_dropped", overflow) - # Append to END (not front) so oldest events are retried first - self._buffer.extend(batch) - # Update metrics on failure (thread-safe) - metrics.inc_transport("batches_failed") - - # Enforce max buffer size for any remaining overflow - if len(self._buffer) > self.config.max_buffer_size: - overflow = len(self._buffer) - self.config.max_buffer_size - logger.warning(f"Buffer overflow: dropping {overflow} oldest events") - self._buffer = self._buffer[overflow:] # type: ignore[assignment] - metrics.inc_transport("events_dropped", overflow) - - def _extract_retry_after(self, response: httpx.Response) -> float | None: - """Extract Retry-After header value as seconds. - - Handles both: - - Integer seconds (e.g., "30") - - HTTP-date format (e.g., "Wed, 21 Oct 2015 07:28:00 GMT") - - Returns seconds (not ms) to align with _last_retry_after_seconds. - """ - retry_after = response.headers.get("Retry-After") - if not retry_after: - return None - - # Try parsing as seconds (integer or float) - try: - return float(retry_after) - except ValueError: - pass - - # Try parsing as HTTP datetime (RFC 7231) - try: - from email.utils import parsedate_to_datetime - dt = parsedate_to_datetime(retry_after) - from datetime import datetime, timezone - return (dt - datetime.now(timezone.utc)).total_seconds() - except Exception: - pass - - return None - - def _is_policy_limit_response(self, response: httpx.Response) -> bool: - """Check if response indicates policy limit failure.""" - if response.status_code == 429: - try: - data = response.json() - if 'rejected' in data and data['rejected']: - rejected_info = data['rejected'] - if ( - isinstance(rejected_info, dict) and - rejected_info.get('reason') == 'policy_limit' - ): - return True - except Exception: - logger.debug("Non-JSON response, skipping parse") - return False - - def _calculate_backoff(self) -> float: - """Calculate backoff delay based on retry info and jitter. - - Uses exponential backoff with jitter for retry handling. - Honors Retry-After header from backend (in seconds) when available. - """ - base_delay = 0.5 - max_delay = 30.0 - backoff_factor = 2.0 - jitter = 0.1 - - # Honor Retry-After from backend if present (from 429 response) - if self._last_retry_after_seconds > 0: - delay = min(self._last_retry_after_seconds, max_delay) - # Add small jitter to prevent thundering herd when many clients - # have the same Retry-After value - jitter_amount = delay * jitter - delay = delay + random.uniform(-jitter_amount, jitter_amount) - delay = max(0.0, delay) - # Reset after use - next retry uses exponential backoff - self._last_retry_after_seconds = 0.0 - else: - delay = base_delay - - return delay - - async def _inject_trace_context(self, headers: dict[str, str]) -> None: - """ - Inject trace context into request headers (W3C Trace Context format). - - This enables distributed tracing across SDK and backend. - Uses W3C Trace Context standard for trace_id propagation. - """ - if not _OTEL_AVAILABLE or not self._propagator: - return - - carrier: dict[str, str] = {} - self._propagator.inject(carrier) - headers.update(carrier) - - async def flush_now(self) -> None: - """Force immediate flush.""" - await self._flush() - - # ============================================================================= - # Execute (Strict Mode) - Phase 1 - # ============================================================================= - - async def execute( - self, - organization_id: str, - execution_id: str, - trace_id: str, - tool: str, - input_data: dict[str, Any], - mode: str = "auto", - fallback_mode: str = FallbackMode.PERMISSIVE, - operation_id: str | None = None, - ) -> dict[str, Any]: - """ - Pre-execution policy evaluation via unified gate endpoint. - - Uses /api/v1/gate endpoint for unified execute + check functionality. - - Args: - organization_id: Organization identifier - execution_id: Execution identifier - trace_id: Distributed trace ID - tool: Tool to execute - input_data: Tool input - mode: Execution mode ("auto", "inline", "strict") - fallback_mode: What to do if Gateway unavailable - operation_id: Optional idempotency key - - Returns: - Dict with: - - decision: "allow" | "block" | "flag" | "pause" | "require_approval" - - decision_source: "gateway" | "cached" | "fallback" - - explanation: Human-readable explanation - - policy_version: Policy version used - - decision_context: Context for replay (if available) - """ - if not self._client: - self._client = httpx.AsyncClient( - timeout=httpx.Timeout( - connect=5.0, - read=30.0, - write=10.0, - pool=self._pool_config.acquire_timeout, - ), - verify=True, - limits=httpx.Limits( - max_connections=self._pool_config.max_connections, - max_keepalive_connections=self._pool_config.max_keepalive, - keepalive_expiry=self._pool_config.idle_timeout, - ), - ) - - gate_request = { - "organization_id": organization_id, - "execution_id": execution_id, - "trace_id": trace_id, - "tool": tool, - "input": input_data, - "mode": mode, - "operation_id": operation_id or str(uuid.uuid4()), - } - - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["X-API-Key"] = self.api_key - headers["X-API-Version"] = __api_version__ - - # Add HMAC signature headers - body = json.dumps(gate_request) - if self.secret_key and self.api_key: - timestamp = int(time.time()) - signature = generate_hmac_signature( - self.api_key, - self.secret_key, - timestamp, - body, - ) - headers["X-Signature-Timestamp"] = str(timestamp) - headers["X-Signature"] = signature - - # Inject trace context for distributed tracing (W3C Trace Context) - await self._inject_trace_context(headers) - - # Try Gateway - for attempt in range(2): - try: - response = await self._client.post( - f"{self.api_url}/api/v1/gate", - json=gate_request, - headers=headers, - timeout=5.0, - ) - - if response.status_code == 200: - data = response.json() - data["decision_source"] = DecisionSource.GATEWAY - # Cache successful decision for CACHED mode - cache_key = self._policy_cache.make_key( - organization_id, - data.get("policy_version") - ) - self._policy_cache.set( - cache_key, - data.get("decision", "allow"), - data.get("policy_id"), - data.get("policy_version") - ) - return data # type: ignore[no-any-return] - elif response.status_code >= 500: - # Gateway error - try fallback - logger.warning(f"Gateway returned {response.status_code}, trying fallback") - continue - else: - # 4xx - don't retry, return block - return { - "decision": "block", - "decision_source": DecisionSource.FALLBACK, - "explanation": f"Gateway returned {response.status_code}", - "policy_version": 0, - } - except Exception as e: - logger.warning(f"Execute attempt {attempt + 1} failed: {e}") - if attempt < 1: - await asyncio.sleep(0.5) - - # All attempts failed - apply fallback mode - if fallback_mode == FallbackMode.STRICT: - return { - "decision": "block", - "decision_source": DecisionSource.FALLBACK, - "explanation": "Gateway unavailable, fallback=STRICT", - "policy_version": 0, - } - elif fallback_mode == FallbackMode.CACHED: - # Use cached decision if available - cache_key = self._policy_cache.make_key(organization_id) - cached = self._policy_cache.get(cache_key) - if cached: - logger.warning("Gateway unreachable, using cached decision for %s", tool) - return { - "decision": cached.decision, - "decision_source": DecisionSource.CACHED, - "explanation": "Gateway unavailable, using cached decision", - "policy_version": int(cached.ttl_seconds) if cached.ttl_seconds > 0 else 0, - } - else: - logger.warning( - "Gateway unreachable, no cache for %s, " - "falling back to PERMISSIVE", - tool - ) - return { - "decision": "allow", - "decision_source": DecisionSource.FALLBACK, - "explanation": "Gateway unavailable, no cache available", - "policy_version": 0, - } - else: # PERMISSIVE (default) - return { - "decision": "allow", - "decision_source": DecisionSource.FALLBACK, - "explanation": "Gateway unavailable, fallback=PERMISSIVE", - "policy_version": 0, - } - - async def check(self, check_request: dict[str, Any]) -> dict[str, Any]: - """ - Call /api/v1/gate endpoint for pre-execution budget checking. - - Uses the unified gate endpoint with check_type for budget validation. - Async version for asyncio-based applications. - - Args: - check_request: Dict with: - - organization_id: Organization identifier - - execution_id: Execution identifier - - operation_id: Operation identifier (for idempotency) - - check_type: "llm" or "tool" - - model: Model name (for LLM checks) - - tool_name: Tool name (for tool checks) - - estimated_tokens: Token count (for LLM checks) - - input: Optional input data - - Returns: - Dict with: - - decision: "allow" | "block" | "throttle" - - reservation_id: Optional reservation ID - - remaining_budget_cents: Remaining budget - - projected_cost_cents: Projected cost for this operation - - explanations: List of explanation strings - - suggestions: List of suggestion strings - """ - if not self._client: - self._client = httpx.AsyncClient( - timeout=httpx.Timeout( - connect=5.0, - read=30.0, - write=10.0, - pool=self._pool_config.acquire_timeout, - ), - verify=True, - limits=httpx.Limits( - max_connections=self._pool_config.max_connections, - max_keepalive_connections=self._pool_config.max_keepalive, - keepalive_expiry=self._pool_config.idle_timeout, - ), - ) - - # Convert check_request to gate_request format - gate_request = { - "organization_id": check_request.get("organization_id"), - "execution_id": check_request.get("execution_id"), - "trace_id": check_request.get("trace_id", str(uuid.uuid4())), - "tool": check_request.get("tool_name") or check_request.get("tool"), - "input": check_request.get("input"), - "mode": "auto", - "check_type": check_request.get("check_type"), - "model": check_request.get("model"), - "estimated_tokens": check_request.get("estimated_tokens"), - "operation_id": check_request.get("operation_id") or str(uuid.uuid4()), - } - - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["X-API-Key"] = self.api_key - headers["X-API-Version"] = __api_version__ - - # Add HMAC signature headers - body = json.dumps(gate_request) - if self.secret_key and self.api_key: - timestamp = int(time.time()) - signature = generate_hmac_signature( - self.api_key, - self.secret_key, - timestamp, - body, - ) - headers["X-Signature-Timestamp"] = str(timestamp) - headers["X-Signature"] = signature - - # Inject trace context for distributed tracing (W3C Trace Context) - await self._inject_trace_context(headers) - - try: - response = await self._client.post( - f"{self.api_url}/api/v1/gate", - json=gate_request, - headers=headers, - timeout=5.0, - ) - - if response.status_code == 200: - return response.json() # type: ignore[no-any-return] - else: - return { - "decision": "block", - "reservation_id": None, - "remaining_budget_cents": 0, - "projected_cost_cents": 0, - "explanations": [f"Gate endpoint returned {response.status_code}"], - "suggestions": ["Check API availability"], - } - except Exception as e: - logger.warning(f"Gate request failed: {e}") - return { - "decision": "block", - "reservation_id": None, - "remaining_budget_cents": 0, - "projected_cost_cents": 0, - "explanations": [f"Gate request failed: {e}"], - "suggestions": ["Check API availability"], - } \ No newline at end of file diff --git a/src/nullrun/transport_websocket.py b/src/nullrun/transport_websocket.py index e95160b..b437abd 100644 --- a/src/nullrun/transport_websocket.py +++ b/src/nullrun/transport_websocket.py @@ -7,12 +7,13 @@ """ import asyncio +import hashlib +import hmac import json import logging import time -import hmac -import hashlib -from typing import Any, Callable +from collections.abc import Callable +from typing import Any try: import websockets @@ -146,6 +147,12 @@ def __init__( self._receive_task: asyncio.Task | None = None self._reconnect_task: asyncio.Task | None = None self._closed = False + # Per-workflow monotonic version dedup (ADR-007). + # Drop incoming state changes with `version <= last` to + # survive the at-least-once delivery semantics of the WS + # channel. Initialised to 0 (which means: always accept the + # first non-zero version seen). + self._last_version: dict[str, int] = {} async def _reconnect_loop(self) -> None: """ @@ -286,6 +293,30 @@ async def _handle_message(self, message: str) -> None: except Exception as e: logger.warning(f"Key rotation callback error: {e}") + elif msg_type == "resync_required": + # Server overflowed its broadcast channel. Per + # ADR-007 the SDK MUST close, reconnect, and replace + # its local state from the new ``initial_state`` — + # there is no "catch up" semantics. We clear the + # version-dedup cache and let ``_reconnect_loop`` + # reopen the connection. + reason = data.get("reason", "overflow") + logger.warning( + f"Server requested resync (reason={reason}); " + "clearing local state and reconnecting" + ) + self.clear_local_state() + # Mark the connection as closed so the receive loop + # unwinds; _reconnect_loop will pick up and reconnect. + self._running = False + self._closed = True + if self._conn is not None: + try: + await self._conn.close() + except Exception: # noqa: BLE001 + pass + self._conn = None + elif msg_type == "pong": # Pong response to ping - connection is alive pass @@ -332,35 +363,87 @@ async def _send_ack(self, message_id: str) -> None: Args: message_id: The message ID to acknowledge + + Phase 1: when ``self.api_key`` AND ``self.secret_key`` are set, + the ACK is HMAC-signed before transmission. The signature + envelope mirrors the HTTP request scheme + (HMAC-SHA256(secret_key, "::")) + and the server's WS handler drops tampered ACKs at the edge. + No signature is added when ``secret_key`` is empty + (preserves dev/legacy behavior). """ if not self._conn or not self._running: logger.warning("Cannot send ACK - WebSocket not connected") return try: - ack = { + ack: dict[str, Any] = { "type": "ack", "message_id": message_id, "received_at": int(time.time() * 1000), # milliseconds } - await self._conn.send(json.dumps(ack)) + payload = json.dumps(ack) + if self.api_key and self.secret_key: + # Sign with the same scheme the server uses for + # incoming messages (see backend/src/proxy/http/ws_control.rs:69-82). + from nullrun.transport import compute_hmac_signature + + timestamp = int(time.time()) + signature = compute_hmac_signature( + self.api_key, + self.secret_key, + timestamp, + payload, + ) + ack["signature"] = signature + ack["timestamp"] = timestamp + # Re-serialise with the new fields. + payload = json.dumps(ack) + await self._conn.send(payload) logger.debug(f"ACK sent for message {message_id}") except Exception as e: logger.warning(f"Failed to send ACK: {e}") def _dispatch_state(self, state: dict[str, Any]) -> None: """ - Dispatch state to callback. + Dispatch state to callback after per-workflow version dedup + (ADR-007: at-least-once delivery, drop stale events). Args: state: State dict with workflow_id, state, version, etc. """ + workflow_id = state.get("workflow_id", "") + incoming_version = state.get("version", 0) + if workflow_id: + last = self._last_version.get(workflow_id, 0) + if incoming_version <= last: + logger.debug( + f"Dropping stale state event for {workflow_id}: " + f"incoming version={incoming_version} <= last={last}" + ) + return + # Only record after we've decided to dispatch — the + # check above does not persist the version. + self._last_version[workflow_id] = incoming_version if self.on_state_change: try: self.on_state_change(state) except Exception as e: logger.warning(f"State change callback error: {e}") + def clear_local_state(self) -> None: + """ + Clear the in-memory per-workflow version cache. + + Called after a ``ResyncRequired`` event so the next + ``initial_state`` from the server is accepted (the dedup + cache may otherwise drop the server's freshest state if + the version is unchanged from the pre-overflow value). + Per ADR-007 there is no "merge" — local state is fully + replaced by the next ``initial_state``. + """ + self._last_version.clear() + async def send(self, message: dict[str, Any]) -> None: """ Send message to WebSocket server. diff --git a/tests/conftest.py b/tests/conftest.py index fd8c9db..510c1e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,12 +45,21 @@ def reset_runtime(): @pytest.fixture def mock_api(): - """Mock all HTTP calls to NullRun API.""" + """Mock all HTTP calls to NullRun API. + + The mocked auth response includes `workflow_id` so the runtime + is bound to a workflow out of the box — this matches the + Phase 139+ contract where `/auth/verify` returns the workflow + the key is bound to. Tests that need to exercise the + "no workflow" path can still set `rt.workflow_id = None` + after `make_runtime()`. + """ with respx.mock: # Auth endpoint respx.post(f"{BASE_URL}/api/v1/auth/verify").mock( return_value=Response(200, json={ "organization_id": "ws-test", + "workflow_id": "wf-test", "plan": "pro", "features": [], "limits": {"max_cost_cents": 10000}, @@ -103,8 +112,8 @@ def make_runtime(mock_api): `decorators._get_or_create_runtime`) finds the test runtime, not a fallback that would try to construct one with no api_key. """ - from nullrun.runtime import NullRunRuntime import nullrun.decorators as _dec + from nullrun.runtime import NullRunRuntime def _make(**kwargs): defaults = dict( diff --git a/tests/test_actions.py b/tests/test_actions.py index 9ebe48c..1924db9 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -35,13 +35,23 @@ def test_get_action_handler_returns_singleton(self): class TestKillAction: - """Test KILL action.""" + """Test KILL action. + + The default KILL handler intentionally raises + `WorkflowKilledInterrupt` to halt the agent. After the P0-0.5 + fix, the kill contract is restored: the exception PROPAGATES, + not swallowed. The kill is still recorded in history. + """ def test_kill_records_to_history(self): - """KILL action is recorded in history.""" + """KILL action is recorded in history AND propagates + `WorkflowKilledInterrupt` (the kill contract).""" + from nullrun.breaker.exceptions import WorkflowKilledInterrupt + handler = ActionHandler() - # handler catches exceptions internally, doesn't propagate - handler.handle(ActionType.KILL, "wf-123", "Test reason") + with pytest.raises(WorkflowKilledInterrupt): + handler.handle(ActionType.KILL, "wf-123", "Test reason") + # Action is still in history despite the exception. history = handler.get_action_history() assert len(history) == 1 assert history[0].action_type == "kill" @@ -49,27 +59,44 @@ def test_kill_records_to_history(self): class TestPauseAction: - """Test PAUSE action.""" + """Test PAUSE action. + + Same contract as KILL: the default PAUSE handler raises + `WorkflowPausedException` to halt the agent. After the P0-0.5 + fix, the pause propagates. + """ + + def test_pause_propagates_exception(self): + """PAUSE action raises `WorkflowPausedException` to halt + the agent (the pause contract). Action is still recorded.""" + from nullrun.breaker.exceptions import WorkflowPausedException - def test_pause_does_not_propagate_exception(self): - """PAUSE action is handled without propagating exception.""" handler = ActionHandler() - # Should not raise - exceptions are caught internally - handler.handle(ActionType.PAUSE, "wf-456", "Rate limit hit") - # But action should be recorded + with pytest.raises(WorkflowPausedException): + handler.handle(ActionType.PAUSE, "wf-456", "Rate limit hit") + # Action is still in history despite the exception. history = handler.get_action_history() assert len(history) == 1 def test_pause_tracks_workflow(self): """PAUSE action tracks workflow in paused_workflows.""" + from nullrun.breaker.exceptions import WorkflowPausedException + handler = ActionHandler() - handler.handle(ActionType.PAUSE, "wf-789", "Test pause") + with pytest.raises(WorkflowPausedException): + handler.handle(ActionType.PAUSE, "wf-789", "Test pause") + # Workflow is registered as paused even though the + # exception propagated. assert handler.is_paused("wf-789") def test_is_paused_respects_cooldown(self): """is_paused respects cooldown_seconds.""" + from nullrun.breaker.exceptions import WorkflowPausedException + handler = ActionHandler() - handler.handle(ActionType.PAUSE, "wf-cooldown", "Test") + # After P0-0.5, PAUSE propagates WorkflowPausedException. + with pytest.raises(WorkflowPausedException): + handler.handle(ActionType.PAUSE, "wf-cooldown", "Test") # Within cooldown assert handler.is_paused("wf-cooldown", cooldown_seconds=60.0) # After cooldown @@ -255,4 +282,95 @@ def test_block_does_not_propagate_exception(self): handler.handle(ActionType.BLOCK, "wf-block", "Policy violation") # But action should be recorded history = handler.get_action_history() - assert len(history) == 1 \ No newline at end of file + assert len(history) == 1 + + +# ───────────────────────────────────────────────────────────────────── +# Phase 0, Epic 0.5: kill/pause handlers must PROPAGATE their +# BaseException subclasses (`WorkflowKilledInterrupt`, +# `WorkflowPausedException`). The pre-fix code caught +# `BaseException` and silently swallowed the kill/pause signal, +# breaking the kill contract (see `docs/kill-contract.md` §1). +# ───────────────────────────────────────────────────────────────────── + + +class TestActionHandlerKillContract: + """The default KILL handler intentionally raises + `WorkflowKilledInterrupt` (a `BaseException` subclass) so that + the calling agent halts. The pre-fix `except BaseException` in + `ActionHandler.handle` silently swallowed it. The fix catches + only `Exception`.""" + + def test_kill_handler_propagates_workflowkilledinterrupt(self): + """`handle("kill", ...)` with the default KILL handler must + raise `WorkflowKilledInterrupt` so the agent halts.""" + from nullrun.breaker.exceptions import WorkflowKilledInterrupt + + handler = ActionHandler() + # The default `_default_kill` raises WorkflowKilledInterrupt. + with pytest.raises(WorkflowKilledInterrupt): + handler.handle(ActionType.KILL, "wf-kill", "test reason") + + def test_pause_handler_propagates_workflowpausedexception(self): + """`handle("pause", ...)` with the default PAUSE handler must + raise `WorkflowPausedException` so the agent halts (a pause + is a non-resumable-by-the-agent signal).""" + from nullrun.breaker.exceptions import WorkflowPausedException + + handler = ActionHandler() + with pytest.raises(WorkflowPausedException): + handler.handle(ActionType.PAUSE, "wf-pause", "test reason") + + def test_kill_action_is_recorded_before_propagating(self): + """The KILL action must be in the history (so the operator + can see it was dispatched) BEFORE the exception propagates.""" + from nullrun.breaker.exceptions import WorkflowKilledInterrupt + + handler = ActionHandler() + with pytest.raises(WorkflowKilledInterrupt): + handler.handle(ActionType.KILL, "wf-kill-history", "test reason") + history = handler.get_action_history() + assert len(history) == 1 + assert history[0].action_type == "kill" + assert history[0].workflow_id == "wf-kill-history" + + def test_arbitrary_exception_in_handler_is_logged_not_propagated(self): + """A user-registered handler that raises an `Exception` + subclass (not a `BaseException` subclass) must NOT + propagate. The pre-fix code happened to swallow these too + (via the broad `except BaseException`), so the public + contract is preserved.""" + handler = ActionHandler() + + def broken_handler(workflow_id: str, reason: str, **_details) -> None: + raise ValueError("user bug") + + handler.register_handler(ActionType.ALERT, broken_handler) + # Should NOT raise — ValueError is caught and logged. + handler.handle(ActionType.ALERT, "wf-broken", "test alert") + + def test_keyboard_interrupt_propagates(self): + """`KeyboardInterrupt` is a `BaseException` subclass. It + MUST propagate (the previous `except BaseException` caught + it). This test pins the contract: user Ctrl-C is never + swallowed.""" + handler = ActionHandler() + + def ctrl_c_handler(workflow_id: str, reason: str, **_details) -> None: + raise KeyboardInterrupt() + + handler.register_handler(ActionType.ALERT, ctrl_c_handler) + with pytest.raises(KeyboardInterrupt): + handler.handle(ActionType.ALERT, "wf-ctrlc", "test alert") + + def test_system_exit_propagates(self): + """`SystemExit` is a `BaseException` subclass. It MUST + propagate (so the Python interpreter can shut down cleanly).""" + handler = ActionHandler() + + def sys_exit_handler(workflow_id: str, reason: str, **_details) -> None: + raise SystemExit(1) + + handler.register_handler(ActionType.ALERT, sys_exit_handler) + with pytest.raises(SystemExit): + handler.handle(ActionType.ALERT, "wf-sysexit", "test alert") \ No newline at end of file diff --git a/tests/test_e2e_observation.py b/tests/test_e2e_observation.py index 5d7f370..99e11ee 100644 --- a/tests/test_e2e_observation.py +++ b/tests/test_e2e_observation.py @@ -29,7 +29,6 @@ import nullrun - E2E_BASE_URL = os.environ.get("NULLRUN_E2E_BASE_URL") E2E_API_KEY = os.environ.get("NULLRUN_E2E_API_KEY") E2E_ORG_ID = os.environ.get("NULLRUN_E2E_ORG_ID", "org-e2e") @@ -153,6 +152,6 @@ def test_e2e_openai_call_lands_in_backend(e2e_workflow_id: str) -> None: break time.sleep(0.5) - assert wf is not None, f"openai call did not land in /usage within 10s" + assert wf is not None, "openai call did not land in /usage within 10s" assert wf.get("calls", 0) >= 1 assert wf.get("tokens", 0) > 0, f"expected non-zero tokens, got {wf!r}" diff --git a/tests/test_kill_contract.py b/tests/test_kill_contract.py index 3a2c80d..6e25237 100644 --- a/tests/test_kill_contract.py +++ b/tests/test_kill_contract.py @@ -1,6 +1,12 @@ """Smoke test for the kill contract exception classes. -Run from sdk-python/ root: python tests/test_kill_contract.py +Run from the repository root (e.g. `nullrun-sdk-python/`): + python tests/test_kill_contract.py + +Note: this file is a manual smoke test (it has an `if __name__ == +"__main__"` driver and is NOT a pytest test). The real contract is +also covered by `test_blocked_exception.py` and the +`test_kill_contract_*` tests in this directory. """ import sys import warnings diff --git a/tests/test_preflight_fail_policy.py b/tests/test_preflight_fail_policy.py index 3c5fe54..e61109e 100644 --- a/tests/test_preflight_fail_policy.py +++ b/tests/test_preflight_fail_policy.py @@ -23,9 +23,6 @@ on `transport.execute` / `transport.check` and the new `NullRunTransportError` / `TransportErrorSource` exception pair. """ -import os -import asyncio -from typing import List import httpx import pytest @@ -33,14 +30,11 @@ import nullrun from nullrun.breaker.exceptions import ( - BreakerTransportError, NullRunBlockedException, NullRunTransportError, TransportErrorSource, WorkflowKilledInterrupt, ) -from nullrun.decorators import reset as reset_decorator_runtime -from nullrun.runtime import NullRunRuntime # Base URL used in tests BASE_URL = "https://api.test.nullrun.io" @@ -64,12 +58,12 @@ class _RecordingRuntime: """ def __init__(self) -> None: - self.events: List[dict] = [] + self.events: list[dict] = [] self._remote_states: dict = {} self._sensitive_tools: set = set() self._strict_mode_tools: set = set() # Order of gate calls recorded by `_record_gate` below - self.gate_calls: List[str] = [] + self.gate_calls: list[str] = [] def is_sensitive_tool(self, tool_name: str) -> bool: return tool_name in self._sensitive_tools @@ -283,9 +277,6 @@ def test_defense_in_depth_fallback_source_fails_closed( Simulated by injecting a runtime that returns the synthetic-allow result directly (bypassing transport).""" # Build a runtime that returns a FALLBACK_* decision - from nullrun.breaker.exceptions import ( - NullRunBlockedException as _Blocked, - ) rt = make_runtime() rt.add_sensitive_tool("charge_card") # Override execute to return a synthetic allow with @@ -333,15 +324,31 @@ def test_real_block_still_honored( error) must STILL raise NullRunBlockedException. The fail-CLOSED rule applies to *both* transport failure and real policy blocks — the opt-out is scoped to transport - errors only.""" - respx.post(f"{BASE_URL}/api/v1/gate").mock( - return_value=httpx.Response(200, json={ + errors only. + + The mock distinguishes the two gate calls: the budget + pre-check (`/check`) returns `allow` (so the request flows + through to the sensitive-tool gate), and the + sensitive-tool pre-check (`/execute`) returns `block`. + Both hit `/api/v1/gate` — we discriminate by the + `check_type` field in the request body. + """ + import json + + def _gate_router(request): + payload = json.loads(request.content.decode("utf-8")) + if payload.get("check_type") in ("llm", "tool"): + # budget pre-check — let the request through + return httpx.Response(200, json={"decision": "allow"}) + # sensitive-tool pre-check — block + return httpx.Response(200, json={ "decision": "block", "explanation": "blocked by policy", "decision_source": "gateway", "policy_version": 1, }) - ) + + respx.post(f"{BASE_URL}/api/v1/gate").mock(side_effect=_gate_router) rt, charge_card, calls = self._build_protected_sensitive_tool( mock_api, make_runtime ) diff --git a/tests/test_protect.py b/tests/test_protect.py index a13a40c..f3d256d 100644 --- a/tests/test_protect.py +++ b/tests/test_protect.py @@ -13,15 +13,12 @@ "tolerate a noop runtime" behavior is no longer relevant. """ import asyncio -from typing import List import pytest import nullrun -from nullrun.decorators import reset as reset_decorator_runtime from nullrun.tracing import get_current_span, reset_span, set_span - # ────────────────────────────────────────────────────────────── # Fixtures # ────────────────────────────────────────────────────────────── @@ -46,7 +43,7 @@ class _RecordingRuntime: """ def __init__(self) -> None: - self.events: List[dict] = [] + self.events: list[dict] = [] def track_event(self, event_type: str, **kwargs) -> None: self.events.append({"type": event_type, **kwargs}) diff --git a/tests/test_real_e2e_observation.py b/tests/test_real_e2e_observation.py index 800d497..060085e 100644 --- a/tests/test_real_e2e_observation.py +++ b/tests/test_real_e2e_observation.py @@ -36,7 +36,6 @@ from nullrun.instrumentation import auto as _auto from nullrun.instrumentation.auto import PROVIDER_EXTRACTORS, _openai_extractor - # --------------------------------------------------------------------------- # Mock LLM + NULLRUN backend (one server, two routes) # --------------------------------------------------------------------------- @@ -60,7 +59,6 @@ def __init__(self) -> None: received_events = received llm_request_event = threading.Event() - server = self class Handler(BaseHTTPRequestHandler): # Silence the default stderr access logs — they pollute test output. diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 18f7da9..89f5a16 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -3,6 +3,7 @@ Зависимости: pip install pytest pytest-asyncio respx httpx """ import asyncio +from unittest.mock import patch import httpx import pytest @@ -29,9 +30,13 @@ def test_creates_with_explicit_params(self, make_runtime): assert rt is not None def test_reads_api_key_from_env(self, monkeypatch, make_runtime): + """api_key can be supplied via NULLRUN_API_KEY env var; the + runtime picks it up on construction. (The previous version + of this test also set NULLRUN_WORKSPACE_ID, but that env var + is unused in 0.3.0 — the organization id is returned by + `/auth/verify` and stored on `self.organization_id`.)""" monkeypatch.setenv("NULLRUN_API_KEY", "env-key-12345678") monkeypatch.setenv("NULLRUN_API_URL", "https://api.test.nullrun.io") - monkeypatch.setenv("NULLRUN_WORKSPACE_ID", "ws-env") rt = make_runtime() assert rt is not None @@ -381,4 +386,92 @@ def test_runtime_singleton_reset_clears_instance(self, mock_api, monkeypatch): rt2 = NullRunRuntime.get_instance() # rt2 might be the same as rt1 if environment is same # but at minimum reset_instance should have been called - assert rt2 is not None \ No newline at end of file + assert rt2 is not None + + +# ───────────────────────────────────────────────────────────────────── +# Phase 0, Epic 0.4: track() must emit exactly once (no double-bill). +# The pre-fix code had an `else` branch in `track()` that called +# `self._transport.track(...)` twice in the HTTP path (after the +# gRPC deletion already removed the gRPC-specific double-emit, this +# was the remaining hazard). The contract is: every `track()` call +# produces exactly one wire event. +# ───────────────────────────────────────────────────────────────────── + + +class TestTrackEmitsExactlyOnce: + """Regression for the P0-0.4 fix: `track()` must call + `self._transport.track(...)` exactly once per invocation. The + pre-fix code had a dead-code `wire_event` block in the + `else` branch that duplicated the event into the transport + buffer, causing customers to be double-billed.""" + + def test_track_emits_exactly_one_event(self, make_runtime): + """A single `track()` call must produce exactly one + `self._transport.track(...)` call. The pre-fix code had + an extra `wire_event = {...}` followed by a second + `self._transport.track(wire_event)` in the else branch + that double-billed customers.""" + rt = make_runtime() + # Patch the underlying transport's track to count calls. + with patch.object( + rt._transport, "track", wraps=rt._transport.track + ) as mock_track: + rt.track({"event_id": "e1", "event_type": "llm_call", "tokens": 100}) + assert mock_track.call_count == 1, ( + f"track() should emit exactly one event, got {mock_track.call_count}" + ) + # Verify the single emitted event has no `cost_cents` (the + # wire-format hygiene step). + call_args = mock_track.call_args + event = call_args[0][0] if call_args[0] else call_args[1].get("event") + if event is None and len(call_args) > 1: + event = list(call_args[1].values())[0] if isinstance(call_args[1], dict) else None + # The transport's track() may take a single dict positional. + # We check the FIRST positional argument. + if event and isinstance(event, dict): + assert "cost_cents" not in event, ( + f"Wire event must NOT include cost_cents: {event}" + ) + + def test_track_emits_exactly_one_for_multiple_calls(self, make_runtime): + """Multiple `track()` calls must produce exactly N wire + events. This catches any future regression where a single + `track()` call starts emitting more than one event. + + We bump the local loop threshold so the test isn't tripped + up by the in-memory loop detector (which is itself tested + in `test_runtime.py::TestTrackLocalLimits`).""" + rt = make_runtime() + rt._local_loop_threshold = 1000 # disable loop detection + rt._local_rate_limit = 100_000 + with patch.object( + rt._transport, "track", wraps=rt._transport.track + ) as mock_track: + for i in range(10): + rt.track( + { + "event_id": f"e{i}", + "event_type": "llm_call", + "tool_name": f"tool_{i}", # unique tool name + "tokens": 100, + } + ) + assert mock_track.call_count == 10 + + def test_track_dedup_skips_duplicate_fingerprints(self, make_runtime): + """When `track()` is called twice with the same fingerprint + (via `_seen_track_fingerprints`), the second call must NOT + re-emit. The dedup LRU catches double-emit from the + auto-instrumentation paths (httpx + langchain callback).""" + rt = make_runtime() + with patch.object( + rt._transport, "track", wraps=rt._transport.track + ) as mock_track: + # First call: registers the fingerprint. + rt.track({"event_id": "e1", "event_type": "llm_call", "tokens": 100, "_fingerprint": "fp-1"}) + # Second call with same fingerprint: should be deduped. + rt.track({"event_id": "e1", "event_type": "llm_call", "tokens": 100, "_fingerprint": "fp-1"}) + assert mock_track.call_count == 1, ( + f"track() with same fingerprint should dedup; got {mock_track.call_count}" + ) \ No newline at end of file diff --git a/tests/test_runtime_default_transport.py b/tests/test_runtime_default_transport.py index 7024753..7345de9 100644 --- a/tests/test_runtime_default_transport.py +++ b/tests/test_runtime_default_transport.py @@ -1,25 +1,26 @@ """ tests/test_runtime_default_transport.py -Regression guard for the gRPC transport freeze (see memory/grpc-feature-frozen.md -in the repo). The gRPC server on :50051 is intentionally incomplete: it does -not validate x-api-key, runs over plaintext, and exposes the proto schema via -reflection. These tests verify the SDK does NOT silently start using gRPC -when an operator forgets to clear NULLRUN_USE_GRPC, and that the warning is -logged loudly when initialization fails. +Regression guard for the NULLRUN_USE_GRPC env var. In SDK 0.4.0 the +gRPC transport was deleted (Phase 0, Epic 0.2) because the backend +proto is frozen and missing trace/span fields. These tests verify +that: + +1. The default code path (NULLRUN_USE_GRPC unset) does not call into + any gRPC machinery — there is none to call. +2. Setting NULLRUN_USE_GRPC=1 is a no-op that emits a single + WARNING (the operator should know the env var is dead). +3. The HTTP transport remains fully wired in both cases. What this test does NOT cover (intentionally): -- A successful gRPC connection. The proto files are not generated in the - repo (see sdk-python/src/nullrun/grpc_transport.py:14-21), so we cannot - exercise the "happy path" without first running grpcio-tools. Covering - the happy path is a task for the activation checklist, not for the - freeze PR. +- A successful gRPC connection. There is no gRPC transport anymore + (`src/nullrun/grpc_transport.py` was removed). The HTTP transport + is the only supported ingestion path; see the gateway repo for + the long-term transport plan. """ import logging import pytest -import respx -from httpx import Response from nullrun.runtime import NullRunRuntime @@ -32,118 +33,42 @@ class TestDefaultTransportIsHttp: + """The default path must never instantiate any gRPC transport + (because there isn't one).""" - def test_grpc_transport_stays_none_without_env_var( + def test_no_grpc_transport_attribute( self, make_runtime, monkeypatch ): - """The default path must never instantiate GrpcTransport. - - Regression guard: if someone removes the `if os.getenv("NULLRUN_USE_GRPC")` - gate in runtime.py:442, this test will fail because `_grpc_transport` - will be set to something non-None (or the import itself will raise - because proto files are not shipped in the repo). + """Regression guard: if someone re-introduces a gRPC transport + and forgets to gate it on `NULLRUN_USE_GRPC`, the runtime + must still be in pure-HTTP mode by default. """ monkeypatch.delenv("NULLRUN_USE_GRPC", raising=False) - # Even with an api_key set, no gRPC env → no gRPC transport. rt = make_runtime() - assert rt._grpc_transport is None - - def test_create_grpc_transport_never_called_by_default( - self, make_runtime, monkeypatch - ): - """Verifies the gate in runtime.py:442 short-circuits before - create_grpc_transport is invoked at all (cheaper than just - checking the result). - """ - from unittest.mock import patch - - monkeypatch.delenv("NULLRUN_USE_GRPC", raising=False) - with patch( - "nullrun.runtime.create_grpc_transport" - ) as mock_create: - make_runtime() - mock_create.assert_not_called() - - -# ────────────────────────────────────────────────────────────────────── -# Opt-in path with broken init (NULLRUN_USE_GRPC=1, proto missing) -# ────────────────────────────────────────────────────────────────────── + # No gRPC attribute at all on the runtime. + assert not hasattr(rt, "_grpc_transport") or rt._grpc_transport is None - -class TestOptInWithBrokenInit: - - def test_grpc_init_failure_falls_back_to_http_and_logs_warning( + def test_nullrun_use_grpc_env_var_emits_warning( self, make_runtime, monkeypatch, caplog ): - """When NULLRUN_USE_GRPC=1 but the proto files are not generated - (the actual state of this repo: sdk-python/src/nullrun/v1/ does - not exist), the SDK must: - - 1. NOT crash at init. - 2. Log a WARNING (exactly at WARNING level, not INFO or DEBUG — - an operator who flipped the env var must not miss it) that - names the failure mode. - 3. Leave _grpc_transport = None. - 4. Wire the HTTP transport so /track still works. - """ + """Setting NULLRUN_USE_GRPC=1 must log a WARNING telling the + operator the env var is now a no-op (the gRPC transport + was removed in 0.4.0).""" monkeypatch.setenv("NULLRUN_USE_GRPC", "1") with caplog.at_level(logging.WARNING, logger="nullrun.runtime"): - rt = make_runtime() - - # 1. SDK did not raise. - assert rt is not None - # 3. gRPC transport is None (init failed cleanly). - assert rt._grpc_transport is None - # 4. HTTP transport is wired — track() must still work. - assert rt._transport is not None - - # 2. The warning names the cause AND is at WARNING level exactly. - # - # Why "exactly WARNING" and not "at least WARNING": if someone - # silently downgrades `logger.warning(...)` to `logger.info(...)` - # the operator who set NULLRUN_USE_GRPC=1 stops seeing the message - # at default log level. The test must fail in that case so the - # regression is caught in CI, not in production. - warning_records = [ - r for r in caplog.records - if r.levelno == logging.WARNING - and r.name == "nullrun.runtime" - ] + make_runtime() assert any( - "gRPC transport could not be initialized" in r.getMessage() - for r in warning_records + "NULLRUN_USE_GRPC" in r.getMessage() and "no-op" in r.getMessage() + for r in caplog.records ), ( - "Expected a WARNING (level=WARNING, logger=nullrun.runtime) " - "mentioning that gRPC transport init failed. Got records: " - f"{[(r.levelname, r.name, r.getMessage()) for r in caplog.records]}" + "Expected a WARNING that NULLRUN_USE_GRPC is a no-op. " + f"Got: {[(r.levelname, r.getMessage()) for r in caplog.records]}" ) - def test_track_routes_to_http_when_grpc_unavailable( - self, make_runtime, monkeypatch - ): - """When gRPC init fails, runtime.track() must use the HTTP - transport. This is the contract runtime.py:1133-1148 implements: - `if self._grpc_transport: ... else: self._transport.track(...)`. - We assert it end-to-end by mocking the HTTP batch endpoint and - verifying it receives a request. - """ + def test_http_transport_always_wired(self, make_runtime, monkeypatch): + """Even with NULLRUN_USE_GRPC=1, the HTTP transport must be + fully wired — `track()` and `flush_now()` must work.""" monkeypatch.setenv("NULLRUN_USE_GRPC", "1") rt = make_runtime() - assert rt._grpc_transport is None # gRPC init failed in this env - - # Replace the generic /track/batch mock with one that records calls. - with respx.mock: - route = respx.post(f"{BASE_URL}/api/v1/track/batch").mock( - return_value=Response(200, json={"ok": True, "accepted": 1}) - ) - rt.track({ - "event_type": "llm_call", - "model": "gpt-4", - "tokens": 100, - }) - # Flush is async; track() returns immediately. Force a flush - # by calling _transport.flush() if available, else just check - # that the route was registered (the actual flush is tested - # elsewhere; the regression we guard here is the - # if/else branch in runtime.py:1133-1148). - assert route.called or route.call_count >= 0 # route exists + assert rt._transport is not None + assert rt._transport._client is not None diff --git a/tests/test_safe_error_str.py b/tests/test_safe_error_str.py index 3984156..c9efc4b 100644 --- a/tests/test_safe_error_str.py +++ b/tests/test_safe_error_str.py @@ -13,8 +13,6 @@ from __future__ import annotations -import pytest - from nullrun.breaker.exceptions import ( LoopDetectedException, NullRunBlockedException, diff --git a/tests/test_toolbox_langgraph.py b/tests/test_toolbox_langgraph.py index 86c5800..745ac0e 100644 --- a/tests/test_toolbox_langgraph.py +++ b/tests/test_toolbox_langgraph.py @@ -6,6 +6,8 @@ without requiring an actual LangChain/LangGraph runtime — we just need a duck-typed object with `.invoke` and `.stream`. """ +from typing import Any + import pytest from nullrun.instrumentation.langgraph import NullRunCallback @@ -29,53 +31,72 @@ def stream(self, input, config=None, **kwargs): yield {"callbacks": (config or {}).get("callbacks", [])} -def test_wrapper_returns_app(): +class _StubRuntime: + """A no-network stand-in for NullRunRuntime that the wrapper + can hand the callback without going through `get_runtime()`. + + `wrapper()` only needs an object that the `NullRunCallback` + constructor accepts (it just stashes it as `self.runtime`). + Real test isolation is in `test_langgraph_callback.py` / + `test_protect.py`. + """ + + def __init__(self) -> None: + self.track_calls: list[dict] = [] + + +@pytest.fixture +def stub_runtime() -> _StubRuntime: + return _StubRuntime() + + +def test_wrapper_returns_app(stub_runtime: _StubRuntime) -> None: """wrapper() must return the same app object (mutated in place).""" app = _FakeApp() - out = wrapper(app) + out = wrapper(app, runtime=stub_runtime) assert out is app -def test_wrapper_attaches_callback_to_invoke(): +def test_wrapper_attaches_callback_to_invoke(stub_runtime: _StubRuntime) -> None: """invoke() must have a NullRunCallback appended to config['callbacks'].""" app = _FakeApp() - wrapper(app) + wrapper(app, runtime=stub_runtime) app.invoke({"x": 1}) callbacks = app.invocations[0]["config"]["callbacks"] assert any(isinstance(c, NullRunCallback) for c in callbacks) -def test_wrapper_attaches_callback_to_stream(): +def test_wrapper_attaches_callback_to_stream(stub_runtime: _StubRuntime) -> None: """stream() must also get a NullRunCallback in config['callbacks'].""" app = _FakeApp() - wrapper(app) + wrapper(app, runtime=stub_runtime) list(app.stream({"x": 1})) callbacks = app.stream_calls[0]["config"]["callbacks"] assert any(isinstance(c, NullRunCallback) for c in callbacks) -def test_wrapper_preserves_user_callbacks(): +def test_wrapper_preserves_user_callbacks(stub_runtime: _StubRuntime) -> None: """If the caller already supplied callbacks, wrapper appends to them.""" app = _FakeApp() - wrapper(app) - user_cb = object() + wrapper(app, runtime=stub_runtime) + user_cb: Any = object() app.invoke({"x": 1}, config={"callbacks": [user_cb]}) callbacks = app.invocations[0]["config"]["callbacks"] assert user_cb in callbacks assert any(isinstance(c, NullRunCallback) for c in callbacks) -def test_wrapper_handles_no_config_arg(): +def test_wrapper_handles_no_config_arg(stub_runtime: _StubRuntime) -> None: """invoke(input) without a config kwarg must still get a callbacks list.""" app = _FakeApp() - wrapper(app) + wrapper(app, runtime=stub_runtime) app.invoke({"x": 1}) config = app.invocations[0]["config"] assert config is not None assert "callbacks" in config -def test_old_instrument_path_is_removed(): +def test_old_instrument_path_is_removed() -> None: """`nullrun.instrumentation.langgraph.instrument` no longer exists.""" import nullrun.instrumentation.langgraph as mod assert not hasattr(mod, "instrument"), ( diff --git a/tests/test_track_span_context.py b/tests/test_track_span_context.py index ce09c2b..9ddd0ea 100644 --- a/tests/test_track_span_context.py +++ b/tests/test_track_span_context.py @@ -12,7 +12,6 @@ loose contextvars (or synthesises new ones). """ from types import SimpleNamespace -from typing import List import pytest @@ -23,7 +22,6 @@ set_span, ) - # ────────────────────────────────────────────────────────────── # Capture events from the runtime # ────────────────────────────────────────────────────────────── @@ -39,7 +37,7 @@ def capturing_runtime(make_runtime, mock_api): captured and re-invoked so the runtime's own bookkeeping works. """ rt = make_runtime() - events: List[dict] = [] + events: list[dict] = [] original_track = rt.track @@ -228,9 +226,8 @@ def test_module_level_track_llm_output_tokens_optional(mock_api): stale singleton from a previous test (or a fresh one built from env defaults) targets the prod URL and respx raises AllMockedAssertionError.""" - from tests.conftest import BASE_URL - import nullrun + from tests.conftest import BASE_URL nullrun.init(api_key="test-key-12345678", api_url=BASE_URL) nullrun.track_llm(input_tokens=42) # smoke test — no exception @@ -244,10 +241,9 @@ def test_protect_then_track_llm_attaches_to_protect_span(capturing_runtime, monk """The integration story: @protect opens a span, a track_llm inside it inherits that span — no manual plumbing needed.""" import nullrun + import nullrun.decorators as dec from nullrun import runtime as runtime_mod from nullrun.decorators import reset as reset_decorator_runtime - - import nullrun.decorators as dec # Wire both: the @protect emit path (uses dec._runtime) AND the # module-level nullrun.track_llm path (uses runtime_mod.get_runtime). dec._runtime = capturing_runtime.runtime diff --git a/tests/test_transport.py b/tests/test_transport.py index c145c1e..f9c92a3 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,7 +1,6 @@ """ tests/test_transport.py — transport, circuit breaker, flush, retry coverage """ -import asyncio import threading import time @@ -11,7 +10,7 @@ from nullrun.breaker.circuit_breaker import CBState, CircuitBreaker from nullrun.breaker.exceptions import BreakerTransportError -from nullrun.transport import AsyncTransport, Transport +from nullrun.transport import Transport @pytest.fixture @@ -89,7 +88,12 @@ def test_send_batch_http_error_raises(self, transport): @respx.mock def test_execute_fallback_strict_blocks_on_gateway_error(self, transport): - """STRICT fallback mode blocks when Gateway unavailable.""" + """STRICT fallback mode blocks when Gateway unavailable. + + Uses `on_transport_error="legacy"` so the historical + `fallback_mode` behavior takes over and a dict is returned + (the new ADR-008 default is "raise"). + """ respx.post("https://api.test.nullrun.io/api/v1/gate").mock( return_value=httpx.Response(500, text="Server Error") ) @@ -100,13 +104,17 @@ def test_execute_fallback_strict_blocks_on_gateway_error(self, transport): tool="my.tool", input_data={}, fallback_mode="strict", + on_transport_error="legacy", ) assert result["decision"] == "block" assert result["decision_source"] == "fallback" @respx.mock def test_execute_fallback_permissive_allows_on_gateway_error(self, transport): - """PERMISSIVE fallback mode allows when Gateway unavailable.""" + """PERMISSIVE fallback mode allows when Gateway unavailable. + + Uses `on_transport_error="legacy"` for the historical shape. + """ respx.post("https://api.test.nullrun.io/api/v1/gate").mock( return_value=httpx.Response(500, text="Server Error") ) @@ -117,13 +125,18 @@ def test_execute_fallback_permissive_allows_on_gateway_error(self, transport): tool="my.tool", input_data={}, fallback_mode="permissive", + on_transport_error="legacy", ) assert result["decision"] == "allow" assert result["decision_source"] == "fallback" @respx.mock def test_execute_fallback_cached_uses_cache(self, transport): - """CACHED fallback mode uses cached decision when available.""" + """CACHED fallback mode uses cached decision when available. + + Uses `on_transport_error="legacy"` so the historical + `fallback_mode` behavior takes over. + """ # Pre-populate the cache cache_key = transport._policy_cache.make_key("ws-123") transport._policy_cache.set(cache_key, "block", "policy-cached-123") @@ -139,6 +152,7 @@ def test_execute_fallback_cached_uses_cache(self, transport): tool="my.tool", input_data={}, fallback_mode="cached", + on_transport_error="legacy", ) assert result["decision"] == "block" assert result["decision_source"] == "cached" @@ -146,7 +160,10 @@ def test_execute_fallback_cached_uses_cache(self, transport): @respx.mock def test_execute_fallback_cached_no_cache_allows(self, transport): - """CACHED fallback allows when no cache available and Gateway unavailable.""" + """CACHED fallback allows when no cache available and Gateway unavailable. + + Uses `on_transport_error="legacy"`. + """ respx.post("https://api.test.nullrun.io/api/v1/gate").mock( return_value=httpx.Response(500, text="Server Error") ) @@ -157,6 +174,7 @@ def test_execute_fallback_cached_no_cache_allows(self, transport): tool="my.tool", input_data={}, fallback_mode="cached", + on_transport_error="legacy", ) assert result["decision"] == "allow" assert result["decision_source"] == "fallback" @@ -190,18 +208,25 @@ def test_execute_success_caches_decision(self, transport): @respx.mock def test_check_endpoint_returns_block_on_error(self, transport): - """Check endpoint returns block decision on error.""" - respx.post("https://api.test.nullrun.io/api/v1/check").mock( + """Check endpoint returns block decision on error. + + Uses `on_transport_error="legacy"` for the historical + fail-CLOSED block return. The new ADR-008 default is "raise". + """ + respx.post("https://api.test.nullrun.io/api/v1/gate").mock( return_value=httpx.Response(500, text="Server Error") ) - result = transport.check({ - "workspace_id": "ws-123", - "execution_id": "exec-456", - "operation_id": "op-789", - "check_type": "llm", - "model": "claude-3", - "estimated_tokens": 100, - }) + result = transport.check( + { + "workspace_id": "ws-123", + "execution_id": "exec-456", + "operation_id": "op-789", + "check_type": "llm", + "model": "claude-3", + "estimated_tokens": 100, + }, + on_transport_error="legacy", + ) assert result["decision"] == "block" @respx.mock @@ -362,37 +387,11 @@ def handler(request): t.stop() -class TestAsyncTransport: - - @pytest.mark.asyncio - @respx.mock - async def test_async_send_batch_success(self): - respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(200, json={}) - ) - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - t._client = httpx.AsyncClient() - # Add events directly to buffer - async with t._lock: - t._buffer.append({"event": "async_test"}) - await t._flush_locked() - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_includes_api_version_header(self): - route = respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(200, json={}) - ) - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - t._client = httpx.AsyncClient() - # Add events directly to buffer - async with t._lock: - t._buffer.append({"event": "test"}) - await t._flush_locked() - request = route.calls.last.request - assert "X-API-Version" in request.headers - await t.stop() +# TestAsyncTransport and TestAsyncTransportFlush classes were removed +# in 0.4.0 when AsyncTransport was deleted (Phase 0, Epic 0.2). The +# sync Transport supports being driven from an async event loop via +# `nullrun.track_llm` and `@nullrun.protect` (see +# tests/test_async_from_event_loop.py for the new contract). class TestBoundedDict: @@ -484,7 +483,7 @@ def test_buffer_overflow_drops_oldest(self): t.track({"event": f"e{i}"}) # Flush with CB OPEN will re-queue and enforce max_buffer_size - initial_buffer_len = len(t._buffer) + len(t._buffer) t._do_flush() # After flush with CB OPEN, buffer should be capped at max_buffer_size @@ -517,255 +516,10 @@ def test_transport_stopped_flag(self, transport): assert transport._stopped -class TestAsyncTransportFlush: - - @pytest.mark.asyncio - @respx.mock - async def test_async_flush_error_requeues(self): - """When async flush fails, batch is re-queued.""" - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - t._client = httpx.AsyncClient() - - # Mock a failing endpoint - respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(500, text="Server Error") - ) - - # Add events to buffer - async with t._lock: - t._buffer.append({"event": "test1"}) - t._buffer.append({"event": "test2"}) - - initial_buffer_len = len(t._buffer) - await t._flush_locked() - - # Buffer should have events re-queued after failure - # (may be empty if all re-queued or have some remaining) - # The key is it shouldn't silently drop without metric update - assert len(t._buffer) >= 0 # Re-queue happened - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_flush_circuit_breaker_open(self): - """When CB opens in async transport, batch is re-queued.""" - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - t._client = httpx.AsyncClient() - - # Open the circuit breaker - cb = t._circuit_breaker - for _ in range(cb._failure_threshold): - try: - await cb.call(lambda: (_ for _ in ()).throw(RuntimeError("boom"))) - except RuntimeError: - pass - - # Add events - async with t._lock: - t._buffer.append({"event": "test1"}) - - await t._flush_locked() - # Buffer still has event since CB is open - assert len(t._buffer) >= 1 - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_track_increments_metrics(self): - """Async track increments events_enqueued metric.""" - from nullrun.observability import metrics - - metrics.reset() - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - await t.start() - - # Mock successful batch - respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(200, json={}) - ) - - await t.track({"event": "test1"}) - await t.track({"event": "test2"}) - - # events_enqueued should be incremented - assert metrics.transport.events_enqueued >= 2 - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_flush_success_updates_metrics(self): - """Successful async flush updates batches_sent and events_sent metrics.""" - from nullrun.observability import metrics - - metrics.reset() - route = respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(200, json={"accepted_event_ids": ["e1", "e2"]}) - ) - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - t._client = httpx.AsyncClient() - - async with t._lock: - t._buffer.append({"event_id": "e1", "event": "test1"}) - t._buffer.append({"event_id": "e2", "event": "test2"}) - - await t._flush_locked() - - assert metrics.transport.batches_sent >= 1 - assert metrics.transport.events_sent >= 2 - assert metrics.transport.last_flush_at is not None - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_flush_circuit_breaker_open_increments_metrics(self): - """Circuit breaker opening increments circuit_breaker_opens metric in async.""" - from nullrun.observability import metrics - from nullrun.breaker.circuit_breaker import CBState - - metrics.reset() - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - await t.start() - t._client = httpx.AsyncClient() - - # Open the circuit breaker via failures - cb = t._circuit_breaker - for _ in range(cb._failure_threshold): - try: - await cb.call(lambda: (_ for _ in ()).throw(RuntimeError("boom"))) - except RuntimeError: - pass - - assert cb.state == CBState.OPEN - assert metrics.transport.circuit_open_count >= 1 - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_buffer_overflow_drops_oldest(self): - """Async transport drops oldest events when buffer exceeds max_buffer_size.""" - from nullrun.observability import metrics - from nullrun.transport import FlushConfig - - metrics.reset() - config = FlushConfig(max_buffer_size=5, batch_size=100, max_failed_flush=3) - t = AsyncTransport( - api_url="https://api.test.nullrun.io", - api_key="test-key", - config=config, - ) - t._client = httpx.AsyncClient() - - # First, open the circuit breaker so re-queue path is triggered - cb = t._circuit_breaker - for _ in range(cb._failure_threshold): - try: - await cb.call(lambda: (_ for _ in ()).throw(RuntimeError("boom"))) - except RuntimeError: - pass - - # Add events beyond max_buffer_size - for i in range(10): - async with t._lock: - t._buffer.append({"event_id": f"e{i}", "event": f"test{i}"}) - - await t._flush_locked() - - # After flush with CB OPEN, buffer should be capped at max_buffer_size - assert len(t._buffer) <= config.max_buffer_size - # Events should have been dropped due to overflow - assert metrics.transport.events_dropped >= 5 - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_flush_circuit_breaker_open_reequeue_full_batch(self): - """When CB opens, full batch is re-queued and preserved for retry.""" - from nullrun.breaker.circuit_breaker import CBState - - t = AsyncTransport(api_url="https://api.test.nullrun.io", api_key="test-key") - t._client = httpx.AsyncClient() - - # Open the circuit breaker - cb = t._circuit_breaker - for _ in range(cb._failure_threshold): - try: - await cb.call(lambda: (_ for _ in ()).throw(RuntimeError("boom"))) - except RuntimeError: - pass - - assert cb.state == CBState.OPEN - - # Add multiple events to buffer - async with t._lock: - t._buffer.append({"event_id": "e1", "event": "test1"}) - t._buffer.append({"event_id": "e2", "event": "test2"}) - t._buffer.append({"event_id": "e3", "event": "test3"}) - - batch_size = len(t._buffer) - await t._flush_locked() - - # All events should be back in buffer since CB is OPEN - assert len(t._buffer) == batch_size - # Events should be in same order (appended to end) - event_ids = [e["event_id"] for e in t._buffer] - assert "e1" in event_ids - assert "e2" in event_ids - assert "e3" in event_ids - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_flush_with_hmac_headers(self): - """Async flush includes HMAC signature headers when secret_key is set.""" - route = respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(200, json={}) - ) - t = AsyncTransport( - api_url="https://api.test.nullrun.io", - api_key="test-key", - secret_key="secret-123", - ) - t._client = httpx.AsyncClient() - - async with t._lock: - t._buffer.append({"event": "test"}) - - await t._flush_locked() - - request = route.calls.last.request - assert "X-Signature-Timestamp" in request.headers - assert "X-Signature" in request.headers - assert len(request.headers["X-Signature"]) == 64 # SHA256 hex - await t.stop() - - @pytest.mark.asyncio - @respx.mock - async def test_async_track_batch_size_triggers_flush(self): - """Async track triggers flush when batch_size is reached.""" - from nullrun.transport import FlushConfig - - route = respx.post("https://api.test.nullrun.io/api/v1/track/batch").mock( - return_value=httpx.Response(200, json={}) - ) - config = FlushConfig(batch_size=3, flush_interval=60.0) - t = AsyncTransport( - api_url="https://api.test.nullrun.io", - api_key="test-key", - config=config, - ) - await t.start() - - await t.track({"event": "e1"}) - await t.track({"event": "e2"}) - - # Not yet flushed (only 2 of 3) - assert not route.called - - await t.track({"event": "e3"}) - - # Should have triggered flush - assert route.called - await t.stop() +# TestAsyncTransportFlush class was removed in 0.4.0 (see comment above +# TestAsyncTransport). The async-flush path is now exercised via +# tests/test_async_from_event_loop.py which uses the sync Transport +# from an async event loop. # ────────────────────────────────────────────────────────────── @@ -793,8 +547,9 @@ def test_cache_miss_returns_none(self): def test_cache_expiry(self): """PolicyCache evicts expired entries.""" - from nullrun.transport import PolicyCache import time + + from nullrun.transport import PolicyCache cache = PolicyCache(maxsize=100, ttl_seconds=0.1) # 100ms TTL cache.set("key1", "allow", "policy-123") # Not expired yet @@ -896,6 +651,7 @@ class TestTransportHMAC: def test_generate_hmac_signature(self): """HMAC signature generation works.""" import time + from nullrun.transport import generate_hmac_signature sig = generate_hmac_signature( api_key="test-key", @@ -909,6 +665,7 @@ def test_generate_hmac_signature(self): def test_verify_hmac_signature_valid(self): """HMAC verification succeeds with valid signature.""" import time + from nullrun.transport import generate_hmac_signature, verify_hmac_signature api_key = "test-key" secret_key = "secret-123" @@ -921,6 +678,7 @@ def test_verify_hmac_signature_valid(self): def test_verify_hmac_signature_invalid(self): """HMAC verification fails with invalid signature.""" import time + from nullrun.transport import verify_hmac_signature result = verify_hmac_signature( api_key="test-key", @@ -933,8 +691,9 @@ def test_verify_hmac_signature_invalid(self): def test_verify_hmac_signature_expired(self): """HMAC verification fails with expired timestamp.""" - from nullrun.transport import generate_hmac_signature, verify_hmac_signature import time + + from nullrun.transport import generate_hmac_signature, verify_hmac_signature api_key = "test-key" secret_key = "secret-123" body = '{"event": "test"}'