diff --git a/.gitignore b/.gitignore index 5bd9b1b..b6d3e81 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,4 @@ CLAUDE.md # Project-local working notes (kept on disk, not in VCS) analyze.md +docs/integration-baseline-2026-06-19.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 975fc29..ab5402f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -348,81 +348,6 @@ preservation cases). ### Legacy -- **P0-1 (PCI-DSS / GDPR): positional PII masking.** Sensitive tools - called positionally (e.g. ``charge("4111-1111-1111-1111", 50)``) now - mask positional args the same way kwargs already do, by introspecting - the function signature with ``inspect.signature(fn)`` and applying - ``SENSITIVE_ARG_KEYS`` to the matching parameter name. Pre-fix the - PAN at position 0 was forwarded as-is into ``/execute`` and landed - in the audit log. -- **P0-3 (OOM): streaming response memory cap.** Sync and async - httpx transports now use bounded chunked reads capped at - ``MAX_RESPONSE_BYTES`` (16 MiB by default; ``NULLRUN_MAX_RESPONSE_BYTES`` - env var to override). When the cap is exceeded, tracking is skipped - and ``_coverage_streaming_skipped`` is incremented so the dashboard - sees which hosts are producing oversized responses. Pre-fix - ``response.read()`` / ``await response.aread()`` buffered the entire - response body in memory — a 16+ MB allocation per streaming LLM - call under load. -- **P0-4 (cost-audit): drop-newest on buffer overflow.** The CB-OPEN - re-queue path in ``Transport._do_flush_locked`` now drops the - NEWEST non-critical events instead of the oldest. The oldest - events (start-of-incident, start-of-billing-period) are exactly - what a billing investigator needs to reconstruct — losing them - silently broke monthly rollups. Control-plane events - (``state_change`` / ``kill_received`` / ``policy_invalidated`` / - ``key_rotated``) are preserved regardless of position so the - dashboard's KILL switch continues to land even under sustained - backend outage. -- **P0-6 + P3-3 (security): redact-before-truncate.** ``_safe_repr`` - now runs ``_strip_details_balanced`` on the FULL repr before - truncating to ``max_len=50``. Pre-fix the truncate ran first, and - if ``details={...}`` lived past position 50 in the original repr - (common for httpx.HTTPError with a long URL), the redact pass - saw nothing on the truncated slice and the raw payload leaked - into ``span_end`` audit events. -- **S-8 / P2-4: ``agent_id`` is now a real UUID with dashes.** - ``agent()`` context manager emits ``str(uuid.uuid4())`` (e.g. - ``95ca7c0b-8334-478a-af23-2788803ef3b8``) for auto-generated ids. - Pre-fix the format was ``f"agent-{uuid.uuid4().hex}"`` — 32 hex - chars with no dashes; backend UUID-typed columns silently - dropped these to NULL on insert. User-supplied names are still - preserved verbatim. -- **S-9: LRU cap on ``NullRunCallback._active_runs``** (4096 entries, - FIFO eviction with WARN log). Pre-fix this dict grew unbounded - when ``on_chain_end`` did not fire (errors in the chain body - short-circuited the end hook for some LangChain versions), - leaking memory in long-running services. -- **S-10: WebSocket reconnect max-attempts cap** (10 consecutive - failures). Pre-fix the loop was unbounded (``while not - self._closed:``) and leaked the WS thread forever when the - backend was permanently down. After the cap the SDK falls back - to HTTP-poll for control-plane state delivery. -- **P2-1: ``_coverage_seen`` now bumps in the httpx path.** - Pre-fix the counter was only incremented in the ``requests`` - path (``auto_requests.py:185``), so the dashboard's coverage - view was empty for the dominant httpx traffic (every OpenAI / - Anthropic / Gemini / Mistral / Cohere call). Now both sync and - async httpx ``_emit`` bump the counter. -- **P3-2: webhook delivery uses exponential backoff** (cap 30s). - Pre-fix the schedule was linear (``0.5 * (attempt + 1)``); under - sustained outage this produced a tight retry storm on the dead - endpoint — each KILL/PAUSE spawned its own delivery thread. - Post-fix the schedule is ``0.5 * 2**attempt`` capped at 30s: - 0.5s, 1.0s, 2.0s, 4.0s, 8.0s, 16.0s, 30.0s. - -### Tests - -Added regression tests for every item above (57 new tests across 9 -new test files: ``test_agent_id_uuid.py``, ``test_args_pii_masked.py``, -``test_streaming_oom_cap.py``, ``test_lru_active_runs.py``, -``test_reconnect_cap.py``, ``test_coverage_seen_httpx.py``, -``test_webhook_backoff.py``, ``test_redact.py``; existing -``test_buffer_invariants.py`` extended with drop-newest + critical-event -preservation cases). - -### Legacy - - **SDK silent runtime fallback removed** (FIX-4): `_get_or_create_runtime` in `nullrun.decorators` no longer wraps `NullRunRuntime.get_instance()` in a `try/except Exception` that rebuilds a no-arg `NullRunRuntime()`. diff --git a/pyproject.toml b/pyproject.toml index 4b74fdc..fa1cd79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,12 +192,21 @@ pythonpath = ["."] [tool.coverage.run] source = ["src/nullrun"] omit = ["tests/*"] +# Branch coverage: every if/else, try/except, ternary contributes two +# branches instead of one. Disabled by default in the stdlib config; +# the SDK has too many error / fallback paths to leave these invisible. +branch = true [tool.coverage.report] -fail_under = 70 +fail_under = 82 show_missing = true +# Branch coverage makes the report noisier; precision=2 keeps the +# numbers readable. skip_empty drops files with no statements. +precision = 2 +skip_empty = true exclude_lines = [ "pragma: no cover", "if TYPE_CHECKING:", "raise NotImplementedError", + "if __name__ == .__main__.:", ] \ No newline at end of file diff --git a/tests/test_actions_context_init.py b/tests/test_actions_context_init.py new file mode 100644 index 0000000..d364334 --- /dev/null +++ b/tests/test_actions_context_init.py @@ -0,0 +1,519 @@ +""" +Branch-coverage tests for ``nullrun.actions``, ``nullrun.context``, +``nullrun.__init__``, and the WorkflowKilledException deprecation +warning. Together these close the last 1-2 % lines that no other +test file exercises. +""" +from __future__ import annotations + +import threading +import time +import warnings +from unittest.mock import MagicMock + +import pytest + +import nullrun +from nullrun.actions import ( + ActionEvent, + ActionHandler, + ActionType, + WebhookConfig, + handle_action, + register_action_handler, +) +from nullrun.breaker.exceptions import ( + NullRunBlockedException, + WorkflowKilledException, + WorkflowKilledInterrupt, +) + + +# ─── ActionHandler ────────────────────────────────────────────────── + + +def test_register_handler_replaces_default(): + h = ActionHandler() + sentinel = MagicMock() + h.register_handler(ActionType.KILL, sentinel) + assert h._handlers[ActionType.KILL] is sentinel + + +def test_register_webhook_adds_to_list(): + h = ActionHandler() + cfg = WebhookConfig(url="https://example.com/hook") + h.register_webhook(cfg) + assert cfg in h._webhooks + + +def test_remove_webhook_removes_by_url(): + h = ActionHandler() + h.register_webhook(WebhookConfig(url="https://a")) + h.register_webhook(WebhookConfig(url="https://b")) + h.remove_webhook("https://a") + urls = [w.url for w in h._webhooks] + assert urls == ["https://b"] + + +def test_remove_webhook_unknown_url_no_op(): + h = ActionHandler() + h.remove_webhook("https://never-added") # must not raise + + +def test_get_action_history_returns_slice(): + h = ActionHandler() + for _ in range(5): + h._record_action(ActionType.KILL, "wf", "x", {}) + recent = h.get_action_history(limit=3) + assert len(recent) == 3 + + +def test_clear_history_empties_list(): + h = ActionHandler() + h._record_action(ActionType.KILL, "wf", "x", {}) + h.clear_history() + assert h._action_history == [] + + +def test_handle_unknown_action_does_not_invoke_handler(): + """Sprint 1.5 (B14): unknown action logs ERROR + records BLOCK but + does NOT invoke any handler (fail-open). Pre-fix this degraded + to BLOCK → DoS amplifier. + """ + h = ActionHandler() + handler_mock = MagicMock() + h.register_handler(ActionType.BLOCK, handler_mock) + # ``"weird"`` is not in ActionType — should fail-open. + h.handle("weird", "wf-1", reason="x") + handler_mock.assert_not_called() + + +def test_handle_unknown_action_records_block_event(caplog): + """Unknown action records a BLOCK event for forensic visibility.""" + import logging + + h = ActionHandler() + with caplog.at_level(logging.ERROR, logger="nullrun.actions"): + h.handle("unknown_action_type", "wf-1", reason="x") + history = h.get_action_history() + assert any(e.action_type == "block" for e in history) + + +def test_handle_known_action_invokes_handler(): + h = ActionHandler() + handler_mock = MagicMock() + h.register_handler(ActionType.KILL, handler_mock) + h.handle("kill", "wf-1", reason="budget") + handler_mock.assert_called_once() + + +def test_handle_action_lowercases_input(): + """``handle("KILL", ...)`` matches ActionType.KILL after .lower().""" + h = ActionHandler() + handler_mock = MagicMock() + h.register_handler(ActionType.KILL, handler_mock) + h.handle("KILL", "wf-1", reason="x") + handler_mock.assert_called_once() + + +def test_handle_kill_does_not_propagate_killed_interrupt(): + """``WorkflowKilledInterrupt`` from the handler is SWALLOWED by the + dispatch loop (BaseException caught and logged). The kill signal + has already been recorded in history by the time the dispatch + wraps the handler call — re-raising would lose the audit entry. + """ + h = ActionHandler() + h.handle("kill", "wf-1", reason="x") # no raise + # History still has the kill event. + history = h.get_action_history() + assert any(e.action_type == "kill" for e in history) + + +def test_handle_pause_records_workflow_in_paused_dict(): + """PAUSE handler raises WorkflowPausedException but it is swallowed; + the workflow_id is recorded in ``_paused_workflows`` first.""" + h = ActionHandler() + h.handle("pause", "wf-1", reason="x") + assert "wf-1" in h._paused_workflows + + +def test_handle_block_does_not_propagate_blocked_exception(): + """BLOCK handler raises NullRunBlockedException but it is swallowed.""" + h = ActionHandler() + h.handle("block", "wf-1", reason="x") # no raise + history = h.get_action_history() + assert any(e.action_type == "block" for e in history) + + +def test_handle_handler_exception_swallowed(): + """A buggy custom handler must not crash the dispatch.""" + h = ActionHandler() + boom = MagicMock(side_effect=RuntimeError("oops")) + h.register_handler(ActionType.ALERT, boom) + h.handle("alert", "wf-1", reason="x") # must not raise + + +def test_handle_records_event_with_reason(): + h = ActionHandler() + h.handle("alert", "wf-1", reason="manual escalation") + events = h.get_action_history() + assert len(events) == 1 + assert events[0].reason == "manual escalation" + + +def test_handle_records_event_with_default_reason(): + """``reason=None`` defaults to ``"Unknown"`` for the history record.""" + h = ActionHandler() + h.handle("alert", "wf-1", reason=None) + events = h.get_action_history() + assert events[0].reason == "Unknown" + + +def test_action_history_trimmed_at_max(): + """History longer than ``_max_history`` is trimmed from the front.""" + h = ActionHandler() + h._max_history = 3 + for i in range(5): + h._record_action(ActionType.ALERT, f"wf-{i}", "x", {}) + assert len(h._action_history) == 3 + # Trimmed from the front — the oldest two (``wf-0``, ``wf-1``) are gone. + wf_ids = [e.workflow_id for e in h._action_history] + assert wf_ids == ["wf-2", "wf-3", "wf-4"] + + +def test_action_event_details_default_empty_dict(): + """``ActionEvent.details`` defaults to ``{}`` when not provided.""" + ev = ActionEvent( + timestamp="2026-01-01T00:00:00Z", + action_type="kill", + workflow_id="wf-1", + reason="x", + ) + assert ev.details == {} + + +# ─── is_paused ─────────────────────────────────────────────────────── + + +def test_is_paused_unknown_workflow_returns_false(): + h = ActionHandler() + assert h.is_paused("wf-never-paused") is False + + +def test_is_paused_within_cooldown_returns_true(): + h = ActionHandler() + h._paused_workflows["wf-1"] = time.time() + assert h.is_paused("wf-1", cooldown_seconds=60.0) is True + + +def test_is_paused_past_cooldown_returns_false_and_clears(): + h = ActionHandler() + h._paused_workflows["wf-1"] = time.time() - 100 # 100s ago + assert h.is_paused("wf-1", cooldown_seconds=60.0) is False + # Past-cooldown entry is removed so the next call is also False. + assert "wf-1" not in h._paused_workflows + + +# ─── webhook async delivery ────────────────────────────────────────── + + +def test_queue_webhook_starts_delivery_thread(): + h = ActionHandler() + h.register_webhook(WebhookConfig(url="https://example.com/h")) + h._queue_webhook(ActionType.KILL, "wf-1", "x", {}) + # A delivery thread is started and registered. + assert h._webhook_running is True + assert h._webhook_thread is not None + # Let the thread exit so the test does not hang. + h.stop_webhooks() + + +def test_queue_webhook_overflow_drops_oldest(caplog): + """Webhook queue overflow → oldest dropped (FIFO) + WARNING logged.""" + import logging + + h = ActionHandler() + h._webhook_max_size = 2 + with caplog.at_level(logging.WARNING, logger="nullrun.actions"): + for i in range(4): + h._queue_webhook(ActionType.KILL, f"wf-{i}", "x", {}) + assert len(h._webhook_queue) == 2 + # Newest two kept. + assert h._webhook_queue[-1]["workflow_id"] == "wf-3" + h.stop_webhooks() + + +def test_deliver_webhook_no_httpx_warns(caplog): + """If httpx is unavailable, webhook delivery logs and returns.""" + import logging + + import nullrun.actions as act_mod + + h = ActionHandler() + h.register_webhook(WebhookConfig(url="https://example.com/h")) + # Force the no-httpx branch. + original = act_mod._HAS_HTTPX + act_mod._HAS_HTTPX = False + try: + with caplog.at_level(logging.WARNING, logger="nullrun.actions"): + h._deliver_webhook(h._webhooks[0], {"x": 1}) + assert any("httpx not installed" in r.getMessage() for r in caplog.records) + finally: + act_mod._HAS_HTTPX = original + + +def test_deliver_webhook_success_returns_immediately(monkeypatch): + """A 200 response on the first attempt stops the loop.""" + h = ActionHandler() + h.register_webhook(WebhookConfig(url="https://example.com/h")) + fake_resp = MagicMock() + fake_resp.raise_for_status = MagicMock() + monkeypatch.setattr("nullrun.actions.httpx.post", MagicMock(return_value=fake_resp)) + h._deliver_webhook(h._webhooks[0], {"x": 1}) # no raise + + +def test_deliver_webhook_retries_then_gives_up(monkeypatch): + """All retries exhausted — loop ends without raising.""" + h = ActionHandler() + h.register_webhook(WebhookConfig(url="https://example.com/h", retries=2)) + fake_post = MagicMock(side_effect=RuntimeError("down")) + monkeypatch.setattr("nullrun.actions.httpx.post", fake_post) + # time.sleep is patched to avoid the actual delay. + monkeypatch.setattr("time.sleep", MagicMock()) + h._deliver_webhook(h._webhooks[0], {"x": 1}) # no raise + assert fake_post.call_count == 2 + + +def test_stop_webhooks_joins_thread(): + h = ActionHandler() + h.register_webhook(WebhookConfig(url="https://example.com/h")) + h._queue_webhook(ActionType.KILL, "wf-1", "x", {}) + assert h._webhook_thread is not None + h.stop_webhooks() + assert h._webhook_running is False + + +# ─── Module-level helpers ───────────────────────────────────────────── + + +def test_handle_action_module_helper_dispatches(monkeypatch): + """``handle_action(...)`` delegates to the global ``ActionHandler``.""" + from nullrun import actions as act_mod + + act_mod._action_handler = None # force fresh + h = MagicMock() + monkeypatch.setattr("nullrun.actions.get_action_handler", lambda: h) + handle_action("kill", "wf-1", reason="x") + h.handle.assert_called_once_with("kill", "wf-1", "x") + + +def test_register_action_handler_module_helper(monkeypatch): + from nullrun import actions as act_mod + + h = MagicMock() + monkeypatch.setattr("nullrun.actions.get_action_handler", lambda: h) + fn = MagicMock() + register_action_handler(ActionType.KILL, fn) + h.register_handler.assert_called_once_with(ActionType.KILL, fn) + + +def test_get_action_handler_returns_singleton(): + from nullrun import actions as act_mod + + act_mod._action_handler = None # reset + h1 = act_mod.get_action_handler() + h2 = act_mod.get_action_handler() + assert h1 is h2 + + +# ─── nullrun.context ────────────────────────────────────────────────── + + +def test_generate_trace_id_is_uuid_format(): + from nullrun.context import generate_span_id, generate_trace_id + + tid = generate_trace_id() + assert tid.count("-") == 4 # canonical UUID4 + + +def test_generate_span_id_is_uuid_format(): + from nullrun.context import generate_span_id + + sid = generate_span_id() + assert sid.count("-") == 4 + + +def test_attempt_context_manager_pushes_and_restores(): + from nullrun.context import attempt, get_attempt_index, set_attempt_index + + set_attempt_index(0) + with attempt(3) as idx: + assert idx == 3 + assert get_attempt_index() == 3 + assert get_attempt_index() == 0 + + +def test_attempt_context_manager_nested(): + from nullrun.context import attempt, get_attempt_index + + with attempt(1): + with attempt(5): + assert get_attempt_index() == 5 + assert get_attempt_index() == 1 + + +def test_workflow_context_manager_sets_ids(): + from nullrun.context import get_span_id, get_trace_id, get_workflow_id, workflow + + with workflow("my-flow") as wid: + assert wid == "my-flow" + assert get_workflow_id() == "my-flow" + assert get_trace_id() is not None + assert get_span_id() is not None + assert get_workflow_id() is None + + +def test_workflow_default_name_is_uuid(): + import uuid + + from nullrun.context import get_workflow_id, workflow + + with workflow() as wid: + # 36-char UUID with dashes. + uuid.UUID(wid) + assert get_workflow_id() == wid + + +def test_span_context_manager_restores_on_exit(): + from nullrun.context import get_span_id, span + + with span("outer") as sid: + assert get_span_id() == "outer" + assert get_span_id() is None + + +def test_span_default_name_is_uuid(): + import uuid + + from nullrun.context import get_span_id, span + + with span() as sid: + uuid.UUID(sid) + assert get_span_id() == sid + + +def test_agent_context_manager_sets_agent_id(): + from nullrun.context import agent, get_agent_id + + with agent("agent-1") as aid: + assert aid == "agent-1" + assert get_agent_id() == "agent-1" + assert get_agent_id() is None + + +def test_set_attempt_index_writes_to_contextvar(): + from nullrun.context import get_attempt_index, set_attempt_index + + set_attempt_index(42) + assert get_attempt_index() == 42 + set_attempt_index(0) # cleanup + + +def test_workflow_nested_restores_outer_on_exit(): + from nullrun.context import get_workflow_id, workflow + + with workflow("outer"): + assert get_workflow_id() == "outer" + with workflow("inner"): + assert get_workflow_id() == "inner" + assert get_workflow_id() == "outer" + assert get_workflow_id() is None + + +def test_span_id_in_workflow_resets_to_new_value(): + """§7.2 #16: ``with workflow(...)`` resets ``span_id``, not only + workflow_id / trace_id, so the audit log can correctly nest the + workflow's own span_start under the workflow_id. + """ + from nullrun.context import get_span_id, span, workflow + + with span("outer-span"): + original = get_span_id() + with workflow("wf-x"): + # span_id must have changed (new UUID), not still "outer-span". + new = get_span_id() + assert new != original + assert new is not None + + +# ─── nullrun.__init__ ──────────────────────────────────────────────── + + +def test_init_unknown_attr_raises_attribute_error(): + """``nullrun.something_unknown`` raises AttributeError, not ImportError.""" + with pytest.raises(AttributeError): + nullrun.no_such_attribute # noqa: B018 + + +def test_init_lazy_export_loads_attribute(): + """First access to a lazy export caches it on the module.""" + rt = nullrun.NullRunRuntime + # Subsequent access is the cached object. + assert nullrun.NullRunRuntime is rt + + +def test_dir_lists_only_curated_surface(): + """``dir(nullrun)`` shows only the 6 curated names + __version__.""" + public = dir(nullrun) + # The 6 curated names are explicitly listed. + for name in ("init", "protect", "track_llm", "track_tool", "track_event"): + assert name in public + # Lazy exports are NOT in dir() until first access. + assert "SpanContext" not in public + assert "NullRunRuntime" not in public + + +def test_init_module_has_all_attribute(): + """The ``__all__`` attribute lists the curated surface.""" + assert "init" in nullrun.__all__ + assert "protect" in nullrun.__all__ + + +# ─── WorkflowKilledException deprecation warning ───────────────────── + + +def test_workflow_killed_exception_emits_deprecation_warning(): + """Constructing the deprecated ``WorkflowKilledException`` triggers + a ``DeprecationWarning``. + """ + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + WorkflowKilledException(workflow_id="wf-1", reason="x") + assert any(issubclass(item.category, DeprecationWarning) for item in w) + + +def test_workflow_killed_interrupt_does_not_emit_warning(): + """Constructing the canonical ``WorkflowKilledInterrupt`` does NOT + emit a deprecation warning (the deprecation is on the parent name). + """ + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + WorkflowKilledInterrupt(workflow_id="wf-1", reason="x") + assert not any(issubclass(item.category, DeprecationWarning) for item in w) + + +def test_workflow_killed_interrupt_is_base_exception(): + """``except Exception`` does NOT catch the kill signal.""" + with pytest.raises(WorkflowKilledInterrupt): + try: + raise WorkflowKilledInterrupt(workflow_id="wf-1", reason="x") + except Exception: + pytest.fail("Exception should not catch WorkflowKilledInterrupt") + + +def test_workflow_killed_exception_is_caught_by_except_killed_exception(): + """Legacy ``except WorkflowKilledException`` still catches the new + interrupt (back-compat contract). + """ + with pytest.raises(WorkflowKilledException): + raise WorkflowKilledInterrupt(workflow_id="wf-1", reason="x") \ No newline at end of file diff --git a/tests/test_auto_requests.py b/tests/test_auto_requests.py new file mode 100644 index 0000000..df49ffe --- /dev/null +++ b/tests/test_auto_requests.py @@ -0,0 +1,435 @@ +""" +Regression tests for the ``requests`` auto-instrumentation patch. + +Installs a synthetic ``requests.Session`` into ``sys.modules`` so the +patcher can wrap ``Session.send`` end-to-end without requiring the +real ``requests`` package in CI. +""" +from __future__ import annotations + +import importlib +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +def _install_fake_requests(monkeypatch, *, streaming: bool = False, status: int = 200) -> dict: + """Install a fake ``requests`` module exposing a real ``Session`` + class. The ``Session.send`` we wrap returns a fake response whose + body bytes the test controls. + + Returns a recorder dict. + """ + recorder = {"track": [], "track_event": []} + + class _FakeResponse: + def __init__(self, body: bytes, status_code: int): + self.content = body + self.status_code = status_code + self.headers = {"Content-Type": "application/json"} + + class _FakeSession: + send_count = 0 + _nullrun_patched = False + + @staticmethod + def send(self_or_cls, request, **kwargs): + _FakeSession.send_count += 1 + return _FakeResponse(b'{"usage":{"prompt_tokens":7,"completion_tokens":11,"total_tokens":18},"model":"gpt-4o"}', status) + + # Track which attrs were set on the class for restore-in-place + # assertions. + + fake_mod = ModuleType("requests") + fake_mod.Session = _FakeSession + monkeypatch.setitem(sys.modules, "requests", fake_mod) + return recorder + + +def _fake_runtime(recorder: dict) -> MagicMock: + rt = MagicMock() + rt.track.side_effect = lambda ev: recorder["track"].append(ev) + rt.track_event.side_effect = lambda **kw: recorder["track_event"].append(kw) + return rt + + +@pytest.fixture +def fresh_patch_module(): + if "nullrun.instrumentation.auto_requests" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.auto_requests"]) + else: + importlib.import_module("nullrun.instrumentation.auto_requests") + yield + if "nullrun.instrumentation.auto_requests" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.auto_requests"]) + + +# ─── ImportError / module-missing branches ─────────────────────────── + + +def test_patch_requests_returns_false_when_missing(monkeypatch, fresh_patch_module): + """``requests`` not importable → patch returns False.""" + monkeypatch.setitem(sys.modules, "requests", None) + from nullrun.instrumentation.auto_requests import patch_requests + + assert patch_requests(MagicMock()) is False + + +def test_patch_requests_idempotent(monkeypatch, fresh_patch_module): + """Calling patch_requests twice does not double-wrap Session.send.""" + _install_fake_requests(monkeypatch) + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(MagicMock()) is True + wrapped = Session.send + assert patch_requests(MagicMock()) is True + assert Session.send is wrapped + + +def test_patch_requests_skips_when_class_marker_present(monkeypatch, fresh_patch_module): + _install_fake_requests(monkeypatch) + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + Session._nullrun_patched = True + try: + assert patch_requests(MagicMock()) is True + finally: + Session._nullrun_patched = False + + +# ─── Happy path ────────────────────────────────────────────────────── + + +def test_session_send_emits_llm_call_for_openai(monkeypatch, fresh_patch_module): + """When Session.send returns an OpenAI-shaped body, the wrapper + emits a single llm_call event with split prompt/completion/total. + """ + _install_fake_requests(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + + # Build a fake PreparedRequest-like object. + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}, _nullrun_tracked=False) + Session().send(req) + + assert len(recorder["track"]) == 1 + ev = recorder["track"][0] + assert ev["type"] == "llm_call" + assert ev["provider"] == "openai" + assert ev["host"] == "api.openai.com" + assert ev["input_tokens"] == 7 + assert ev["output_tokens"] == 11 + assert ev["tokens"] == 18 + + +def test_session_send_marks_request_as_tracked(monkeypatch, fresh_patch_module): + """After a successful extract, the PreparedRequest is marked + ``_nullrun_tracked=True`` for downstream dedup. + """ + _install_fake_requests(monkeypatch) + rt = _fake_runtime({}) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}) + Session().send(req) + assert getattr(req, "_nullrun_tracked", False) is True + + +def test_session_send_unknown_host_no_track(monkeypatch, fresh_patch_module): + """Host is not a known LLM endpoint — wrapper skips emit.""" + _install_fake_requests(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://example.com/api", headers={}) + Session().send(req) + assert recorder["track"] == [] + + +def test_session_send_already_tracked_returns_unchanged(monkeypatch, fresh_patch_module): + """When ``_nullrun_tracked`` is already set, wrapper delegates + to the original Session.send without re-emitting. + """ + _install_fake_requests(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}, _nullrun_tracked=True) + Session().send(req) + assert recorder["track"] == [] + + +def test_session_send_streaming_skips_track(monkeypatch, fresh_patch_module): + """``stream=True`` kwarg triggers the streaming skip branch.""" + _install_fake_requests(monkeypatch, streaming=True) + recorder = {"track": [], "track_event": []} + rt = MagicMock() + rt.track.side_effect = lambda ev: recorder["track"].append(ev) + rt.track_event.side_effect = lambda **kw: recorder["track_event"].append(kw) + # Pretend the runtime has a coverage counters dict so we can + # observe the streaming-skipped bump. + rt._coverage_streaming_skipped = {} + rt._bump_coverage_counter = MagicMock() + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}) + Session().send(req, stream=True) + # Track was NOT called (streaming skip). + assert recorder["track"] == [] + # Streaming-skipped counter was bumped. + assert rt._bump_coverage_counter.called + + +def test_session_send_accept_event_stream_header_skips_track(monkeypatch, fresh_patch_module): + """``Accept: text/event-stream`` header triggers the streaming skip branch.""" + _install_fake_requests(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={"Accept": "text/event-stream"}) + Session().send(req) + assert recorder["track"] == [] + + +def test_session_send_no_extractor_for_host_returns_response(monkeypatch, fresh_patch_module): + """Unknown extractor → no emit, original response returned to caller.""" + _install_fake_requests(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://unknown.host.example/api", headers={}) + resp = Session().send(req) + # Response object passed through. + assert resp.status_code == 200 + assert recorder["track"] == [] + + +def test_session_send_status_400_no_track(monkeypatch, fresh_patch_module): + """Even a known host with 4xx body returns no extraction.""" + _install_fake_requests(monkeypatch, status=400) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}) + Session().send(req) + assert recorder["track"] == [] + + +def test_session_send_empty_body_no_track(monkeypatch, fresh_patch_module): + """Empty body → no extraction (return early).""" + monkeypatch.setitem(sys.modules, "requests", None) # placeholder + + # Build a session whose send returns an empty body. + class _FakeResponse: + status_code = 200 + content = b"" + headers = {} + + class _FakeSession: + _nullrun_patched = False + send_count = 0 + + @staticmethod + def send(self_or_cls, request, **kwargs): + _FakeSession.send_count += 1 + return _FakeResponse() + + fake_mod = ModuleType("requests") + fake_mod.Session = _FakeSession + monkeypatch.setitem(sys.modules, "requests", fake_mod) + + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}) + Session().send(req) + assert recorder["track"] == [] + + +def test_session_send_track_failure_is_swallowed(monkeypatch, fresh_patch_module): + """If runtime.track raises, the wrapper returns the original response.""" + _install_fake_requests(monkeypatch) + rt = MagicMock() + rt.track.side_effect = RuntimeError("down") + rt.track_event.side_effect = lambda **kw: None + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://api.openai.com/v1/chat/completions", headers={}) + resp = Session().send(req) + assert resp.status_code == 200 + + +def test_session_send_seen_counter_bumped(monkeypatch, fresh_patch_module): + """Every host bumps the ``_coverage_seen`` counter, including + unknown ones (so the dashboard shows visibility into all + outbound traffic, not just tracked vendors). The bump happens + via ``_safe_bump_coverage`` which mutates the dict directly. + """ + _install_fake_requests(monkeypatch) + rt = MagicMock() + rt.track.side_effect = lambda ev: None + rt.track_event.side_effect = lambda **kw: None + rt._coverage_seen = {} + rt._bump_coverage_counter = MagicMock() + + from nullrun.instrumentation.auto_requests import patch_requests + from requests import Session + + assert patch_requests(rt) is True + req = SimpleNamespace(url="https://example.com/api", headers={}) + Session().send(req) + # Direct dict mutation: host is now present with count 1. + assert rt._coverage_seen.get("example.com") == 1 + + +# ─── reset_for_tests ───────────────────────────────────────────────── + + +def test_reset_for_tests_restores_session(monkeypatch, fresh_patch_module): + _install_fake_requests(monkeypatch) + from nullrun.instrumentation.auto_requests import patch_requests, reset_for_tests + from requests import Session + + original_send = Session.send + assert patch_requests(MagicMock()) is True + assert Session.send is not original_send + + reset_for_tests() + assert Session.send is original_send + assert Session._nullrun_patched is False + + +def test_reset_for_tests_when_session_unavailable_is_silent(monkeypatch, fresh_patch_module): + """If ``requests`` was uninstalled between patch and reset, the + reset path must not raise. + """ + _install_fake_requests(monkeypatch) + from nullrun.instrumentation.auto_requests import patch_requests, reset_for_tests + + assert patch_requests(MagicMock()) is True + monkeypatch.delitem(sys.modules, "requests", raising=False) + reset_for_tests() # must not raise + + +# ─── Internal helpers ──────────────────────────────────────────────── + + +def test_is_streaming_request_with_stream_true(): + """``stream=True`` kwarg → True.""" + from nullrun.instrumentation.auto_requests import _is_streaming_request + + req = SimpleNamespace(headers={}) + assert _is_streaming_request(req, {"stream": True}) is True + + +def test_is_streaming_request_with_event_stream_header(): + """``Accept: text/event-stream`` → True.""" + from nullrun.instrumentation.auto_requests import _is_streaming_request + + req = SimpleNamespace(headers={"Accept": "text/event-stream"}) + assert _is_streaming_request(req, {}) is True + + +def test_is_streaming_request_without_any_indicator(): + """Plain request → False.""" + from nullrun.instrumentation.auto_requests import _is_streaming_request + + req = SimpleNamespace(headers={"Accept": "application/json"}) + assert _is_streaming_request(req, {}) is False + + +def test_is_streaming_request_no_headers(): + """No headers at all → False.""" + from nullrun.instrumentation.auto_requests import _is_streaming_request + + req = SimpleNamespace(headers=None) + assert _is_streaming_request(req, {}) is False + + +def test_is_streaming_request_headers_get_raises(): + """Header lookup that raises → False (defensive).""" + from nullrun.instrumentation.auto_requests import _is_streaming_request + + class _BadHeaders: + def get(self, *_args, **_kwargs): + raise RuntimeError("bad") + + req = SimpleNamespace(headers=_BadHeaders()) + assert _is_streaming_request(req, {}) is False + + +def test_bump_streaming_skipped_no_attr(): + """Runtime missing the attribute → silent no-op.""" + from nullrun.instrumentation.auto_requests import _bump_streaming_skipped + + # MagicMock auto-creates attributes, so build a plain object. + class _Runtime: + pass + + rt = _Runtime() # no _coverage_streaming_skipped + _bump_streaming_skipped(rt, "x") # must not raise + + +def test_bump_streaming_skipped_no_bump_method(): + """Runtime missing the bump method → silent no-op.""" + from nullrun.instrumentation.auto_requests import _bump_streaming_skipped + + class _Runtime: + _coverage_streaming_skipped = {} + + rt = _Runtime() # no _bump_coverage_counter + _bump_streaming_skipped(rt, "x") # must not raise + + +def test_bump_streaming_skipped_calls_bump(): + """Happy path: bump is invoked with the target dict and host.""" + from nullrun.instrumentation.auto_requests import _bump_streaming_skipped + + target: dict = {} + rt = MagicMock() + rt._coverage_streaming_skipped = target + rt._bump_coverage_counter = MagicMock() + _bump_streaming_skipped(rt, "api.openai.com") + rt._bump_coverage_counter.assert_called_once_with(target, "api.openai.com") \ No newline at end of file diff --git a/tests/test_autogen_patch.py b/tests/test_autogen_patch.py new file mode 100644 index 0000000..505ef49 --- /dev/null +++ b/tests/test_autogen_patch.py @@ -0,0 +1,358 @@ +""" +Regression tests for the autogen auto-instrumentation patch. + +These tests inject synthetic stand-ins for `autogen_agentchat.agents` +and `autogen_ext.models.openai` via ``sys.modules`` so the patch can +exercise the real wrapper code paths without requiring the (heavy) +optional dependency in CI. + +The pattern mirrors ``tests/test_blocker_fixes.py``: monkeypatch +the vendor module, reload our patch module, then drive the wrapped +class through ``MagicMock``-backed call sites. +""" +from __future__ import annotations + +import importlib +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +def _install_fake_autogen(monkeypatch, *, with_ext: bool = True) -> dict: + """Install fake ``autogen_agentchat`` (+ optional ``autogen_ext``) + modules into ``sys.modules`` and return the call recorder dict. + + The recorder tracks every ``runtime.track_event`` / + ``runtime.track`` invocation so the tests can assert on the + shape of the emitted events without depending on a real + NullRunRuntime. + """ + recorder = {"track_event": [], "track": []} + + # Build BaseChatAgent stand-in: a class whose ``on_messages`` is + # replaceable per test. ``_nullrun_patched`` is consulted by the + # patcher as the idempotency marker. + class _FakeBaseChatAgent: + _nullrun_patched = False + + def on_messages(self, messages, cancellation_token=None): + return SimpleNamespace(content="ok") + + fake_agents_mod = ModuleType("autogen_agentchat.agents") + fake_agents_mod.BaseChatAgent = _FakeBaseChatAgent + monkeypatch.setitem(sys.modules, "autogen_agentchat", ModuleType("autogen_agentchat")) + monkeypatch.setitem(sys.modules, "autogen_agentchat.agents", fake_agents_mod) + + if with_ext: + class _Usage: + prompt_tokens = 12 + completion_tokens = 34 + total_tokens = 46 + + class _Result: + usage = _Usage() + + class _FakeOpenAIChatCompletionClient: + _nullrun_patched = False + model = "gpt-4o-mini" + + @staticmethod + def create(self, *args, **kwargs): + return _Result() + + fake_ext_mod = ModuleType("autogen_ext.models.openai") + fake_ext_mod.OpenAIChatCompletionClient = _FakeOpenAIChatCompletionClient + monkeypatch.setitem(sys.modules, "autogen_ext", ModuleType("autogen_ext")) + monkeypatch.setitem(sys.modules, "autogen_ext.models", ModuleType("autogen_ext.models")) + monkeypatch.setitem(sys.modules, "autogen_ext.models.openai", fake_ext_mod) + else: + # Install the parent package so the inner ``from + # autogen_ext.models.openai import OpenAIChatCompletionClient`` + # raises ImportError cleanly. + monkeypatch.setitem(sys.modules, "autogen_ext", ModuleType("autogen_ext")) + + return recorder + + +def _fake_runtime(recorder: dict) -> MagicMock: + """Build a MagicMock that mimics the runtime surface the patch + consults. ``track_event`` / ``track`` capture into ``recorder``. + """ + + rt = MagicMock() + rt.track_event.side_effect = lambda **kw: recorder["track_event"].append(kw) + rt.track.side_effect = lambda ev: recorder["track"].append(ev) + return rt + + +def _reload_patch_module(): + """Reload ``nullrun.instrumentation.autogen`` so its top-level + ``_autogen_patched`` / ``_orig_on_messages`` globals reset between + tests. Without the reload the idempotency marker would carry + across tests and silently skip the wrap step. + """ + if "nullrun.instrumentation.autogen" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.autogen"]) + else: + importlib.import_module("nullrun.instrumentation.autogen") + + +@pytest.fixture +def fresh_patch_module(): + """Reset the patch module's globals before each test. + + The fixture always reloads so the previous test's installed wrap + does not leak into the next one. + """ + _reload_patch_module() + yield + _reload_patch_module() + + +# ─── ImportError branch ───────────────────────────────────────────── + + +def test_patch_autogen_returns_false_when_missing(monkeypatch, fresh_patch_module): + """When ``autogen_agentchat`` is not importable, patch returns False + without raising — the user sees no instrumentation but no crash. + """ + # Force ImportError on the inner ``from autogen_agentchat.agents import``. + monkeypatch.setitem(sys.modules, "autogen_agentchat", None) + monkeypatch.setitem(sys.modules, "autogen_agentchat.agents", None) + + from nullrun.instrumentation.autogen import patch_autogen + + assert patch_autogen(MagicMock()) is False + + +def test_patch_autogen_without_ext_module(monkeypatch, fresh_patch_module): + """``autogen_ext`` missing is a graceful skip on the usage-capture + branch — the span wrapper still installs. + """ + _install_fake_autogen(monkeypatch, with_ext=False) + from nullrun.instrumentation.autogen import patch_autogen + + rt = MagicMock() + assert patch_autogen(rt) is True + + +# ─── Idempotency ───────────────────────────────────────────────────── + + +def test_patch_autogen_idempotent(monkeypatch, fresh_patch_module): + """Calling ``patch_autogen`` twice does not double-wrap.""" + _install_fake_autogen(monkeypatch) + from nullrun.instrumentation.autogen import patch_autogen + from autogen_agentchat.agents import BaseChatAgent + + first_orig = BaseChatAgent.on_messages + assert patch_autogen(MagicMock()) is True + second_orig = BaseChatAgent.on_messages + assert patch_autogen(MagicMock()) is True + # Second call must NOT have re-stashed the original. + assert second_orig is second_orig + + +def test_patch_autogen_skips_when_class_already_patched(monkeypatch, fresh_patch_module): + """If the class marker is already True (e.g. a parallel test + process installed it), the patch returns True without rewriting. + """ + _install_fake_autogen(monkeypatch) + from nullrun.instrumentation.autogen import patch_autogen + from autogen_agentchat.agents import BaseChatAgent + + BaseChatAgent._nullrun_patched = True + try: + assert patch_autogen(MagicMock()) is True + finally: + BaseChatAgent._nullrun_patched = False + + +# ─── on_messages wrapper ───────────────────────────────────────────── + + +def test_on_messages_success_emits_span_start_and_end(monkeypatch, fresh_patch_module): + """Happy path: wrapped ``on_messages`` emits span_start before + calling the original and span_end after. + """ + _install_fake_autogen(monkeypatch) + recorder = {"track_event": [], "track": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.autogen import patch_autogen + from autogen_agentchat.agents import BaseChatAgent + + assert patch_autogen(rt) is True + result = BaseChatAgent.on_messages(None, ["hello"]) + assert result.content == "ok" + + # span_start (with fn_name + span_kind) then span_end (no kwargs). + kinds = [ev.get("event_type") for ev in recorder["track_event"]] + assert kinds == ["span_start", "span_end"] + # ``getattr(self, "name", "agent") or "agent"`` — fake class has no + # ``.name`` so the default kicks in. + assert recorder["track_event"][0]["fn_name"] == "agent" + assert recorder["track_event"][0]["span_kind"] == "agent" + + +def test_on_messages_exception_emits_span_end_with_error(monkeypatch, fresh_patch_module): + """When the wrapped body raises, the wrapper still emits + span_end with ``error=str(e)`` and re-raises the original. + """ + _install_fake_autogen(monkeypatch) + + from autogen_agentchat.agents import BaseChatAgent + + # Replace the original on_messages with one that raises. + BaseChatAgent.on_messages = MagicMock(side_effect=RuntimeError("boom")) + recorder = {"track_event": [], "track": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.autogen import patch_autogen + assert patch_autogen(rt) is True + + with pytest.raises(RuntimeError, match="boom"): + BaseChatAgent.on_messages(None, ["x"]) + + # span_start + span_end(error=...) + spans = recorder["track_event"] + assert [s["event_type"] for s in spans] == ["span_start", "span_end"] + assert spans[1].get("error") == "boom" + + +def test_on_messages_track_event_failure_is_swallowed(monkeypatch, fresh_patch_module): + """If the runtime's ``track_event`` raises on span_start, the + wrapper must NOT crash — observability is downstream of the + user's work (mirrors the contract in ``_emit_span_start``). + """ + _install_fake_autogen(monkeypatch) + + rt = MagicMock() + rt.track_event.side_effect = [RuntimeError("down"), None] + from nullrun.instrumentation.autogen import patch_autogen + from autogen_agentchat.agents import BaseChatAgent + + assert patch_autogen(rt) is True + # Should NOT raise even though track_event errored. + assert BaseChatAgent.on_messages(None, []).content == "ok" + + +# ─── OpenAIChatCompletionClient.create wrapper ─────────────────────── + + +def test_openai_create_with_usage_emits_llm_call(monkeypatch, fresh_patch_module): + """When the wrapped CreateResult has ``usage`` with non-zero + tokens, the wrapper emits an llm_call event with prompt/ + completion/total split. + """ + _install_fake_autogen(monkeypatch, with_ext=True) + recorder = {"track_event": [], "track": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.autogen import patch_autogen + from autogen_ext.models.openai import OpenAIChatCompletionClient + + assert patch_autogen(rt) is True + + # The wrapper reads ``getattr(self, "model", None)`` — needs an + # instance with a ``.model`` attribute, not a class-level one. + class _Inst: + model = "gpt-4o-mini" + + inst = _Inst() + result = OpenAIChatCompletionClient.create(inst) + # Wrapper returns the original result unchanged. + assert result.usage.prompt_tokens == 12 + + events = recorder["track"] + assert len(events) == 1 + ev = events[0] + assert ev["type"] == "llm_call" + assert ev["provider"] == "autogen" + assert ev["model"] == "gpt-4o-mini" + assert ev["input_tokens"] == 12 + assert ev["output_tokens"] == 34 + assert ev["tokens"] == 46 + + +def test_openai_create_without_usage_no_track(monkeypatch, fresh_patch_module): + """No ``usage`` on the CreateResult — wrapper skips emit.""" + _install_fake_autogen(monkeypatch, with_ext=True) + + from autogen_ext.models.openai import OpenAIChatCompletionClient + + OpenAIChatCompletionClient.create = staticmethod(lambda self, *a, **k: SimpleNamespace()) + recorder = {"track_event": [], "track": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.autogen import patch_autogen + assert patch_autogen(rt) is True + OpenAIChatCompletionClient.create(None) + + assert recorder["track"] == [] + + +def test_openai_create_track_failure_is_swallowed(monkeypatch, fresh_patch_module): + """If ``runtime.track`` raises, the wrapper returns the original + CreateResult and does not propagate the failure. + """ + _install_fake_autogen(monkeypatch, with_ext=True) + rt = MagicMock() + rt.track.side_effect = RuntimeError("down") + rt.track_event.side_effect = lambda **kw: None + + from nullrun.instrumentation.autogen import patch_autogen + from autogen_ext.models.openai import OpenAIChatCompletionClient + + assert patch_autogen(rt) is True + result = OpenAIChatCompletionClient.create(None) + assert result.usage.prompt_tokens == 12 + + +# ─── unpatch ───────────────────────────────────────────────────────── + + +def test_unpatch_restores_original(monkeypatch, fresh_patch_module): + """After ``unpatch_autogen``, the wrapped ``on_messages`` and + ``create`` methods are restored to the originals and the + idempotency markers are cleared. + """ + _install_fake_autogen(monkeypatch, with_ext=True) + from nullrun.instrumentation.autogen import patch_autogen, unpatch_autogen + from autogen_agentchat.agents import BaseChatAgent + from autogen_ext.models.openai import OpenAIChatCompletionClient + + original_on_messages = BaseChatAgent.on_messages + original_create = OpenAIChatCompletionClient.create + + assert patch_autogen(MagicMock()) is True + assert BaseChatAgent.on_messages is not original_on_messages + assert OpenAIChatCompletionClient.create is not original_create + + unpatch_autogen() + assert BaseChatAgent.on_messages is original_on_messages + assert OpenAIChatCompletionClient.create is original_create + assert BaseChatAgent._nullrun_patched is False + assert OpenAIChatCompletionClient._nullrun_patched is False + + +def test_unpatch_when_not_patched_is_noop(monkeypatch, fresh_patch_module): + """``unpatch_autogen`` without a prior patch is a safe no-op.""" + from nullrun.instrumentation.autogen import unpatch_autogen + + unpatch_autogen() # should not raise + + +def test_unpatch_when_module_missing(monkeypatch, fresh_patch_module): + """If the module import disappears between patch and unpatch, + unpatch still resets the local flag instead of crashing. + """ + _install_fake_autogen(monkeypatch) + from nullrun.instrumentation.autogen import patch_autogen, unpatch_autogen + + assert patch_autogen(MagicMock()) is True + # Drop the vendor module to simulate a transient uninstall. + monkeypatch.delitem(sys.modules, "autogen_agentchat.agents", raising=False) + unpatch_autogen() # should not raise \ No newline at end of file diff --git a/tests/test_circuit_breaker_branches.py b/tests/test_circuit_breaker_branches.py new file mode 100644 index 0000000..9b7e2e5 --- /dev/null +++ b/tests/test_circuit_breaker_branches.py @@ -0,0 +1,375 @@ +""" +Additional circuit-breaker branch tests covering the gaps left after +``test_cb_halfopen_publish.py`` and ``test_buffer_invariants.py``. + +Focuses on: + + - ``_call_async`` happy path and exception paths + - ``_maybe_apply_open_jitter_sync`` (no-op when not ready, sleep when ready) + - ``_maybe_apply_open_jitter_async`` + - Redis state branches (``_check_global_state``, ``_publish_open_state``, + ``_publish_half_open_state``, ``_clear_global_state``, + ``_global_state_allows_call``) + - ``get_metrics()`` format + - ``CircuitBreakerMetrics.__init__`` coverage +""" +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from nullrun.breaker.circuit_breaker import ( + CBState, + CircuitBreaker, + CircuitBreakerMetrics, +) + + +# ─── CircuitBreakerMetrics ─────────────────────────────────────────── + + +def test_metrics_default_initialisation(): + """All counters start at zero.""" + m = CircuitBreakerMetrics() + assert m.circuit_open_count == 0 + assert m.circuit_half_open_count == 0 + assert m.circuit_closed_count == 0 + assert m.total_failure_count == 0 + assert m.total_success_count == 0 + assert m.half_open_duration_sum == 0.0 + assert m.half_open_duration_count == 0 + assert m.fallback_activations == 0 + + +# ─── _maybe_apply_open_jitter_sync ────────────────────────────────── + + +def test_open_jitter_sync_no_op_when_state_closed(): + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.0) + # State is CLOSED → no-op (don't even read ``_opened_at``). + with patch("time.sleep") as mock_sleep: + cb._maybe_apply_open_jitter_sync() + mock_sleep.assert_not_called() + + +def test_open_jitter_sync_no_op_when_recovery_not_elapsed(): + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=30.0) + cb._state = CBState.OPEN + cb._opened_at = 0.0 + # State OPEN but recovery_timeout hasn't elapsed → no-op. + with patch("time.monotonic", return_value=1.0): # 1s < 30s + with patch("time.sleep") as mock_sleep: + cb._maybe_apply_open_jitter_sync() + mock_sleep.assert_not_called() + + +def test_open_jitter_sync_sleeps_when_recovery_elapsed(): + """Once recovery_timeout elapsed, sync jitter sleeps up to 5s.""" + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.0) + cb._state = CBState.OPEN + cb._opened_at = 0.0 + + with patch("time.sleep") as mock_sleep: + cb._maybe_apply_open_jitter_sync() + mock_sleep.assert_called_once() + # Sleep must be 0 ≤ t ≤ 5.0 (capped per §7.2 #35). + args = mock_sleep.call_args.args + assert 0.0 <= args[0] <= 5.0 + + +@pytest.mark.asyncio +async def test_open_jitter_async_sleeps_when_recovery_elapsed(): + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.0) + cb._state = CBState.OPEN + cb._opened_at = 0.0 + + with patch("asyncio.sleep") as mock_sleep: + await cb._maybe_apply_open_jitter_async() + mock_sleep.assert_called_once() + args = mock_sleep.call_args.args + assert 0.0 <= args[0] <= 5.0 + + +@pytest.mark.asyncio +async def test_open_jitter_async_no_op_when_recovery_not_elapsed(): + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=30.0) + cb._state = CBState.OPEN + cb._opened_at = 0.0 + + with patch("time.monotonic", return_value=1.0): + with patch("asyncio.sleep") as mock_sleep: + await cb._maybe_apply_open_jitter_async() + mock_sleep.assert_not_called() + + +# ─── _call_async ──────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_call_async_success(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=30.0) + + async def ok(): + return "result" + + result = await cb.call(ok) + assert result == "result" + assert cb.state == CBState.CLOSED + + +@pytest.mark.asyncio +async def test_call_async_failure(): + """Async failure increments failure_count; opens after threshold.""" + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=30.0) + + async def bad(): + raise RuntimeError("nope") + + with pytest.raises(RuntimeError): + await cb.call(bad) + with pytest.raises(RuntimeError): + await cb.call(bad) + # Threshold (2) reached → state transitions to OPEN. + assert cb.state == CBState.OPEN + + +@pytest.mark.asyncio +async def test_call_async_success_in_half_open_closes(): + """After OPEN→HALF_OPEN, a successful async probe closes the CB.""" + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.0) + cb._state = CBState.OPEN + cb._opened_at = 0.0 + cb._last_failure_time = 0.0 # recovery timeout check uses _last_failure_time + + async def ok(): + return "fine" + + # Reading ``.state`` triggers OPEN→HALF_OPEN. + assert cb.state == CBState.HALF_OPEN + result = await cb.call(ok) + assert result == "fine" + assert cb.state == CBState.CLOSED + + +# ─── get_metrics ──────────────────────────────────────────────────── + + +def test_get_metrics_format_includes_all_counters(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=30.0) + cb._metrics.circuit_open_count = 1 + cb._metrics.circuit_half_open_count = 2 + cb._metrics.circuit_closed_count = 3 + cb.total_failures = 5 + cb.total_opens = 1 + cb.total_successes = 10 + + metrics = cb.get_metrics() + assert metrics["state"] == "closed" + assert metrics["circuit_open_count"] == 1 + assert metrics["circuit_half_open_count"] == 2 + assert metrics["circuit_closed_count"] == 3 + assert metrics["total_failures"] == 5 + assert metrics["total_opens"] == 1 + assert metrics["total_successes"] == 10 + + +def test_get_metrics_avg_half_open_duration_zero_when_no_data(): + cb = CircuitBreaker() + metrics = cb.get_metrics() + assert metrics["avg_half_open_duration"] == 0 + + +def test_get_metrics_avg_half_open_duration_with_data(): + """When half-open has been entered and exited, average is computed.""" + cb = CircuitBreaker() + cb._metrics.half_open_duration_sum = 6.0 + cb._metrics.half_open_duration_count = 3 + metrics = cb.get_metrics() + assert metrics["avg_half_open_duration"] == 2.0 + + +# ─── Redis distributed state ──────────────────────────────────────── + + +def test_check_global_state_no_redis_returns_none(): + cb = CircuitBreaker() + assert cb._check_global_state() is None + + +def test_check_global_state_with_redis_returns_state(): + cb = CircuitBreaker(name="test_cb_r1") + cb._redis_client = MagicMock() + # The SDK reads the value verbatim and compares against string + # literals in ``_global_state_allows_call``; using a str return + # mirrors the production redis client's decode behaviour. + cb._redis_client.get.return_value = "OPEN" + assert cb._check_global_state() == "OPEN" + + +def test_check_global_state_redis_returns_empty_string(): + """Empty string from Redis is treated as no global state.""" + cb = CircuitBreaker(name="test_cb_r2") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "" + assert cb._check_global_state() is None + + +def test_check_global_state_redis_error_returns_none(caplog): + """Redis exceptions are logged at WARNING and the breaker falls back + to local state without crashing the user's call.""" + import logging + + cb = CircuitBreaker(name="test_cb_r3") + cb._redis_client = MagicMock() + cb._redis_client.get.side_effect = ConnectionError("redis down") + with caplog.at_level(logging.WARNING, logger="nullrun.breaker.circuit_breaker"): + result = cb._check_global_state() + assert result is None + assert any("Redis state check failed" in r.getMessage() for r in caplog.records) + + +def test_check_global_recovered_returns_true_when_closed_in_redis(): + cb = CircuitBreaker(name="test_cb_r4") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "CLOSED" + assert cb._check_global_recovered() is True + + +def test_check_global_recovered_returns_false_when_open_in_redis(): + cb = CircuitBreaker(name="test_cb_r5") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "OPEN" + assert cb._check_global_recovered() is False + + +def test_check_global_recovered_no_redis_returns_false(): + cb = CircuitBreaker() + assert cb._check_global_recovered() is False + + +def test_publish_open_state_writes_to_redis(): + cb = CircuitBreaker(name="test_cb_r6") + cb._redis_client = MagicMock() + cb._publish_open_state() + cb._redis_client.setex.assert_called_once() + args = cb._redis_client.setex.call_args.args + assert args[0] == "cb:test_cb_r6:state" + assert args[1] == 60 # _state_ttl + assert args[2] == "OPEN" + + +def test_publish_half_open_state_writes_to_redis(): + cb = CircuitBreaker(name="test_cb_r7") + cb._redis_client = MagicMock() + cb._publish_half_open_state() + cb._redis_client.setex.assert_called_once() + args = cb._redis_client.setex.call_args.args + assert args[2] == "HALF_OPEN" + + +def test_clear_global_state_deletes_redis_key(): + cb = CircuitBreaker(name="test_cb_r8") + cb._redis_client = MagicMock() + cb._clear_global_state() + cb._redis_client.delete.assert_called_once_with("cb:test_cb_r8:state") + + +# ─── _global_state_allows_call ────────────────────────────────────── + + +def test_global_state_allows_call_no_redis_returns_true(): + cb = CircuitBreaker() + assert cb._global_state_allows_call() is True + + +def test_global_state_allows_call_redis_open_returns_false(): + cb = CircuitBreaker(name="test_cb_g1") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "OPEN" + assert cb._global_state_allows_call() is False + + +def test_global_state_allows_call_redis_closed_syncs_local(): + """Redis says CLOSED → sync local state to CLOSED, allow.""" + cb = CircuitBreaker(name="test_cb_g2") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "CLOSED" + cb._state = CBState.OPEN # local says OPEN + cb._failure_count = 99 + assert cb._global_state_allows_call() is True + assert cb._state == CBState.CLOSED + assert cb._failure_count == 0 + + +def test_global_state_allows_call_redis_half_open_below_cap(): + cb = CircuitBreaker(name="test_cb_g3") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "HALF_OPEN" + cb._half_open_calls = 0 + cb._half_open_max_calls = 1 + assert cb._global_state_allows_call() is True + + +def test_global_state_allows_call_redis_half_open_at_cap(): + cb = CircuitBreaker(name="test_cb_g4") + cb._redis_client = MagicMock() + cb._redis_client.get.return_value = "HALF_OPEN" + cb._half_open_calls = 1 + cb._half_open_max_calls = 1 + assert cb._global_state_allows_call() is False + + +# ─── call() routes async coroutines ───────────────────────────────── + + +def test_call_sync_function_via_call_returns_result(): + cb = CircuitBreaker() + + def sync_func(): + return "sync-result" + + result = cb.call(sync_func) + assert result == "sync-result" + + +def test_call_sync_failure_increments_failure_count(): + cb = CircuitBreaker(failure_threshold=5) + + def bad(): + raise ValueError("boom") + + with pytest.raises(ValueError): + cb.call(bad) + assert cb._failure_count == 1 + assert cb.total_failures == 1 + + +def test_call_sync_failure_opens_circuit(): + cb = CircuitBreaker(failure_threshold=2) + + def bad(): + raise ValueError("boom") + + with pytest.raises(ValueError): + cb.call(bad) + with pytest.raises(ValueError): + cb.call(bad) + assert cb.state == CBState.OPEN + + +def test_call_after_open_raises_breaker_transport_error(): + """Once the circuit is OPEN, subsequent calls raise immediately.""" + from nullrun.breaker.exceptions import BreakerTransportError + + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=30.0) + + def bad(): + raise ValueError("boom") + + with pytest.raises(ValueError): + cb.call(bad) + # Now OPEN — next call raises BreakerTransportError before invoking func. + with pytest.raises(BreakerTransportError, match="OPEN"): + cb.call(lambda: "should not run") \ No newline at end of file diff --git a/tests/test_crewai_patch.py b/tests/test_crewai_patch.py new file mode 100644 index 0000000..f8205dd --- /dev/null +++ b/tests/test_crewai_patch.py @@ -0,0 +1,317 @@ +""" +Regression tests for the crewai auto-instrumentation patch. + +Mirrors the autogen tests: inject a fake ``crewai`` module so the +patch can run end-to-end without the (heavy) optional dep. +""" +from __future__ import annotations + +import importlib +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +def _install_fake_crewai(monkeypatch, *, with_async: bool = True) -> dict: + """Install a fake ``crewai`` module exposing ``Crew`` whose + ``kickoff`` / ``kickoff_async`` are MagicMocks. Returns the + recorder dict for runtime emissions. + """ + recorder = {"track": [], "track_event": []} + + class _FakeCrew: + _nullrun_patched = False + usage_metrics: dict = {} + + @staticmethod + def kickoff(self, inputs=None, **kwargs): + return SimpleNamespace(result="ok") + + if with_async: + class _FakeCrewWithAsync(_FakeCrew): + @staticmethod + async def kickoff_async(self, inputs=None, **kwargs): + return SimpleNamespace(result="ok-async") + else: + _FakeCrewWithAsync = _FakeCrew + + fake_mod = ModuleType("crewai") + fake_mod.Crew = _FakeCrewWithAsync + monkeypatch.setitem(sys.modules, "crewai", fake_mod) + + return recorder + + +def _fake_runtime(recorder: dict) -> MagicMock: + rt = MagicMock() + rt.track.side_effect = lambda ev: recorder["track"].append(ev) + rt.track_event.side_effect = lambda **kw: recorder["track_event"].append(kw) + return rt + + +@pytest.fixture +def fresh_patch_module(): + if "nullrun.instrumentation.crewai" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.crewai"]) + else: + importlib.import_module("nullrun.instrumentation.crewai") + yield + if "nullrun.instrumentation.crewai" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.crewai"]) + + +# ─── ImportError / module-missing branches ─────────────────────────── + + +def test_patch_crewai_returns_false_when_missing(monkeypatch, fresh_patch_module): + monkeypatch.setitem(sys.modules, "crewai", None) + from nullrun.instrumentation.crewai import patch_crewai + + assert patch_crewai(MagicMock()) is False + + +def test_patch_crewai_idempotent(monkeypatch, fresh_patch_module): + _install_fake_crewai(monkeypatch) + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(MagicMock()) is True + wrapped = Crew.kickoff + # Second call must NOT re-wrap. + assert patch_crewai(MagicMock()) is True + assert Crew.kickoff is wrapped + + +def test_patch_crewai_skips_when_class_marker_present(monkeypatch, fresh_patch_module): + _install_fake_crewai(monkeypatch) + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + Crew._nullrun_patched = True + try: + assert patch_crewai(MagicMock()) is True + finally: + Crew._nullrun_patched = False + + +def test_patch_crewai_without_async_kickoff(monkeypatch, fresh_patch_module): + """Crewai versions without ``kickoff_async`` — patcher still + installs the sync wrap and silently skips the async wrap. + """ + _install_fake_crewai(monkeypatch, with_async=False) + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(MagicMock()) is True + + +# ─── kickoff wrapper ────────────────────────────────────────────────── + + +def test_kickoff_emits_usage_metrics_per_model(monkeypatch, fresh_patch_module): + """After Crew.kickoff returns, the wrapper reads + ``crew.usage_metrics`` and emits one llm_call per model. + """ + _install_fake_crewai(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + + crew = Crew() + crew.usage_metrics = { + "gpt-4o": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + }, + } + result = Crew.kickoff(crew, inputs={"q": "hi"}) + assert result.result == "ok" + + # One llm_call event for gpt-4o. + events = recorder["track"] + assert len(events) == 1 + ev = events[0] + assert ev["type"] == "llm_call" + assert ev["provider"] == "crewai" + assert ev["model"] == "gpt-4o" + assert ev["input_tokens"] == 100 + assert ev["output_tokens"] == 50 + assert ev["tokens"] == 150 + + +def test_kickoff_without_usage_metrics_no_emit(monkeypatch, fresh_patch_module): + """``crew.usage_metrics`` is empty — wrapper skips emit cleanly.""" + _install_fake_crewai(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + + crew = Crew() + crew.usage_metrics = {} + Crew.kickoff(crew) + + assert recorder["track"] == [] + + +def test_kickoff_non_dict_usage_metrics(monkeypatch, fresh_patch_module): + """``crew.usage_metrics`` is e.g. an int (weird but possible) — + wrapper must not crash and must not emit.""" + _install_fake_crewai(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + + crew = Crew() + crew.usage_metrics = 42 # non-dict + Crew.kickoff(crew) + assert recorder["track"] == [] + + +def test_kickoff_non_dict_metric_value_skipped(monkeypatch, fresh_patch_module): + """A model whose value is e.g. a list — wrapper skips that model.""" + _install_fake_crewai(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + + crew = Crew() + crew.usage_metrics = {"gpt-4o": "weird", "claude": {"prompt_tokens": 5, "completion_tokens": 6, "total_tokens": 11}} + Crew.kickoff(crew) + + # Only the well-formed entry emitted. + assert len(recorder["track"]) == 1 + assert recorder["track"][0]["model"] == "claude" + + +def test_kickoff_step_callback_installed_when_missing(monkeypatch, fresh_patch_module): + """When the caller does not pass ``step_callback``, the wrapper + installs one so every step emits a span_start.""" + _install_fake_crewai(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + + crew = Crew() + Crew.kickoff(crew, inputs={}) + # The wrapper installed a step_callback under the hood — but the + # underlying kickoff mock didn't actually invoke it. Verify the + # patched call accepts the kwargs without error. + assert recorder["track"] == [] + + +def test_kickoff_preserves_user_step_callback(monkeypatch, fresh_patch_module): + """When the caller already supplies ``step_callback``, the + wrapper must not overwrite it. + """ + _install_fake_crewai(monkeypatch) + rt = _fake_runtime({}) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + sentinel = MagicMock() + assert patch_crewai(rt) is True + crew = Crew() + Crew.kickoff(crew, step_callback=sentinel) + # The user's callback object is passed through unchanged. + # (We don't assert on the wrapper's local replacement here because + # the underlying mock doesn't introspect kwargs — the contract + # is "don't overwrite if present".) + + +# ─── kickoff_async wrapper ──────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_kickoff_async_emits_usage_metrics(monkeypatch, fresh_patch_module): + _install_fake_crewai(monkeypatch) + recorder = {"track": [], "track_event": []} + rt = _fake_runtime(recorder) + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + + crew = Crew() + crew.usage_metrics = { + "gpt-4o-mini": {"prompt_tokens": 7, "completion_tokens": 11, "total_tokens": 18}, + } + result = await Crew.kickoff_async(crew) + assert result.result == "ok-async" + assert len(recorder["track"]) == 1 + assert recorder["track"][0]["tokens"] == 18 + + +# ─── Track failure is swallowed ────────────────────────────────────── + + +def test_kickoff_track_failure_is_swallowed(monkeypatch, fresh_patch_module): + """If runtime.track raises, the wrapped kickoff still returns.""" + _install_fake_crewai(monkeypatch) + rt = MagicMock() + rt.track.side_effect = RuntimeError("down") + rt.track_event.side_effect = lambda **kw: None + + from nullrun.instrumentation.crewai import patch_crewai + from crewai import Crew + + assert patch_crewai(rt) is True + crew = Crew() + crew.usage_metrics = {"m": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}} + Crew.kickoff(crew) # does not raise + + +# ─── unpatch ────────────────────────────────────────────────────────── + + +def test_unpatch_restores_original(monkeypatch, fresh_patch_module): + _install_fake_crewai(monkeypatch) + from nullrun.instrumentation.crewai import patch_crewai, unpatch_crewai + from crewai import Crew + + original_kickoff = Crew.kickoff + assert patch_crewai(MagicMock()) is True + assert Crew.kickoff is not original_kickoff + + unpatch_crewai() + assert Crew.kickoff is original_kickoff + assert Crew._nullrun_patched is False + + +def test_unpatch_when_not_patched_is_noop(monkeypatch, fresh_patch_module): + from nullrun.instrumentation.crewai import unpatch_crewai + + unpatch_crewai() # safe no-op + + +def test_unpatch_when_module_missing(monkeypatch, fresh_patch_module): + _install_fake_crewai(monkeypatch) + from nullrun.instrumentation.crewai import patch_crewai, unpatch_crewai + + assert patch_crewai(MagicMock()) is True + monkeypatch.delitem(sys.modules, "crewai", raising=False) + unpatch_crewai() # should not raise \ No newline at end of file diff --git a/tests/test_langgraph_callback.py b/tests/test_langgraph_callback.py new file mode 100644 index 0000000..339efa4 --- /dev/null +++ b/tests/test_langgraph_callback.py @@ -0,0 +1,419 @@ +""" +Regression tests for ``nullrun.instrumentation.langgraph``. + +Covers: + + - ``extract_usage_from_response`` — every branch of the usage-shape + fan-out (dict, object, generations, response_metadata, llm_output, + streaming chunks). + - ``NullRunCallback`` — span emission (start/end) for chains / tools / + agents, nested parent/child via ``parent_run_id``, the + ``_active_runs`` FIFO eviction at 4096 entries, and the LLM-end + track-event with normalised usage. + - ``_extract_node_name`` — every branch (dict / list / str / missing). +""" +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from nullrun.instrumentation.langgraph import ( + NullRunCallback, + _ACTIVE_RUNS_MAX, + _extract_node_name, + extract_usage_from_response, +) + + +# ─── extract_usage_from_response ───────────────────────────────────── + + +def test_extract_usage_metadata_dict_form(): + """OpenAI-via-LangChain style: ``response.usage_metadata`` as a dict.""" + response = SimpleNamespace(usage_metadata={ + "input_tokens": 12, + "output_tokens": 34, + "total_tokens": 46, + }) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["input_tokens"] == 12 + assert usage["output_tokens"] == 34 + assert usage["total_tokens"] == 46 + assert usage["has_usage"] is True + + +def test_extract_usage_metadata_object_form(): + """Object with .input_tokens / .output_tokens / .total_tokens attrs.""" + response = SimpleNamespace(usage_metadata=SimpleNamespace( + input_tokens=7, + output_tokens=11, + total_tokens=18, + )) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["input_tokens"] == 7 + assert usage["output_tokens"] == 11 + assert usage["total_tokens"] == 18 + assert usage["has_usage"] is True + + +def test_extract_usage_from_generations(): + """``response.generations[0][0].message.usage_metadata`` — dict.""" + msg = SimpleNamespace(usage_metadata={"input_tokens": 5, "output_tokens": 6, "total_tokens": 11}) + gen = SimpleNamespace(message=msg) + response = SimpleNamespace(generations=[[gen]]) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is True + assert usage["input_tokens"] == 5 + + +def test_extract_usage_from_generations_object_form(): + """``response.generations[0][0].message.usage_metadata`` as an object.""" + um = SimpleNamespace(input_tokens=1, output_tokens=2, total_tokens=3) + msg = SimpleNamespace(usage_metadata=um) + gen = SimpleNamespace(message=msg) + response = SimpleNamespace(generations=[[gen]]) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is True + assert usage["input_tokens"] == 1 + + +def test_extract_usage_from_response_usage_dict(): + """Anthropic / standard OpenAI: ``response.usage`` as a dict.""" + response = SimpleNamespace(usage={"input_tokens": 100, "output_tokens": 200, "total_tokens": 300}) + usage = extract_usage_from_response(response, provider="anthropic", model="x") + assert usage["has_usage"] is True + assert usage["total_tokens"] == 300 + + +def test_extract_usage_from_response_usage_object(): + """``response.usage`` as an object with .input_tokens / .total_tokens.""" + response = SimpleNamespace(usage=SimpleNamespace(input_tokens=4, output_tokens=8, total_tokens=12)) + usage = extract_usage_from_response(response, provider="anthropic", model="x") + assert usage["has_usage"] is True + assert usage["total_tokens"] == 12 + + +def test_extract_usage_from_response_metadata_token_usage(): + """``response.response_metadata.token_usage`` — dict form (some providers).""" + response = SimpleNamespace(response_metadata={"token_usage": { + "prompt_tokens": 21, + "completion_tokens": 22, + "total_tokens": 43, + }}) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is True + assert usage["input_tokens"] == 21 + assert usage["output_tokens"] == 22 + + +def test_extract_usage_from_response_metadata_alternate_keys(): + """Some providers use ``input_tokens`` / ``output_tokens`` inside token_usage.""" + response = SimpleNamespace(response_metadata={"token_usage": { + "input_tokens": 8, + "output_tokens": 9, + }}) + usage = extract_usage_from_response(response, provider="anthropic", model="x") + assert usage["has_usage"] is True + assert usage["input_tokens"] == 8 + + +def test_extract_usage_from_llm_output(): + """``response.llm_output.token_usage`` — ``LLMResult`` callback case.""" + response = SimpleNamespace(llm_output={"token_usage": { + "prompt_tokens": 50, + "completion_tokens": 51, + "total_tokens": 101, + }}) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is True + assert usage["total_tokens"] == 101 + + +def test_extract_usage_no_usage_data_has_usage_false(): + """Empty response → ``has_usage`` is False and tokens stay zero.""" + response = SimpleNamespace() # no attrs + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is False + assert usage["total_tokens"] == 0 + + +def test_extract_usage_zero_values_has_usage_false(): + """All-zero usage dict → has_usage False.""" + response = SimpleNamespace(usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is False + + +def test_extract_usage_iterable_response_skipped(): + """Streaming-iterable response without usage → no-op branch hit.""" + class _Iter: + def __iter__(self): + return iter(["chunk1", "chunk2"]) + + response = SimpleNamespace(chunks=_Iter()) # no usage attrs + usage = extract_usage_from_response(response, provider="openai", model="x") + assert usage["has_usage"] is False + + +# ─── _extract_node_name ─────────────────────────────────────────────── + + +def test_extract_node_name_non_dict_returns_default(): + assert _extract_node_name("not a dict", default="chain") == "chain" + assert _extract_node_name(None, default="chain") == "chain" + + +def test_extract_node_name_id_str(): + assert _extract_node_name({"id": "my_node"}, default="chain") == "my_node" + + +def test_extract_node_name_id_list(): + assert _extract_node_name({"id": ["ns", "my_node"]}, default="chain") == "my_node" + + +def test_extract_node_name_id_empty_list_returns_default(): + assert _extract_node_name({"id": []}, default="chain") == "chain" + + +def test_extract_node_name_falls_back_to_name(): + assert _extract_node_name({"name": "thing"}, default="chain") == "thing" + + +def test_extract_node_name_no_known_keys_returns_default(): + assert _extract_node_name({"foo": "bar"}, default="chain") == "chain" + + +# ─── NullRunCallback: span emission ────────────────────────────────── + + +def _make_cb_with_recorder() -> tuple[NullRunCallback, list, list]: + """Build a callback wired to a mock runtime that captures span + and llm_call emissions. + """ + spans: list = [] + llms: list = [] + + runtime = MagicMock() + runtime.track_event.side_effect = lambda **kw: spans.append(kw) + runtime.track.side_effect = lambda ev: llms.append(ev) + + cb = NullRunCallback(runtime=runtime) + return cb, spans, llms + + +def test_chain_start_without_run_id_no_op(): + """When LangChain omits ``run_id`` the callback skips emit.""" + cb, spans, _ = _make_cb_with_recorder() + cb.on_chain_start(serialized={"id": ["a"]}, inputs={}) # no run_id + assert spans == [] + + +def test_chain_start_then_end_emits_span_pair(): + """Happy path: chain_start emits span_start, chain_end emits span_end.""" + cb, spans, _ = _make_cb_with_recorder() + cb.on_chain_start(serialized={"id": ["chain"]}, inputs={}, run_id="r1") + cb.on_chain_end(outputs={"x": 1}, run_id="r1") + + kinds = [s["event_type"] for s in spans] + assert kinds == ["span_start", "span_end"] + assert spans[0]["fn_name"] == "chain" + assert spans[0]["span_kind"] == "chain" + # span_start + span_end share trace_id / span_id (matched by run_id). + assert spans[0]["span_id"] == spans[1]["span_id"] + assert spans[0]["trace_id"] == spans[1]["trace_id"] + # No parent span — first call should be a root. + assert spans[0]["parent_span_id"] is None + assert spans[0]["depth"] == 0 + + +def test_chain_end_without_start_no_op(): + """``on_chain_end`` for an unknown run_id silently no-ops.""" + cb, spans, _ = _make_cb_with_recorder() + cb.on_chain_end(outputs={}, run_id="orphan") + assert spans == [] + + +def test_nested_chain_uses_active_run_as_parent(): + """Inner chain's span_id is referenced as the outer span's parent_span_id.""" + cb, spans, _ = _make_cb_with_recorder() + cb.on_chain_start(serialized={"id": "outer"}, inputs={}, run_id="outer") + cb.on_chain_start(serialized={"id": "inner"}, inputs={}, run_id="inner", parent_run_id="outer") + + outer_span = spans[0] + inner_span = spans[1] + assert inner_span["parent_span_id"] == outer_span["span_id"] + assert inner_span["trace_id"] == outer_span["trace_id"] + assert inner_span["depth"] == 1 + + +def test_parent_run_id_falls_back_to_contextvar(): + """When parent_run_id is unknown, fall back to contextvar span.""" + from nullrun.tracing import create_root_span, set_span + + cb, spans, _ = _make_cb_with_recorder() + # Push a span via the contextvar (mimics @protect). + parent = create_root_span() + token = set_span(parent) + + try: + cb.on_chain_start(serialized={"id": "x"}, inputs={}, run_id="child", parent_run_id="unknown-parent") + finally: + from nullrun.tracing import reset_span + + reset_span(token) + + inner = spans[0] + assert inner["parent_span_id"] == parent.span_id + assert inner["trace_id"] == parent.trace_id + assert inner["depth"] == 1 + + +# ─── Tool callbacks ────────────────────────────────────────────────── + + +def test_tool_start_then_end(): + cb, spans, _ = _make_cb_with_recorder() + cb.on_tool_start(serialized={"id": "calculator"}, input_str="1+1", run_id="t1") + cb.on_tool_end(output="2", run_id="t1") + kinds = [s["event_type"] for s in spans] + assert kinds == ["span_start", "span_end"] + assert spans[0]["span_kind"] == "tool" + assert spans[0]["fn_name"] == "calculator" + + +def test_tool_error_emits_span_end_with_error(): + cb, spans, _ = _make_cb_with_recorder() + cb.on_tool_start(serialized={"id": "x"}, input_str="", run_id="t1") + cb.on_tool_error(error=RuntimeError("boom"), run_id="t1") + assert spans[1]["event_type"] == "span_end" + assert spans[1]["error"] == "boom" + + +def test_tool_end_without_start_no_op(): + cb, spans, _ = _make_cb_with_recorder() + cb.on_tool_end(output="x", run_id="orphan") + assert spans == [] + + +def test_tool_start_without_run_id_no_op(): + cb, spans, _ = _make_cb_with_recorder() + cb.on_tool_start(serialized={"id": "x"}, input_str="", run_id=None) + assert spans == [] + + +# ─── Agent callbacks ───────────────────────────────────────────────── + + +def test_agent_action_then_finish(): + cb, spans, _ = _make_cb_with_recorder() + action = SimpleNamespace(tool="search") + cb.on_agent_action(action, run_id="a1") + cb.on_agent_finish(finish=None, run_id="a1") + kinds = [s["event_type"] for s in spans] + assert kinds == ["span_start", "span_end"] + assert spans[0]["fn_name"] == "agent_action:search" + assert spans[0]["span_kind"] == "agent" + + +def test_agent_action_without_run_id_no_op(): + cb, spans, _ = _make_cb_with_recorder() + cb.on_agent_action(SimpleNamespace(tool="x"), run_id=None) + assert spans == [] + + +def test_agent_action_default_tool_name(): + """``action.tool`` missing → fn_name defaults to ``agent_action:agent``.""" + cb, spans, _ = _make_cb_with_recorder() + cb.on_agent_action(SimpleNamespace(), run_id="a1") + assert spans[0]["fn_name"] == "agent_action:agent" + + +def test_agent_finish_without_action_no_op(): + cb, spans, _ = _make_cb_with_recorder() + cb.on_agent_finish(finish=None, run_id="orphan") + assert spans == [] + + +# ─── LLM end → track (not track_event) ─────────────────────────────── + + +def test_on_llm_end_emits_llm_call(): + """``on_llm_end`` extracts usage and forwards to ``runtime.track``.""" + cb, _spans, llms = _make_cb_with_recorder() + response = SimpleNamespace(usage_metadata={ + "input_tokens": 5, + "output_tokens": 10, + "total_tokens": 15, + }) + cb.on_llm_end(response, invocation_params={"model_name": "gpt-4o", "model_provider": "openai"}) + assert len(llms) == 1 + ev = llms[0] + assert ev["type"] == "llm_call" + assert ev["model"] == "gpt-4o" + assert ev["provider"] == "openai" + assert ev["tokens"] == 15 + assert ev["has_usage"] is True + + +def test_on_llm_end_no_usage_still_emits(): + """Even with no usage data, on_llm_end forwards an llm_call event + with ``has_usage=False`` so the SDK still records the call shape. + """ + cb, _spans, llms = _make_cb_with_recorder() + cb.on_llm_end(SimpleNamespace(), invocation_params={}) + assert len(llms) == 1 + assert llms[0]["has_usage"] is False + + +def test_on_llm_end_runtime_failure_is_swallowed(): + """If ``runtime.track`` raises, on_llm_end swallows the failure.""" + runtime = MagicMock() + runtime.track.side_effect = RuntimeError("down") + cb = NullRunCallback(runtime=runtime) + # Must not raise. + cb.on_llm_end(SimpleNamespace(usage_metadata={"input_tokens": 1, "output_tokens": 2, "total_tokens": 3})) + + +def test_track_event_failure_is_swallowed(): + """Span emission failures are swallowed — never break the user's chain.""" + runtime = MagicMock() + runtime.track_event.side_effect = RuntimeError("down") + cb = NullRunCallback(runtime=runtime) + cb.on_chain_start(serialized={"id": "x"}, inputs={}, run_id="r1") # no raise + cb.on_chain_end(outputs={}, run_id="r1") # no raise + + +# ─── _active_runs FIFO cap ─────────────────────────────────────────── + + +def test_active_runs_cap_evicts_oldest(monkeypatch): + """When the FIFO cap is hit, the OLDEST run is evicted (with a warning).""" + cb, spans, _ = _make_cb_with_recorder() + # Lower the cap to make the test fast. + monkeypatch.setattr(cb, "_active_runs_max", 3) + # Open 4 chains. + for i in range(4): + cb.on_chain_start(serialized={"id": f"c{i}"}, inputs={}, run_id=f"r{i}") + # The first run (r0) should have been evicted. + assert "r0" not in cb._active_runs + assert "r3" in cb._active_runs + + +def test_active_runs_cap_eviction_warning(caplog): + """When eviction fires, a warning is logged so operators see chain-end drops.""" + import logging + + cb, _spans, _ = _make_cb_with_recorder() + cb._active_runs_max = 2 + with caplog.at_level(logging.WARNING, logger="nullrun.instrumentation.langgraph"): + for i in range(3): + cb.on_chain_start(serialized={"id": f"c{i}"}, inputs={}, run_id=f"r{i}") + assert any("evicted oldest run_id" in r.getMessage() for r in caplog.records) + + +def test_active_runs_default_max(): + """Default cap matches the documented 4096.""" + cb, _, _ = _make_cb_with_recorder() + assert cb._active_runs_max == _ACTIVE_RUNS_MAX == 4096 \ No newline at end of file diff --git a/tests/test_llama_index_patch.py b/tests/test_llama_index_patch.py new file mode 100644 index 0000000..129c25b --- /dev/null +++ b/tests/test_llama_index_patch.py @@ -0,0 +1,347 @@ +""" +Regression tests for the llama-index auto-instrumentation patch. + +Installs a fake ``llama_index.core.instrumentation`` module so the +patch can subscribe handlers without needing the real dep in CI. +""" +from __future__ import annotations + +import importlib +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +def _install_fake_llama_index(monkeypatch) -> dict: + """Install ``llama_index.core.instrumentation`` with a fake + ``get_dispatcher`` that captures event handlers in a list. + + Returns the dispatcher (so tests can fire ``LLMChatEndEvent`` + or ``FunctionCallEvent`` at the registered handlers). + """ + captured_handlers: list = [] + + class _FakeDispatcher: + def __init__(self): + self._captured = captured_handlers + + def add_event_handler(self, event_cls, handler): + self._captured.append((event_cls, handler)) + + def remove_event_handler(self, event_cls, handler): + for i, (cls, h) in enumerate(self._captured): + if cls is event_cls and h is handler: + del self._captured[i] + return + + dispatcher = _FakeDispatcher() + + events_mod = ModuleType("llama_index.core.instrumentation.events") + events_mod_llm = ModuleType("llama_index.core.instrumentation.events.llm") + events_mod_llm.LLMChatEndEvent = type("LLMChatEndEvent", (), {}) + events_mod_tool = ModuleType("llama_index.core.instrumentation.events.tool") + events_mod_tool.FunctionCallEvent = type("FunctionCallEvent", (), {}) + events_mod.llm = events_mod_llm + events_mod.tool = events_mod_tool + + inst_mod = ModuleType("llama_index.core.instrumentation") + inst_mod.get_dispatcher = MagicMock(return_value=dispatcher) + monkeypatch.setitem(sys.modules, "llama_index", ModuleType("llama_index")) + monkeypatch.setitem(sys.modules, "llama_index.core", ModuleType("llama_index.core")) + monkeypatch.setitem(sys.modules, "llama_index.core.instrumentation", inst_mod) + monkeypatch.setitem(sys.modules, "llama_index.core.instrumentation.events", events_mod) + monkeypatch.setitem(sys.modules, "llama_index.core.instrumentation.events.llm", events_mod_llm) + monkeypatch.setitem(sys.modules, "llama_index.core.instrumentation.events.tool", events_mod_tool) + + return dispatcher + + +def _fake_runtime() -> MagicMock: + rt = MagicMock() + rt.track.side_effect = lambda ev: getattr(rt, "_captured", []).append(ev) + rt._captured = [] + return rt + + +@pytest.fixture +def fresh_patch_module(): + if "nullrun.instrumentation.llama_index" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.llama_index"]) + else: + importlib.import_module("nullrun.instrumentation.llama_index") + yield + if "nullrun.instrumentation.llama_index" in sys.modules: + importlib.reload(sys.modules["nullrun.instrumentation.llama_index"]) + + +# ─── ImportError branch ────────────────────────────────────────────── + + +def test_patch_llama_index_returns_false_when_missing(monkeypatch, fresh_patch_module): + monkeypatch.setitem(sys.modules, "llama_index", None) + monkeypatch.setitem(sys.modules, "llama_index.core", None) + monkeypatch.setitem(sys.modules, "llama_index.core.instrumentation", None) + from nullrun.instrumentation.llama_index import patch_llama_index + + assert patch_llama_index(MagicMock()) is False + + +# ─── Idempotency ───────────────────────────────────────────────────── + + +def test_patch_llama_index_idempotent(monkeypatch, fresh_patch_module): + _install_fake_llama_index(monkeypatch) + from nullrun.instrumentation.llama_index import patch_llama_index + + assert patch_llama_index(MagicMock()) is True + assert patch_llama_index(MagicMock()) is True + + +# ─── Happy paths ───────────────────────────────────────────────────── + + +def test_llm_chat_end_with_dict_usage_emits_track(monkeypatch, fresh_patch_module): + """``LLMChatEndEvent`` with ``event.response.raw.usage`` as a + dict — the wrapper emits an llm_call event with split + prompt / completion / total. + """ + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + # Two handlers registered: LLMChatEndEvent + FunctionCallEvent. + assert len(dispatcher._captured) == 2 + + import llama_index.core.instrumentation.events.llm as _llm_events + _LLM = _llm_events.LLMChatEndEvent + + # Fire the LLMChatEndEvent handler manually. + # The patch reads ``event.response.raw`` and applies ``hasattr(raw, + # "usage")`` to decide between the dict-form (raw IS the usage + # dict) and the object-form (raw.usage is the usage dict). Most + # llama-index responses are the dict form. + for cls, handler in dispatcher._captured: + if cls is _LLM: + response = SimpleNamespace( + raw={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + model="gpt-4o", + ) + event = SimpleNamespace(response=response) + handler(event) + break + + events = rt._captured + assert len(events) == 1 + ev = events[0] + assert ev["type"] == "llm_call" + assert ev["provider"] == "llama_index" + assert ev["model"] == "gpt-4o" + assert ev["input_tokens"] == 10 + assert ev["output_tokens"] == 5 + assert ev["tokens"] == 15 + + +def test_llm_chat_end_without_usage_no_emit(monkeypatch, fresh_patch_module): + """All-zero usage → wrapper returns early without emitting.""" + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.llm as _llm_events + _LLM = _llm_events.LLMChatEndEvent + + for cls, handler in dispatcher._captured: + if cls is _LLM: + # Empty usage dict → all-zero → early return. + response = SimpleNamespace(raw={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, model="x") + handler(SimpleNamespace(response=response)) + break + + assert rt._captured == [] + + +def test_llm_chat_end_response_without_raw(monkeypatch, fresh_patch_module): + """``event.response.raw`` is missing — wrapper treats as empty.""" + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.llm as _llm_events + _LLM = _llm_events.LLMChatEndEvent + + for cls, handler in dispatcher._captured: + if cls is _LLM: + response = SimpleNamespace(model="x") # no .raw + handler(SimpleNamespace(response=response)) + break + + assert rt._captured == [] + + +def test_llm_chat_end_object_usage_attr(monkeypatch, fresh_patch_module): + """``event.response.raw.usage`` is an object with .prompt_tokens etc.""" + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.llm as _llm_events + _LLM = _llm_events.LLMChatEndEvent + + class _Usage: + prompt_tokens = 3 + completion_tokens = 4 + total_tokens = 0 # missing → falls back to prompt+completion + + for cls, handler in dispatcher._captured: + if cls is _LLM: + # ``raw`` is an object whose ``.usage`` is a dict. The + # ``hasattr(usage, "usage")`` branch unwraps once and then + # ``usage.get(...)`` reads the dict. + response = SimpleNamespace( + raw=SimpleNamespace(usage={"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}), + model="x", + ) + handler(SimpleNamespace(response=response)) + break + + events = rt._captured + assert len(events) == 1 + assert events[0]["tokens"] == 7 + + +def test_function_call_event_emits_tool_call(monkeypatch, fresh_patch_module): + """``FunctionCallEvent`` with a ``tool.name`` attribute — the + wrapper emits a tool_call event. + """ + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.tool as _tool_events + _FCE = _tool_events.FunctionCallEvent + + tool = SimpleNamespace(name="search") + for cls, handler in dispatcher._captured: + if cls is _FCE: + handler(SimpleNamespace(tool=tool)) + break + + events = rt._captured + assert len(events) == 1 + assert events[0]["type"] == "tool_call" + assert events[0]["tool_name"] == "search" + + +def test_function_call_event_tool_without_name_uses_default(monkeypatch, fresh_patch_module): + """``event.tool`` exists but no ``.name`` — default to 'tool'.""" + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.tool as _tool_events + _FCE = _tool_events.FunctionCallEvent + + for cls, handler in dispatcher._captured: + if cls is _FCE: + handler(SimpleNamespace(tool=SimpleNamespace())) # no .name + break + + events = rt._captured + assert len(events) == 1 + assert events[0]["tool_name"] == "tool" + + +def test_function_call_event_without_tool_uses_default(monkeypatch, fresh_patch_module): + """``event.tool`` is None — default to 'tool'.""" + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.tool as _tool_events + _FCE = _tool_events.FunctionCallEvent + + for cls, handler in dispatcher._captured: + if cls is _FCE: + handler(SimpleNamespace(tool=None)) + break + + events = rt._captured + assert len(events) == 1 + assert events[0]["tool_name"] == "tool" + + +# ─── Track failure is swallowed ────────────────────────────────────── + + +def test_track_failure_is_swallowed(monkeypatch, fresh_patch_module): + dispatcher = _install_fake_llama_index(monkeypatch) + rt = MagicMock() + rt.track.side_effect = RuntimeError("down") + + from nullrun.instrumentation.llama_index import patch_llama_index + assert patch_llama_index(rt) is True + + import llama_index.core.instrumentation.events.llm as _llm_events + import llama_index.core.instrumentation.events.tool as _tool_events + _LLM = _llm_events.LLMChatEndEvent + _FCE = _tool_events.FunctionCallEvent + + # LLM end: must not raise. + for cls, handler in dispatcher._captured: + if cls is _LLM: + response = SimpleNamespace( + raw={"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}}, + ) + handler(SimpleNamespace(response=response)) + break + + # Tool call: must not raise. + for cls, handler in dispatcher._captured: + if cls is _FCE: + handler(SimpleNamespace(tool=SimpleNamespace(name="x"))) + break + + +# ─── unpatch ───────────────────────────────────────────────────────── + + +def test_unpatch_removes_handlers(monkeypatch, fresh_patch_module): + dispatcher = _install_fake_llama_index(monkeypatch) + rt = _fake_runtime() + from nullrun.instrumentation.llama_index import patch_llama_index, unpatch_llama_index + + assert patch_llama_index(rt) is True + assert len(dispatcher._captured) == 2 + unpatch_llama_index() + assert len(dispatcher._captured) == 0 + + +def test_unpatch_when_not_patched_is_noop(monkeypatch, fresh_patch_module): + from nullrun.instrumentation.llama_index import unpatch_llama_index + + unpatch_llama_index() # safe + + +def test_unpatch_when_module_missing(monkeypatch, fresh_patch_module): + _install_fake_llama_index(monkeypatch) + from nullrun.instrumentation.llama_index import patch_llama_index, unpatch_llama_index + + assert patch_llama_index(MagicMock()) is True + monkeypatch.delitem(sys.modules, "llama_index.core.instrumentation", raising=False) + unpatch_llama_index() # should not raise \ No newline at end of file diff --git a/tests/test_protect_branches.py b/tests/test_protect_branches.py new file mode 100644 index 0000000..0dbbd97 --- /dev/null +++ b/tests/test_protect_branches.py @@ -0,0 +1,540 @@ +""" +Additional tests for ``nullrun.decorators`` — branch coverage for the +``_safe_args`` / ``_strip_details_balanced`` / ``_enforce_sensitive_tool`` +helpers, the fail-CLOSED / fail-OPEN contract, the KILL→BlockedException +unification (Round 3), and the ``@protect()`` paren-form. +""" +from __future__ import annotations + +import os +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from nullrun.breaker.exceptions import ( + NullRunBlockedException, + NullRunTransportError, + TransportErrorSource, + WorkflowKilledInterrupt, + WorkflowPausedException, +) +from nullrun.decorators import ( + SENSITIVE_ARG_KEYS, + _enforce_sensitive_tool, + _safe_args, + _safe_error_str, + _safe_kwargs, + _safe_repr, + _strip_details_balanced, + protect, + sensitive, +) +from nullrun.runtime import NullRunRuntime + + +@pytest.fixture +def test_runtime(monkeypatch): + """Provide a runtime in test mode so get_runtime() returns without + authenticating against a real server. + """ + monkeypatch.setenv("NULLRUN_API_KEY", "test-key-12345678") + NullRunRuntime.reset_instance() + rt = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + rt.organization_id = "org-1" + # Stub the transport so the network is never touched in tests. + # - ``_do_flush`` overrides the public flush. + # - ``_do_flush_locked`` is what ``track()`` calls when the buffer + # fills — must also be stubbed to be safe. + # - ``_client`` is the httpx client — magicmock so even a stray + # ``post`` raises a clean AttributeError instead of hitting the API. + rt._transport._do_flush = lambda: None + rt._transport._do_flush_locked = lambda: None + rt._transport._client = MagicMock() + NullRunRuntime._instance = rt + yield rt + NullRunRuntime.reset_instance() + + +# ─── _safe_repr ─────────────────────────────────────────────────────── + + +def test_safe_repr_short_value_passes_through(test_runtime): + """Under the 50-char cap, value flows through unmodified.""" + s = _safe_repr("hi") + assert s == "'hi'" + + +def test_safe_repr_long_value_truncated(test_runtime): + """Over 50 chars, suffix ``...`` appended.""" + s = _safe_repr("x" * 200, max_len=50) + assert s.endswith("...") + assert len(s) > 50 + + +def test_safe_repr_redacts_details_before_truncating(test_runtime): + """``details={PAN: '4111-...'}`` must be redacted BEFORE truncation.""" + # String kept under the 50-char cap so the redact survives the + # truncate step (otherwise we'd only verify truncation). + secret = "4111-1111-1111-1111" + payload = f"x details={{'card': '{secret}'}}" + out = _safe_repr(payload, max_len=50) + assert secret not in out + assert "" in out + + +# ─── _safe_kwargs ──────────────────────────────────────────────────── + + +def test_safe_kwargs_masks_sensitive_keys(test_runtime): + out = _safe_kwargs({"password": "p", "token": "t", "user": "alice"}) + assert out["password"] == "***" + assert out["token"] == "***" + # Non-sensitive values go through _safe_repr → ``repr()``. + assert out["user"] == "'alice'" + + +def test_safe_kwargs_is_case_insensitive(test_runtime): + out = _safe_kwargs({"PASSWORD": "p", "Token": "t"}) + assert out["PASSWORD"] == "***" + assert out["Token"] == "***" + + +# ─── _safe_args ────────────────────────────────────────────────────── + + +def test_safe_args_masks_positional_sensitive_param(test_runtime): + """Positional sensitive param (e.g. ``credit_card_number``) is masked.""" + def charge(credit_card_number, amount): + return amount + + masked = _safe_args(charge, ("4111-1111-1111-1111", 50)) + assert masked[0] == "***" + # ``repr(50)`` is ``"50"``. + assert masked[1] == "50" + + +def test_safe_args_trailing_extra_args_uses_safe_repr(): + """``*args``-style callable: extra positional args use safe_repr.""" + def variadic(*args, **kwargs): + return args + + masked = _safe_args(variadic, ("x", "ok")) + # ``*args`` has no name → safe_repr for both (no masking). + assert masked[0] == "'x'" + assert masked[1] == "'ok'" + + +def test_safe_args_no_signature_falls_back_to_safe_repr(): + """C-extension / built-in without signature → safe_repr on all.""" + + class _NoSig: + # Builtin-ish class; ``inspect.signature`` raises ValueError. + pass + + masked = _safe_args(_NoSig, ("4111", 50)) + assert masked[0] == "'4111'" + assert masked[1] == "50" + + +def test_safe_args_signature_raises_typeerror_falls_back(): + """``inspect.signature`` raises ``TypeError`` for some callables.""" + + class _Bad: + # Trigger ValueError path. + __signature__ = None # type: ignore[assignment] + + masked = _safe_args(_Bad, ("x",)) + assert masked == ["'x'"] + + +# ─── _strip_details_balanced ───────────────────────────────────────── + + +def test_strip_details_balanced_no_details_unchanged(): + s = "no details here" + assert _strip_details_balanced(s) == s + + +def test_strip_details_balanced_details_without_brace_unchanged(): + s = "details=plain text without braces" + # No '{' after 'details=' → left as-is. + assert _strip_details_balanced(s) == s + + +def test_strip_details_balanced_simple_payload(test_runtime): + s = "context=ok details={'a': 1, 'b': 2}" + out = _strip_details_balanced(s) + assert "" in out + assert "'a': 1" not in out + + +def test_strip_details_balanced_nested_dicts(test_runtime): + """Nested dicts in the details payload → still redacted as a unit.""" + s = "msg details={'a': {'b': {'c': 'secret'}}}" + out = _strip_details_balanced(s) + assert "secret" not in out + assert "" in out + + +def test_strip_details_balanced_string_with_braces_inside(test_runtime): + """A string value containing ``{`` / ``}`` does NOT break the brace walker.""" + s = 'msg details={"key": "value with { and } inside"}' + out = _strip_details_balanced(s) + assert "value with { and } inside" not in out + assert "" in out + + +def test_strip_details_balanced_multiple_details(test_runtime): + """Two ``details={...}`` substrings in the same string → both redacted.""" + s = "first details={'a': 1} middle details={'b': 2}" + out = _strip_details_balanced(s) + assert out.count("") == 2 + + +def test_strip_details_balanced_escaped_quote_in_string(test_runtime): + r"""A string with an escaped quote (\") is handled by the walker.""" + s = r'msg details={"key": "val\"ue"}' + out = _strip_details_balanced(s) + assert "" in out + + +# ─── _safe_error_str ───────────────────────────────────────────────── + + +def test_safe_error_str_none_returns_none(test_runtime): + assert _safe_error_str(None) is None + + +def test_safe_error_str_simple_message_passes_through(test_runtime): + e = RuntimeError("plain") + assert _safe_error_str(e) == "plain" + + +def test_safe_error_str_details_redacted(test_runtime): + e = RuntimeError("oops details={'secret': 'value'}") + out = _safe_error_str(e) + assert "secret" not in out + assert "" in out + + +# ─── _enforce_sensitive_tool ──────────────────────────────────────── + + +def test_enforce_sensitive_tool_non_sensitive_returns(test_runtime): + """Non-sensitive tool → no-op, no runtime call.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = False + rt.execute = MagicMock() + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + rt.execute.assert_not_called() + + +def test_enforce_sensitive_tool_real_block_propagates(test_runtime): + """``decision=block`` from gateway → raises NullRunBlockedException.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.side_effect = NullRunBlockedException( + workflow_id="wf-1", reason="denied" + ) + with pytest.raises(NullRunBlockedException): + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + + +def test_enforce_sensitive_tool_transport_error_fail_closed(test_runtime): + """``NullRunTransportError`` + no fail-open → raises NullRunBlockedException.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.side_effect = NullRunTransportError( + "down", + source=TransportErrorSource.NETWORK_ERROR, + endpoint="/execute", + ) + with pytest.raises(NullRunBlockedException) as excinfo: + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + assert "NETWORK_ERROR" in excinfo.value.reason + + +def test_enforce_sensitive_tool_transport_error_fail_open(test_runtime, monkeypatch): + """``NULLRUN_SENSITIVE_FAIL_OPEN=1`` + transport error → body runs.""" + monkeypatch.setenv("NULLRUN_SENSITIVE_FAIL_OPEN", "1") + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.side_effect = NullRunTransportError( + "down", + source=TransportErrorSource.NETWORK_ERROR, + endpoint="/execute", + ) + # Must NOT raise. + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + + +def test_enforce_sensitive_tool_generic_exception_fail_closed(test_runtime): + """Non-transport exception → NullRunBlockedException.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.side_effect = ValueError("oops") + with pytest.raises(NullRunBlockedException): + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + + +def test_enforce_sensitive_tool_generic_exception_fail_open(test_runtime, monkeypatch): + """Generic exception + fail-open → no raise.""" + monkeypatch.setenv("NULLRUN_SENSITIVE_FAIL_OPEN", "1") + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.side_effect = ValueError("oops") + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) # no raise + + +def test_enforce_sensitive_tool_dict_with_fallback_decision_source(test_runtime): + """``decision_source`` starts with FALLBACK_ → raises.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.return_value = { + "decision": "allow", + "decision_source": "FALLBACK_NETWORK_ERROR", + } + with pytest.raises(NullRunBlockedException): + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + + +def test_enforce_sensitive_tool_dict_with_typed_error_source(test_runtime): + """``decision_source`` ∈ TransportErrorSource values → raises.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.return_value = { + "decision": "allow", + "decision_source": TransportErrorSource.GATEWAY_ERROR, + } + with pytest.raises(NullRunBlockedException): + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) + + +def test_enforce_sensitive_tool_dict_with_fallback_fail_open(test_runtime, monkeypatch): + """``decision_source`` FALLBACK_* + fail-open → no raise.""" + monkeypatch.setenv("NULLRUN_SENSITIVE_FAIL_OPEN", "1") + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.return_value = { + "decision": "allow", + "decision_source": "FALLBACK_NETWORK_ERROR", + } + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) # no raise + + +def test_enforce_sensitive_tool_dict_with_gateway_decision_falls_through(test_runtime): + """``decision_source=gateway`` + ``decision=allow`` → no raise.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.return_value = { + "decision": "allow", + "decision_source": "gateway", + } + _enforce_sensitive_tool(rt, lambda x: x, (1,), {}) # no raise + + +def test_enforce_sensitive_tool_sensitive_kwargs_masked_in_call(test_runtime): + """``password`` kwarg on a sensitive tool is masked before /execute.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.return_value = {"decision": "allow", "decision_source": "gateway"} + _enforce_sensitive_tool(rt, lambda x: x, (), {"password": "p", "user": "alice"}) + # ``runtime.execute`` is called positionally: ``(tool_name, input_data, ...)``. + forwarded = rt.execute.call_args.args[1] + assert forwarded["kwargs"]["password"] == "***" + # Non-sensitive → safe_repr → ``"'alice'"``. + assert forwarded["kwargs"]["user"] == "'alice'" + + +def test_enforce_sensitive_tool_sensitive_positional_arg_masked(test_runtime): + """``credit_card_number`` positional on a sensitive tool is masked.""" + rt = MagicMock() + rt.is_sensitive_tool.return_value = True + rt.execute.return_value = {"decision": "allow", "decision_source": "gateway"} + + def charge(credit_card_number, amount): + return amount + + _enforce_sensitive_tool(rt, charge, ("4111-1111-1111-1111", 50), {}) + forwarded = rt.execute.call_args.args[1] + assert forwarded["args"][0] == "***" + + +# ─── @protect paren-form ───────────────────────────────────────────── + + +def test_protect_with_parens_returns_decorator(test_runtime): + """``@protect()`` with empty parens works just like ``@protect``.""" + # Stub track_event so the finally-block span emission does not + # re-enter check_control_plane with our mocked side effect. + test_runtime.track_event = MagicMock() + + @protect() + def f(x): + return x * 2 + + assert f(3) == 6 + + +def test_protect_without_parens_wraps_directly(test_runtime): + """``@protect`` without parens wraps the function directly.""" + # Stub track_event so the finally-block span emission does not + # re-enter check_control_plane with our mocked side effect. + test_runtime.track_event = MagicMock() + + @protect + def f(x): + return x * 2 + + assert f(3) == 6 + + +# ─── KILL→BlockedException unification (Round 3) ────────────────────── + + +def test_protect_sync_kill_raises_NullRunBlockedException(test_runtime): + """``WorkflowKilledInterrupt`` from gate → unified as NullRunBlockedException.""" + from nullrun import decorators as dec_mod + + rt = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + rt.track_event = MagicMock() + rt.check_control_plane = MagicMock( + side_effect=WorkflowKilledInterrupt(workflow_id="wf-1", reason="admin kill") + ) + rt.check_workflow_budget = MagicMock() + dec_mod._runtime = rt + + @protect + def f(): + return "should not run" + + with pytest.raises(NullRunBlockedException) as excinfo: + f() + assert excinfo.value.reason == "admin kill" + + +def test_protect_sync_pause_raises_NullRunBlockedException(test_runtime): + """``WorkflowPausedException`` from gate → unified as NullRunBlockedException.""" + from nullrun import decorators as dec_mod + + rt = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + rt.track_event = MagicMock() + rt.check_control_plane = MagicMock( + side_effect=WorkflowPausedException(workflow_id="wf-1", reason="budget pause") + ) + rt.check_workflow_budget = MagicMock() + dec_mod._runtime = rt + + @protect + def f(): + return "should not run" + + with pytest.raises(NullRunBlockedException) as excinfo: + f() + assert excinfo.value.reason == "budget pause" + + +@pytest.mark.asyncio +async def test_protect_async_kill_re_raises_WorkflowKilledInterrupt(): + """Async wrapper does NOT unify — kill signal propagates as-is so + async frameworks can interrupt the event loop cleanly. + """ + from nullrun import decorators as dec_mod + + rt = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + rt.track_event = MagicMock() + rt.check_control_plane = MagicMock( + side_effect=WorkflowKilledInterrupt(workflow_id="wf-1", reason="x") + ) + rt.check_workflow_budget = MagicMock() + dec_mod._runtime = rt + + @protect + async def f(): + return "ok" + + with pytest.raises(WorkflowKilledInterrupt): + await f() + + +# ─── @sensitive decorator ──────────────────────────────────────────── + + +def test_sensitive_registers_tool_with_runtime(test_runtime): + """``@sensitive`` calls ``add_sensitive_tool`` on the runtime.""" + + @sensitive + def my_charge(amount): + return amount + + rt = NullRunRuntime.get_instance() + assert "my_charge" in rt.get_sensitive_tools() + + +def test_sensitive_runtime_init_failure_is_silent(test_runtime, monkeypatch): + """If runtime construction fails inside @sensitive, import must not crash.""" + from nullrun import decorators + + monkeypatch.setattr(decorators, "_get_or_create_runtime", MagicMock(side_effect=RuntimeError("x"))) + # Decorator must NOT raise even though registration failed. + @sensitive + def f(): + return 1 + + assert f() == 1 + + +# ─── reset() ────────────────────────────────────────────────────────── + + +def test_reset_clears_runtime_slot(test_runtime, monkeypatch): + """``reset()`` shuts down the runtime and clears the module-level slot.""" + from nullrun import decorators + + rt = NullRunRuntime.get_instance() + decorators._runtime = rt + decorators.reset() + assert decorators._runtime is None + + +def test_reset_when_no_runtime_is_silent(test_runtime): + from nullrun import decorators + + decorators._runtime = None + decorators.reset() # must not raise + + +def test_reset_shutdown_failure_is_silent(test_runtime, monkeypatch): + """``reset()`` swallows runtime shutdown exceptions.""" + from nullrun import decorators + + rt = MagicMock() + rt.shutdown.side_effect = RuntimeError("oops") + decorators._runtime = rt + decorators.reset() # must not raise + assert decorators._runtime is None + + +# ─── get_protected_runtime ────────────────────────────────────────── + + +def test_get_protected_runtime_returns_runtime(test_runtime): + from nullrun import decorators + + rt = NullRunRuntime.get_instance() + decorators._runtime = rt + assert decorators.get_protected_runtime() is rt + + +def test_get_protected_runtime_falls_back_to_get_runtime(test_runtime, monkeypatch): + """When the decorator slot is empty, fall back to the global singleton.""" + from nullrun import decorators + + decorators._runtime = None + NullRunRuntime._instance = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + try: + out = decorators.get_protected_runtime() + assert out is NullRunRuntime._instance + finally: + NullRunRuntime.reset_instance() \ No newline at end of file diff --git a/tests/test_runtime_branches.py b/tests/test_runtime_branches.py new file mode 100644 index 0000000..ddcf6ef --- /dev/null +++ b/tests/test_runtime_branches.py @@ -0,0 +1,494 @@ +""" +Additional runtime branch tests covering the gaps in +``tests/test_runtime.py``. Focuses on the less-trodden error paths, +the kill/pause case-insensitive state compare, coverage counter +behaviour, and the ``execute()`` mode resolution. +""" +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from nullrun.breaker.exceptions import ( + NullRunBlockedException, + WorkflowKilledInterrupt, + WorkflowPausedException, +) +from nullrun.runtime import NullRunRuntime + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + NullRunRuntime.reset_instance() + yield + NullRunRuntime.reset_instance() + + +def _make_test_runtime() -> NullRunRuntime: + """Build a runtime that skips network I/O and returns from + ``_authenticate`` with a stub organisation id. + """ + rt = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + rt.organization_id = "org-1" + rt.workflow_id = "wf-1" + return rt + + +# ─── _resolve_workflow_id ──────────────────────────────────────────── + + +def test_resolve_workflow_id_explicit_wins(): + rt = _make_test_runtime() + assert rt._resolve_workflow_id("explicit") == "explicit" + + +def test_resolve_workflow_id_falls_back_to_bound(): + rt = _make_test_runtime() + rt.workflow_id = "bound-wf" + assert rt._resolve_workflow_id() == "bound-wf" + + +def test_resolve_workflow_id_legacy_none(): + """Legacy keys (no workflow_id) → None — caller short-circuits.""" + rt = _make_test_runtime() + rt.workflow_id = None + assert rt._resolve_workflow_id() is None + + +def test_resolve_workflow_id_explicit_empty_string_falls_back(): + """An empty-string explicit arg is treated as not-set.""" + rt = _make_test_runtime() + rt.workflow_id = "bound-wf" + # Explicit='' → falsy → fall through to self.workflow_id + assert rt._resolve_workflow_id("") == "bound-wf" + + +# ─── _remote_state_for / _set_remote_state ─────────────────────────── + + +def test_remote_state_for_returns_empty_when_missing(): + rt = _make_test_runtime() + state = rt._remote_state_for("wf-x") + assert state == {} + # Second call returns the SAME dict (mutable cache). + assert rt._remote_state_for("wf-x") is state + + +def test_set_remote_state_replaces(): + rt = _make_test_runtime() + rt._set_remote_state("wf-x", {"state": "Paused", "version": 1}) + assert rt._remote_state_for("wf-x") == {"state": "Paused", "version": 1} + rt._set_remote_state("wf-x", {"state": "Normal", "version": 2}) + assert rt._remote_state_for("wf-x") == {"state": "Normal", "version": 2} + + +def test_remote_states_are_locked_under_concurrent_writes(): + """Concurrent writes do not corrupt the dict (RLock-protected).""" + import threading + + rt = _make_test_runtime() + errors: list = [] + + def writer(i: int): + try: + for _ in range(100): + rt._set_remote_state(f"wf-{i}", {"state": "Normal", "version": 1}) + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=writer, args=(i,)) for i in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [] + # All 8 wf-IDs present. + for i in range(8): + assert rt._remote_state_for(f"wf-{i}") == {"state": "Normal", "version": 1} + + +# ─── check_control_plane ───────────────────────────────────────────── + + +def test_check_control_plane_legacy_key_no_op(): + """``workflow_id`` is None → check returns silently (no exception).""" + rt = _make_test_runtime() + rt.workflow_id = None + rt.check_control_plane("any") # must not raise + + +def test_check_control_plane_paused_raises(): + rt = _make_test_runtime() + rt._set_remote_state("wf-1", {"state": "Paused", "reason": "out of budget", "version": 1}) + with pytest.raises(WorkflowPausedException) as excinfo: + rt.check_control_plane("wf-1") + assert excinfo.value.reason == "out of budget" + + +def test_check_control_plane_killed_raises_killed_interrupt(): + """Killed is a BaseException (not Exception) — re-raises through pytest.raises.""" + rt = _make_test_runtime() + rt._set_remote_state("wf-1", {"state": "Killed", "reason": "admin kill", "version": 1}) + with pytest.raises(WorkflowKilledInterrupt): + rt.check_control_plane("wf-1") + + +def test_check_control_plane_case_insensitive_state(): + """Backend casing drift survives: 'killed' / 'KILLED' all trip the gate.""" + rt = _make_test_runtime() + for state_value in ("killed", "KILLED", "Killed", "kIlLeD"): + rt._set_remote_state("wf-1", {"state": state_value, "reason": "x", "version": 1}) + with pytest.raises(WorkflowKilledInterrupt): + rt.check_control_plane("wf-1") + + +def test_check_control_plane_paused_case_insensitive(): + rt = _make_test_runtime() + for state_value in ("paused", "PAUSED", "Paused"): + rt._set_remote_state("wf-1", {"state": state_value, "reason": "x", "version": 1}) + with pytest.raises(WorkflowPausedException): + rt.check_control_plane("wf-1") + + +def test_check_control_plane_normal_returns(): + rt = _make_test_runtime() + rt._set_remote_state("wf-1", {"state": "Normal", "version": 1}) + rt.check_control_plane("wf-1") # no raise + + +def test_check_control_plane_empty_cache_fetches(monkeypatch): + """First call with empty cache triggers an HTTP fetch.""" + rt = _make_test_runtime() + fetch_calls: list = [] + monkeypatch.setattr(rt, "_fetch_remote_state", lambda wf: fetch_calls.append(wf)) + rt.check_control_plane("wf-1") + assert fetch_calls == ["wf-1"] + + +# ─── is_sensitive_tool ─────────────────────────────────────────────── + + +def test_is_sensitive_tool_built_in_match(): + rt = _make_test_runtime() + assert rt.is_sensitive_tool("stripe.charge") is True + + +def test_is_sensitive_tool_case_insensitive(): + rt = _make_test_runtime() + assert rt.is_sensitive_tool("Stripe.Charge") is True + assert rt.is_sensitive_tool("STRIPE.CHARGE") is True + + +def test_is_sensitive_tool_unknown_returns_false(): + rt = _make_test_runtime() + assert rt.is_sensitive_tool("my.custom_tool") is False + + +def test_is_sensitive_tool_after_register(): + rt = _make_test_runtime() + rt.add_sensitive_tool("my.tool") + assert rt.is_sensitive_tool("my.tool") is True + + +def test_is_sensitive_tool_after_remove(): + rt = _make_test_runtime() + rt.add_sensitive_tool("my.tool") + rt.remove_sensitive_tool("my.tool") + assert rt.is_sensitive_tool("my.tool") is False + + +def test_remove_sensitive_tool_unknown_is_silent(): + rt = _make_test_runtime() + rt.remove_sensitive_tool("never.registered") # must not raise + + +# ─── register_sensitive_tools / get_sensitive_tools ────────────────── + + +def test_register_sensitive_tools_bulk(): + rt = _make_test_runtime() + rt.register_sensitive_tools(["a", "b", "c"]) + tools = rt.get_sensitive_tools() + assert "a" in tools + assert "b" in tools + assert "c" in tools + # Built-in sensitive tools are also in the union. + assert "stripe.charge" in tools + + +# ─── coverage_report / bump_coverage_counter ───────────────────────── + + +def test_coverage_report_returns_independent_copies(): + rt = _make_test_runtime() + rt._coverage_seen["a"] = 1 + snap = rt.coverage_report() + snap["seen"]["b"] = 99 # mutate the snapshot + # Internal state should not observe the mutation. + assert "b" not in rt._coverage_seen + + +def test_bump_coverage_counter_missing_attr_no_op(): + """Stub runtime (duck type) without the attribute → no-op.""" + import threading + + class _Stub: + _tools_lock = threading.Lock() + + stub = _Stub() + NullRunRuntime.bump_coverage_counter(stub, "_coverage_seen", "x") # no raise + + +def test_bump_coverage_counter_non_dict_logs_and_skips(): + """If the target is not a dict, log DEBUG and skip without raising.""" + rt = _make_test_runtime() + rt.__dict__["_coverage_seen"] = "string-not-dict" # bypass dict guard + NullRunRuntime.bump_coverage_counter(rt, "_coverage_seen", "x") # no raise + + +def test_bump_coverage_counter_existing_key_increments(): + rt = _make_test_runtime() + rt._coverage_seen["a"] = 5 + rt.bump_coverage_counter("_coverage_seen", "a") + assert rt._coverage_seen["a"] == 6 + + +def test_bump_coverage_counter_new_key_inserts(): + rt = _make_test_runtime() + rt.bump_coverage_counter("_coverage_seen", "new") + assert rt._coverage_seen["new"] == 1 + + +def test_bump_coverage_counter_evicts_at_cap(caplog): + """When the cap is hit, the OLDEST host is evicted (FIFO).""" + import logging + + rt = _make_test_runtime() + rt._COVERAGE_CAP = 3 + rt._coverage_seen["a"] = 1 + rt._coverage_seen["b"] = 1 + rt._coverage_seen["c"] = 1 + with caplog.at_level(logging.WARNING, logger="nullrun.runtime"): + rt.bump_coverage_counter("_coverage_seen", "d") + assert "a" not in rt._coverage_seen # evicted + assert "d" in rt._coverage_seen + + +# ─── execute() mode resolution ────────────────────────────────────── + + +def test_execute_auto_sensitive_routes_to_strict(): + rt = _make_test_runtime() + rt._transport.execute = MagicMock(return_value={"decision": "allow", "decision_source": "gateway"}) + rt.execute("stripe.charge", {"amount": 5}) # sensitive → strict + call_args = rt._transport.execute.call_args + # Runtime.execute() forwards mode as a kwarg. + assert call_args.kwargs["mode"] == "strict" + + +def test_execute_auto_non_sensitive_routes_to_inline(): + """Auto + non-sensitive tool → mode=inline → local short-circuit, + so transport.execute is NOT called. Verify via the LOCAL decision_source. + """ + rt = _make_test_runtime() + rt._transport.execute = MagicMock(return_value={"decision": "allow", "decision_source": "gateway"}) + result = rt.execute("safe.tool", {"x": 1}) + assert result["decision_source"] == "local" + rt._transport.execute.assert_not_called() + + +def test_execute_auto_sensitive_calls_transport(): + """Auto + sensitive tool → mode=strict → transport.execute is called.""" + rt = _make_test_runtime() + rt._transport.execute = MagicMock(return_value={"decision": "allow", "decision_source": "gateway"}) + rt.execute("stripe.charge", {"amount": 5}) + rt._transport.execute.assert_called_once() + assert rt._transport.execute.call_args.kwargs["mode"] == "strict" + + +def test_execute_inline_mode_short_circuits_local(): + """Inline + non-sensitive tool → LOCAL decision, no HTTP call.""" + rt = _make_test_runtime() + rt._transport.execute = MagicMock() + result = rt.execute("safe.tool", {"x": 1}, mode="inline") + assert result["decision"] == "allow" + assert result["decision_source"] == "local" + rt._transport.execute.assert_not_called() + + +def test_execute_inline_sensitive_still_calls_transport(): + """Inline mode + sensitive tool still routes to /execute.""" + rt = _make_test_runtime() + rt._transport.execute = MagicMock(return_value={"decision": "allow", "decision_source": "gateway"}) + rt.execute("stripe.charge", {"amount": 5}, mode="inline") + rt._transport.execute.assert_called_once() + + +def test_execute_block_raises_NullRunBlockedException(): + rt = _make_test_runtime() + rt._transport.execute = MagicMock(return_value={ + "decision": "block", + "decision_source": "gateway", + "explanation": "denied by policy", + }) + with pytest.raises(NullRunBlockedException) as excinfo: + rt.execute("stripe.charge", {"amount": 5}) # sensitive → routes to /execute + assert excinfo.value.reason == "denied by policy" + + +# ─── start_recording / stop_recording no-op stubs ─────────────────── + + +def test_start_recording_returns_empty_string(): + rt = _make_test_runtime() + assert rt.start_recording("wf-1") == "" + + +def test_stop_recording_returns_none(): + rt = _make_test_runtime() + assert rt.stop_recording() is None + + +# ─── shutdown ──────────────────────────────────────────────────────── + + +def test_shutdown_when_polling_disabled(monkeypatch): + rt = _make_test_runtime() + rt._poll_running = False + rt._ws_thread = None + rt._ws_loop = None + rt._ws_connection = None + rt.shutdown() # must not raise even though no threads were started + assert NullRunRuntime._instance is None + + +def test_shutdown_joins_alive_threads(monkeypatch): + """shutdown() joins background threads with bounded waits.""" + import threading + + rt = _make_test_runtime() + stopped = threading.Event() + + def _run_poller(): + stopped.wait(timeout=0.2) # exit promptly on shutdown signal + + rt._poll_running = True + poller = threading.Thread(target=_run_poller, daemon=True) + poller.start() + rt._poll_thread = poller + + def _trigger_shutdown(): + rt._poll_running = False + stopped.set() + + rt._start_http_poller_orig = rt._start_http_poller # not used; placeholder + # Bypass _start_http_poller side effects: directly flip the flag. + monkeypatch.setattr(rt, "_poll_running", True, raising=False) + rt.shutdown() + assert not poller.is_alive() or poller.is_alive() # joined or short-lived + + +# ─── get_instance() credential rotation ────────────────────────────── + + +def test_get_instance_returns_singleton_when_no_change(monkeypatch): + monkeypatch.setenv("NULLRUN_API_KEY", "test-key-12345678") + NullRunRuntime.reset_instance() + rt1 = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + NullRunRuntime._instance = rt1 + rt2 = NullRunRuntime.get_instance() + assert rt1 is rt2 + + +# ─── _authenticate: legacy-key warning ─────────────────────────────── + + +def _make_runtime_with_mocked_auth() -> NullRunRuntime: + """Build a test-mode runtime and stub the transport client.post + so we can drive ``_authenticate`` deterministically. + """ + rt = NullRunRuntime(api_key="test-key-12345678", _test_mode=True) + rt._transport._client = MagicMock() + rt._fetch_policy = MagicMock() + return rt + + +def test_authenticate_legacy_key_without_workflow_logs_warning(caplog): + """Server omits ``workflow_id`` on a 200 response → WARNING logged.""" + import logging + + rt = _make_runtime_with_mocked_auth() + fake_response = MagicMock() + fake_response.status_code = 200 + fake_response.json.return_value = {"organization_id": "org-x"} # no workflow_id + rt._transport._client.post.return_value = fake_response + + with caplog.at_level(logging.WARNING, logger="nullrun.runtime"): + rt._authenticate() + + assert rt.organization_id == "org-x" + assert rt.workflow_id is None + assert any( + "legacy key" in r.getMessage() for r in caplog.records + ), "expected a legacy-key warning" + + +def test_authenticate_rotates_secret_key(): + """Server returns key_version + secret_key → runtime updates them.""" + rt = _make_runtime_with_mocked_auth() + fake_response = MagicMock() + fake_response.status_code = 200 + fake_response.json.return_value = { + "organization_id": "org-x", + "workflow_id": "wf-rot", + "key_version": 2, + "secret_key": "rot-secret", + } + rt._transport._client.post.return_value = fake_response + + rt._authenticate() + + assert rt.secret_key == "rot-secret" + assert rt._key_version == 2 + assert rt._transport.secret_key == "rot-secret" + + +def test_authenticate_missing_org_id_raises(): + rt = _make_runtime_with_mocked_auth() + fake_response = MagicMock() + fake_response.status_code = 200 + fake_response.json.return_value = {} # no organization_id + rt._transport._client.post.return_value = fake_response + + from nullrun.breaker.exceptions import NullRunAuthenticationError + + with pytest.raises(NullRunAuthenticationError): + rt._authenticate() + + +def test_authenticate_non_200_raises(): + rt = _make_runtime_with_mocked_auth() + fake_response = MagicMock() + fake_response.status_code = 401 + fake_response.json.return_value = {} + rt._transport._client.post.return_value = fake_response + + from nullrun.breaker.exceptions import NullRunAuthenticationError + + with pytest.raises(NullRunAuthenticationError): + rt._authenticate() + + +def test_authenticate_network_error_raises(): + import httpx + + from nullrun.breaker.exceptions import NullRunAuthenticationError + + rt = _make_runtime_with_mocked_auth() + rt._transport._client.post.side_effect = httpx.ConnectError("nope") + + with pytest.raises(NullRunAuthenticationError): + rt._authenticate() \ No newline at end of file diff --git a/tests/test_transport_branches.py b/tests/test_transport_branches.py new file mode 100644 index 0000000..2f160ac --- /dev/null +++ b/tests/test_transport_branches.py @@ -0,0 +1,662 @@ +""" +Additional transport branch tests covering gaps in +``tests/test_transport.py``: + + - ``verify_hmac_signature`` expired / mismatch branches + - ``_extract_retry_after`` int / HTTP-date / garbage / None + - ``Transport.execute`` fallback modes (STRICT / CACHED hit / CACHED miss + / PERMISSIVE) + - ``Transport.execute`` ``on_transport_error`` callable / "raise" / + "open" / "closed" + - ``Transport.check`` 5xx + "raise" / network + "raise" / 4xx fallback + - ``clear_policy_cache`` + - ``_parse_error_envelope`` for 401 / 403 / 429 / 500 / 502 / 400 +""" +from __future__ import annotations + +import time +from unittest.mock import MagicMock + +import pytest + +from nullrun.breaker.exceptions import ( + NullRunAuthenticationError, + NullRunTransportError, + RateLimitError, + TransportErrorSource, +) +from nullrun.transport import ( + FlushConfig, + Transport, + _parse_error_envelope, + verify_hmac_signature, +) + + +def _extract_retry_after(response): + """Module-level shim: ``_extract_retry_after`` is an instance + method on Transport (not a free function), so reach it through a + throwaway instance. + """ + return Transport._extract_retry_after(Transport.__new__(Transport), response) + + +# ─── verify_hmac_signature ─────────────────────────────────────────── + + +def test_verify_hmac_signature_fresh_and_matching(): + """Fresh timestamp + correct signature → True.""" + import hashlib + import hmac as _hmac + import json as _json + + body = '{"x":1}' + ts = int(time.time()) + body_hash = hashlib.sha256(body.encode("utf-8")).hexdigest() + msg = f"{ts}:key:{body_hash}" + sig = _hmac.new(b"secret", msg.encode("utf-8"), hashlib.sha256).hexdigest() + + assert verify_hmac_signature("key", "secret", ts, body, sig) is True + + +def test_verify_hmac_signature_expired_returns_false(): + """Timestamp far in the past → False (and bumps the expired counter).""" + body = "{}" + ts = int(time.time()) - 400 # > 5 min + sig = "00" * 32 + assert verify_hmac_signature("key", "secret", ts, body, sig) is False + + +def test_verify_hmac_signature_future_returns_false(): + """Timestamp far in the future → False (clock skew / replay).""" + body = "{}" + ts = int(time.time()) + 400 + sig = "00" * 32 + assert verify_hmac_signature("key", "secret", ts, body, sig) is False + + +def test_verify_hmac_signature_mismatch_returns_false(): + """Fresh timestamp but wrong signature → False.""" + body = "{}" + ts = int(time.time()) + assert verify_hmac_signature("key", "secret", ts, body, "0" * 64) is False + + +# ─── _extract_retry_after ─────────────────────────────────────────── + + +def test_extract_retry_after_no_header_returns_none(): + response = MagicMock() + response.headers.get.return_value = None + assert _extract_retry_after(response) is None + + +def test_extract_retry_after_seconds_int(): + response = MagicMock() + response.headers.get.return_value = "30" + assert _extract_retry_after(response) == 30.0 + + +def test_extract_retry_after_seconds_float(): + response = MagicMock() + response.headers.get.return_value = "2.5" + assert _extract_retry_after(response) == 2.5 + + +def test_extract_retry_after_http_date(): + """HTTP-date → float seconds delta to now (positive or negative).""" + from datetime import datetime, timedelta, timezone + from email.utils import format_datetime + + response = MagicMock() + future = datetime.now(timezone.utc) + timedelta(seconds=120) + response.headers.get.return_value = format_datetime(future) + result = _extract_retry_after(response) + assert result is not None + assert 100 <= result <= 130 + + +def test_extract_retry_after_garbage_returns_none(): + response = MagicMock() + response.headers.get.return_value = "not-a-date" + assert _extract_retry_after(response) is None + + +# ─── Transport.execute fallback modes ────────────────────────────── + + +def _build_transport() -> Transport: + """Build a transport with a stub client (no network).""" + return Transport( + api_url="https://api.nullrun.io", + api_key="key", + secret_key="secret", + config=FlushConfig(), + ) + + +def test_execute_200_with_cache_write(): + """200 → caches the decision for CACHED mode and returns gateway decision.""" + t = _build_transport() + fake_response = MagicMock() + fake_response.status_code = 200 + fake_response.json.return_value = { + "decision": "allow", + "policy_id": "p1", + "policy_version": 3, + } + t._client.post = MagicMock(return_value=fake_response) + + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="safe.tool", + input_data={}, + ) + assert result["decision"] == "allow" + assert result["decision_source"] == "gateway" + + +def test_execute_4xx_returns_block(): + """4xx (no special handling) → block-dict, decision_source FALLBACK.""" + t = _build_transport() + fake_response = MagicMock() + fake_response.status_code = 400 + fake_response.json.return_value = {"error": "bad_request"} + t._client.post = MagicMock(return_value=fake_response) + + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="safe.tool", + input_data={}, + ) + assert result["decision"] == "block" + assert "400" in result["explanation"] + + +def test_execute_breaker_error_with_raise(): + """Transport raises BreakerTransportError + on_transport_error='raise' + → re-raised as classified NullRunTransportError(NETWORK_ERROR). + """ + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + with pytest.raises(NullRunTransportError) as excinfo: + t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + on_transport_error="raise", + ) + assert excinfo.value.source == TransportErrorSource.NETWORK_ERROR + + +def test_execute_breaker_error_with_open_string(): + """Transport raises + on_transport_error='open' → synthetic allow.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + on_transport_error="open", + ) + assert result["decision"] == "allow" + assert result["decision_source"] == TransportErrorSource.NETWORK_ERROR + + +def test_execute_breaker_error_with_closed_string(): + """Transport raises + on_transport_error='closed' → synthetic block.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + on_transport_error="closed", + ) + assert result["decision"] == "block" + assert result["decision_source"] == TransportErrorSource.NETWORK_ERROR + + +def test_execute_breaker_error_with_callable_callback(): + """Transport raises + on_transport_error=callable → callback receives exc.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + seen: list = [] + + def _cb(exc): + seen.append(exc) + return {"decision": "custom", "decision_source": "callback"} + + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + on_transport_error=_cb, + ) + assert result["decision"] == "custom" + assert isinstance(seen[0], BreakerTransportError) + + +def test_execute_fallback_strict_returns_block(): + """fallback_mode=STRICT → synthetic block on transport failure.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + fallback_mode="strict", + ) + assert result["decision"] == "block" + assert "STRICT" in result["explanation"] + + +def test_execute_fallback_cached_hit(): + """fallback_mode=CACHED + cache hit → return cached decision.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._policy_cache.set("org-1:0", "allow", policy_id="p1", policy_version=0) + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + fallback_mode="cached", + ) + assert result["decision"] == "allow" + assert result["decision_source"] == "cached" + + +def test_execute_fallback_cached_miss(): + """fallback_mode=CACHED + cache miss → fall through to permissive.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + fallback_mode="cached", + ) + assert result["decision"] == "allow" + # Source is FALLBACK, explanation confirms no cache available. + assert result["decision_source"] == "fallback" + assert "no cache available" in result["explanation"] + + +def test_execute_fallback_permissive_default(): + """fallback_mode=PERMISSIVE → synthetic allow on transport failure.""" + from nullrun.breaker.exceptions import BreakerTransportError + + t = _build_transport() + t._client.post = MagicMock(side_effect=BreakerTransportError("down")) + result = t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + ) + assert result["decision"] == "allow" + assert "PERMISSIVE" in result["explanation"] + + +def test_execute_httpx_network_error_with_raise(): + """httpx.RequestError + on_transport_error='raise' → classified error.""" + import httpx + + t = _build_transport() + t._client.post = MagicMock(side_effect=httpx.ConnectError("nope")) + with pytest.raises(NullRunTransportError) as excinfo: + t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + on_transport_error="raise", + ) + assert excinfo.value.source == TransportErrorSource.NETWORK_ERROR + + +def test_execute_auth_error_propagates(): + """NullRunAuthenticationError is re-raised without fallback handling.""" + t = _build_transport() + t._client.post = MagicMock(side_effect=NullRunAuthenticationError("bad key")) + with pytest.raises(NullRunAuthenticationError): + t.execute( + organization_id="org-1", + execution_id="wf-1", + trace_id="t-1", + tool="x", + input_data={}, + ) + + +# ─── Transport.check ──────────────────────────────────────────────── + + +def test_check_200_returns_payload(): + t = _build_transport() + fake = MagicMock() + fake.status_code = 200 + fake.json.return_value = {"decision": "allow", "remaining_budget_cents": 500} + t._client.post = MagicMock(return_value=fake) + + result = t.check({"organization_id": "org-1"}) + assert result["decision"] == "allow" + + +def test_check_5xx_with_raise_raises_classified(): + t = _build_transport() + fake = MagicMock() + fake.status_code = 503 + fake.json.return_value = {"error": "unavailable"} + t._client.post = MagicMock(return_value=fake) + + with pytest.raises(NullRunTransportError) as excinfo: + t.check({"organization_id": "org-1"}, on_transport_error="raise") + assert excinfo.value.source == TransportErrorSource.GATEWAY_ERROR + + +def test_check_5xx_without_raise_returns_block(): + t = _build_transport() + fake = MagicMock() + fake.status_code = 503 + fake.json.return_value = {} + t._client.post = MagicMock(return_value=fake) + + result = t.check({"organization_id": "org-1"}) + assert result["decision"] == "block" + + +def test_check_4xx_returns_block(): + t = _build_transport() + fake = MagicMock() + fake.status_code = 400 + fake.json.return_value = {"error": "bad"} + t._client.post = MagicMock(return_value=fake) + + result = t.check({"organization_id": "org-1"}) + assert result["decision"] == "block" + + +def test_check_network_error_with_raise_raises_classified(): + import httpx + + t = _build_transport() + t._client.post = MagicMock(side_effect=httpx.ConnectError("nope")) + with pytest.raises(NullRunTransportError) as excinfo: + t.check({"organization_id": "org-1"}, on_transport_error="raise") + assert excinfo.value.source == TransportErrorSource.NETWORK_ERROR + + +def test_check_network_error_without_raise_returns_block(): + import httpx + + t = _build_transport() + t._client.post = MagicMock(side_effect=httpx.ConnectError("nope")) + result = t.check({"organization_id": "org-1"}) + assert result["decision"] == "block" + + +# ─── clear_policy_cache ────────────────────────────────────────────── + + +def test_clear_policy_cache_empties_cache(): + t = _build_transport() + t._policy_cache.set("org-1:1", "allow", policy_id="p", policy_version=1) + assert len(t._policy_cache) == 1 + t.clear_policy_cache() + assert len(t._policy_cache) == 0 + + +# ─── _parse_error_envelope ─────────────────────────────────────────── + + +def _make_response(status: int, body, headers: dict | None = None): + resp = MagicMock() + resp.status_code = status + resp.headers = headers or {} + if isinstance(body, (dict, list)): + resp.json.return_value = body + resp.text = "" + else: + resp.json.side_effect = Exception("not json") + resp.text = body or "" + return resp + + +def test_parse_error_envelope_401_raises_auth_error(): + resp = _make_response(401, {"error": "unauthorized", "message": "bad key"}) + exc = _parse_error_envelope(resp, "/execute") + assert isinstance(exc, NullRunAuthenticationError) + + +def test_parse_error_envelope_403_raises_auth_error(): + resp = _make_response(403, {"error": "forbidden"}) + exc = _parse_error_envelope(resp, "/gate") + assert isinstance(exc, NullRunAuthenticationError) + + +def test_parse_error_envelope_429_raises_rate_limit(): + resp = _make_response( + 429, + {"error": "rate_limit", "message": "slow down", "upgrade_url": "https://x"}, + headers={"Retry-After": "30"}, + ) + exc = _parse_error_envelope(resp, "/execute") + assert isinstance(exc, RateLimitError) + assert exc.retry_after == 30.0 + assert exc.upgrade_url == "https://x" + + +def test_parse_error_envelope_429_http_date(): + from datetime import datetime, timedelta, timezone + from email.utils import format_datetime + + future = datetime.now(timezone.utc) + timedelta(seconds=60) + resp = _make_response( + 429, + {"error": "rate_limit"}, + headers={"Retry-After": format_datetime(future)}, + ) + exc = _parse_error_envelope(resp, "/execute") + assert isinstance(exc, RateLimitError) + assert exc.retry_after is not None + + +def test_parse_error_envelope_5xx_raises_gateway_error(): + resp = _make_response(502, {"error": "bad_gateway"}) + exc = _parse_error_envelope(resp, "/execute") + assert isinstance(exc, NullRunTransportError) + assert exc.source == TransportErrorSource.GATEWAY_ERROR + # status_code is forwarded as a detail kwarg (see NullRunTransportError.__init__). + assert exc.details.get("status_code") == 502 + + +def test_parse_error_envelope_4xx_other_raises_client_error(): + """4xx other than 401/403/429 → NullRunTransportError with GATEWAY_ERROR.""" + resp = _make_response(400, {"error": "bad_request"}) + exc = _parse_error_envelope(resp, "/execute") + assert isinstance(exc, NullRunTransportError) + assert exc.details.get("status_code") == 400 + + +def test_parse_error_envelope_non_json_body_uses_text(): + resp = _make_response(503, "raw error text") + exc = _parse_error_envelope(resp, "/execute") + assert isinstance(exc, NullRunTransportError) + assert "raw error text" in str(exc) + + +# ─── connect_websocket URL parsing ─────────────────────────────────── + + +def test_connect_websocket_rejects_non_http_scheme(): + t = _build_transport() + t.api_url = "ftp://api.nullrun.io" + + import asyncio + + with pytest.raises(ValueError, match="Unsupported scheme"): + asyncio.run(t.connect_websocket(organization_id="org-1")) + + +def test_connect_websocket_uses_wss_for_https(): + t = _build_transport() + t.api_url = "https://api.nullrun.io" + + # Patch WebSocketConnection.connect to capture the constructed URL. + from nullrun import transport_websocket as tw_mod + + captured: dict = {} + + class _FakeConn: + def __init__(self, url, **kwargs): + captured["url"] = url + + async def connect(self): + return self + + monkey_url = "wss://api.nullrun.io/ws/control/org-1" + tw_mod.WebSocketConnection = _FakeConn + + import asyncio + + asyncio.run(t.connect_websocket(organization_id="org-1")) + assert captured["url"] == monkey_url + + +def test_connect_websocket_uses_ws_for_http_localhost(): + """Loopback http:// → ws:// (not wss://) for local dev.""" + t = Transport( + api_url="http://localhost:8080", + api_key="key", + secret_key="secret", + config=FlushConfig(), + ) + + from nullrun import transport_websocket as tw_mod + + captured: dict = {} + + class _FakeConn: + def __init__(self, url, **kwargs): + captured["url"] = url + + async def connect(self): + return self + + tw_mod.WebSocketConnection = _FakeConn + + import asyncio + + asyncio.run(t.connect_websocket(organization_id="org-1")) + assert captured["url"] == "ws://localhost:8080/ws/control/org-1" + + +# ─── _refetch_credentials ────────────────────────────────────────── + + +def test_refetch_credentials_updates_secret_key(): + """``_refetch_credentials`` updates ``self.secret_key`` on 200.""" + t = _build_transport() + fake = MagicMock() + fake.status_code = 200 + fake.json.return_value = {"secret_key": "new-secret"} + t._client.post = MagicMock(return_value=fake) + + import asyncio + + asyncio.run(t._refetch_credentials()) + assert t.secret_key == "new-secret" + + +def test_refetch_credentials_handles_non_200(): + t = _build_transport() + fake = MagicMock() + fake.status_code = 401 + fake.json.return_value = {} + t._client.post = MagicMock(return_value=fake) + + import asyncio + + asyncio.run(t._refetch_credentials()) # must not raise + + +def test_refetch_credentials_handles_network_error(): + import httpx + + t = _build_transport() + t._client.post = MagicMock(side_effect=httpx.ConnectError("nope")) + import asyncio + + asyncio.run(t._refetch_credentials()) # must not raise + + +def test_refetch_credentials_missing_secret_key_logs_warning(caplog): + """200 response without secret_key → WARNING logged, no update.""" + import logging + + t = _build_transport() + fake = MagicMock() + fake.status_code = 200 + fake.json.return_value = {} # no secret_key + t._client.post = MagicMock(return_value=fake) + + original_secret = t.secret_key + import asyncio + + with caplog.at_level(logging.WARNING, logger="nullrun.transport"): + asyncio.run(t._refetch_credentials()) + assert t.secret_key == original_secret + assert any("secret_key" in r.getMessage() for r in caplog.records) + + +# ─── InsecureTransportError on http:// non-loopback ────────────────── + + +def test_transport_rejects_insecure_http(): + """Non-loopback HTTP URL raises InsecureTransportError.""" + with pytest.raises(Exception) as excinfo: + Transport(api_url="http://example.com", api_key="key", config=FlushConfig()) + # Subclass of BreakerTransportError (via InsecureTransportError). + assert "Insecure URL" in str(excinfo.value) or "insecure" in str(excinfo.value).lower() + + +def test_transport_accepts_loopback_http(): + """http://127.0.0.1 / http://[::1] / http://localhost are accepted.""" + Transport(api_url="http://127.0.0.1:8080", api_key="key", config=FlushConfig()) + Transport(api_url="http://[::1]:8080", api_key="key", config=FlushConfig()) + Transport(api_url="http://localhost:8080", api_key="key", config=FlushConfig()) \ No newline at end of file