diff --git a/gremlins/cli/resume.py b/gremlins/cli/resume.py index f744f1b8..53da6b30 100644 --- a/gremlins/cli/resume.py +++ b/gremlins/cli/resume.py @@ -3,7 +3,7 @@ import argparse import sys -from gremlins.executor.state import validate_gremlin_id +from gremlins.executor.state_utils import validate_gremlin_id from gremlins.launcher import resume diff --git a/gremlins/executor/gremlin.py b/gremlins/executor/gremlin.py index 16f60f30..7123588b 100644 --- a/gremlins/executor/gremlin.py +++ b/gremlins/executor/gremlin.py @@ -101,6 +101,7 @@ async def run_stages( class Gremlin: registry: ArtifactRegistry + state: State | None def __init__( self, @@ -150,6 +151,7 @@ def __init__( self.fetch_worktree = fetch_worktree self.pipeline_path = pipeline_path self.pipeline_args = pipeline_args or [] + self.state = None @property def artifact_dir(self) -> pathlib.Path: @@ -300,7 +302,9 @@ def _collect_stages( artifacts=self.registry, base_ref=self.base_ref, ) - built.append((e.name, stage_state.make_runner(e, scope=stages))) + built.append( + (e.name, stage_state.make_runner(e, scope=stages, gremlin=self)) + ) return built def _unbind_stale_exec_artifacts(self) -> None: diff --git a/gremlins/executor/parallel_state.py b/gremlins/executor/parallel_state.py index 165d88c6..64cf7528 100644 --- a/gremlins/executor/parallel_state.py +++ b/gremlins/executor/parallel_state.py @@ -8,7 +8,8 @@ import pathlib from typing import Any -from gremlins.executor.state import StateData, resolve_state_file +from gremlins.executor.state import StateData +from gremlins.executor.state_utils import resolve_state_file logger = logging.getLogger(__name__) diff --git a/gremlins/executor/run.py b/gremlins/executor/run.py index bc84015f..0f391d97 100644 --- a/gremlins/executor/run.py +++ b/gremlins/executor/run.py @@ -21,8 +21,8 @@ from gremlins.env_file import load_env_file from gremlins.errors import die from gremlins.executor.gremlin import Gremlin -from gremlins.executor.state import ( - StateData, +from gremlins.executor.state import StateData +from gremlins.executor.state_utils import ( resolve_artifact_dir, resolve_state_file, ) diff --git a/gremlins/executor/state.py b/gremlins/executor/state.py index 2f93c2ea..c466e733 100644 --- a/gremlins/executor/state.py +++ b/gremlins/executor/state.py @@ -10,7 +10,6 @@ import math import os import pathlib -import re import secrets from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any, ClassVar, cast @@ -18,50 +17,25 @@ from gremlins import paths as _paths from gremlins.artifacts.registry import ArtifactRegistry from gremlins.clients.client import Client +from gremlins.executor.state_utils import ( + resolve_state_file, +) from gremlins.utils.state_file import locked_update if TYPE_CHECKING: from gremlins.pipeline import Pipeline -from gremlins.protocols import StageProtocol +from gremlins.protocols import GremlinProtocol, StageProtocol from gremlins.stages.outcome import Done logger = logging.getLogger(__name__) -_GREMLIN_ID_RE = re.compile(r"^[A-Za-z0-9_-]+$") - BAIL_CLASS_REVIEWER_REQUESTED_CHANGES = "reviewer_requested_changes" BAIL_CLASS_SECURITY = "security" BAIL_CLASS_SECRETS = "secrets" BAIL_CLASS_OTHER = "other" -def validate_gremlin_id(gremlin_id: str) -> None: - """Raise ValueError if gremlin_id is not a safe, non-path-traversing identifier.""" - if ".." in gremlin_id or not _GREMLIN_ID_RE.match(gremlin_id): - raise ValueError(f"gremlin_id contains illegal characters: {gremlin_id!r}") - - -def resolve_state_file(gremlin_id: str | None) -> pathlib.Path | None: - """Return path to state.json for gremlin_id, or None when gremlin_id is absent.""" - if not gremlin_id: - return None - return _paths.state_root() / gremlin_id / "state.json" - - -def resolve_artifact_dir(gremlin_id: str | None = None) -> pathlib.Path: - """Resolve the artifacts directory for the current run.""" - state_root = _paths.state_root() - if gremlin_id: - artifact_dir = state_root / gremlin_id / "artifacts" - else: - ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - rand = secrets.token_hex(3) - artifact_dir = state_root / "direct" / f"{ts}-{rand}" / "artifacts" - artifact_dir.mkdir(parents=True, exist_ok=True) - return artifact_dir - - def write_state(state_dir: pathlib.Path, data: dict[str, Any]) -> None: """Atomically overwrite state.json (no merge).""" sf = state_dir / "state.json" @@ -70,18 +44,6 @@ def write_state(state_dir: pathlib.Path, data: dict[str, Any]) -> None: os.replace(tmp, sf) -def landable_shape(state: dict[str, Any]) -> str: - """Classify artifact shape for land dispatch.""" - artifacts = list(state.get("artifacts") or []) - prs = [art for art in artifacts if art.get("type") == "pr"] - - if not prs: - return "empty" - if len(prs) == 1: - return "one_pr" - return "many_prs" - - def _stage_list() -> list[StageProtocol]: return [] @@ -487,6 +449,10 @@ class State: } ) + @property + def registry(self) -> ArtifactRegistry: + return self.artifacts + def framework_subs(self, stage: StageProtocol) -> dict[str, str]: """Runtime-owned substitution vars. Stages must not assemble these themselves.""" return { @@ -548,6 +514,7 @@ def make_runner( scope: Sequence[StageProtocol] | None = None, *, record_stage: bool = True, + gremlin: GremlinProtocol, ) -> Callable[[], Any]: base_state = self gremlin_id = self.data.gremlin_id @@ -578,11 +545,37 @@ async def _run_async() -> Any: entry.skip_if_exists ): return Done() - return await entry.run(_prepare()) + prepared_state = _prepare() + gremlin.state = prepared_state + return await entry.run(gremlin) return _run_async +class GremlinWrapper: + """Minimal Gremlin-like wrapper for subprocess contexts.""" + + def __init__(self, state: State) -> None: + self.state = state + self.registry = state.artifacts + + @property + def data(self) -> StateData: + return self.state.data + + async def fork( + self, + state: State, + target_id: str, + *, + parent_id: str = "", + group_name: str = "", + child_key: str = "", + pipeline: Any | None = None, + ) -> State: + raise NotImplementedError("fork not supported in subprocess context") + + def build_state( data: StateData, client: Client, diff --git a/gremlins/executor/state_utils.py b/gremlins/executor/state_utils.py new file mode 100644 index 00000000..7567176f --- /dev/null +++ b/gremlins/executor/state_utils.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import datetime +import pathlib +import re +import secrets +from typing import Any + +from gremlins import paths as _paths + +_GREMLIN_ID_RE = re.compile(r"^[A-Za-z0-9_-]+$") + + +def validate_gremlin_id(gremlin_id: str) -> None: + """Raise ValueError if gremlin_id is not a safe, non-path-traversing identifier.""" + if ".." in gremlin_id or not _GREMLIN_ID_RE.match(gremlin_id): + raise ValueError(f"gremlin_id contains illegal characters: {gremlin_id!r}") + + +def resolve_state_file(gremlin_id: str | None) -> pathlib.Path | None: + """Return path to state.json for gremlin_id, or None when gremlin_id is absent.""" + if not gremlin_id: + return None + return _paths.state_root() / gremlin_id / "state.json" + + +def resolve_artifact_dir(gremlin_id: str | None = None) -> pathlib.Path: + """Resolve the artifacts directory for the current run.""" + state_root = _paths.state_root() + if gremlin_id: + artifact_dir = state_root / gremlin_id / "artifacts" + else: + ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + rand = secrets.token_hex(3) + artifact_dir = state_root / "direct" / f"{ts}-{rand}" / "artifacts" + artifact_dir.mkdir(parents=True, exist_ok=True) + return artifact_dir + + +def landable_shape(state: dict[str, Any]) -> str: + """Classify artifact shape for land dispatch.""" + artifacts = list(state.get("artifacts") or []) + prs = [art for art in artifacts if art.get("type") == "pr"] + + if not prs: + return "empty" + if len(prs) == 1: + return "one_pr" + return "many_prs" diff --git a/gremlins/fleet/land.py b/gremlins/fleet/land.py index f2de08d6..ff5adbc3 100644 --- a/gremlins/fleet/land.py +++ b/gremlins/fleet/land.py @@ -15,7 +15,7 @@ from gremlins import paths from gremlins.artifacts.registry import ArtifactRegistry, MissingArtifact from gremlins.artifacts.resolve import resolve_in_map -from gremlins.executor.state import landable_shape, resolve_artifact_dir +from gremlins.executor.state_utils import landable_shape, resolve_artifact_dir from gremlins.fleet.resolve import resolve_gremlin from gremlins.fleet.state import ( liveness_of_state_file, @@ -871,7 +871,7 @@ def _exec_land_stage( import asyncio from gremlins.clients.client import PACKAGE_DEFAULT - from gremlins.executor.state import StateData, build_state + from gremlins.executor.state import GremlinWrapper, StateData, build_state from gremlins.stages.outcome import Bail state = build_state( @@ -881,8 +881,10 @@ def _exec_land_stage( cwd=cwd, artifacts=registry, ) + + wrapper = GremlinWrapper(state) try: - asyncio.run(land_stage.run(state)) + asyncio.run(land_stage.run(wrapper)) return True except Bail as b: print(f"error: land: {b.reason}") diff --git a/gremlins/launcher.py b/gremlins/launcher.py index f76b12de..235d2f0f 100644 --- a/gremlins/launcher.py +++ b/gremlins/launcher.py @@ -25,7 +25,8 @@ from gremlins.artifacts.uri import Uri from gremlins.clients.client import PACKAGE_DEFAULT from gremlins.executor.gremlin import Gremlin -from gremlins.executor.state import StateData, validate_gremlin_id +from gremlins.executor.state import StateData +from gremlins.executor.state_utils import validate_gremlin_id from gremlins.pipeline import Pipeline as _PipelineData from gremlins.pipeline.discovery import list_pipelines, resolve_pipeline_path from gremlins.utils import git as _git_mod diff --git a/gremlins/protocols.py b/gremlins/protocols.py index 98f5c85a..8d6b3b11 100644 --- a/gremlins/protocols.py +++ b/gremlins/protocols.py @@ -9,6 +9,7 @@ class GremlinProtocol(Protocol): """What stages need from a Gremlin.""" registry: Any + state: Any async def fork( self, @@ -36,6 +37,6 @@ class StageProtocol(Protocol): client: Any skip_if_exists: str - async def run(self, state: Any) -> Any: - """Run this stage with the given execution state.""" + async def run(self, gremlin: GremlinProtocol) -> Any: + """Run this stage with the given gremlin.""" ... diff --git a/gremlins/run_child.py b/gremlins/run_child.py index 78f8dae0..254e1003 100644 --- a/gremlins/run_child.py +++ b/gremlins/run_child.py @@ -42,7 +42,8 @@ from gremlins import paths from gremlins.clients.client import Client from gremlins.clients.registry import CLIENT_FACTORIES -from gremlins.executor.state import State, StateData, build_state, validate_gremlin_id +from gremlins.executor.state import GremlinWrapper, State, StateData, build_state +from gremlins.executor.state_utils import validate_gremlin_id from gremlins.permissions.loader import load_policy from gremlins.permissions.validation import validate_policy_against_registry from gremlins.pipeline import Pipeline @@ -164,8 +165,9 @@ async def _run(spec_path: pathlib.Path) -> int: if stage.client is None: stage.client = state.client + gremlin = GremlinWrapper(state) try: - await stage.run(state) + await stage.run(gremlin) except Bail as b: cost = getattr(state.client, "total_cost_usd", 0.0) or 0.0 _write_result( diff --git a/gremlins/spawn/child.py b/gremlins/spawn/child.py index 36bf86bf..b7334428 100644 --- a/gremlins/spawn/child.py +++ b/gremlins/spawn/child.py @@ -45,7 +45,8 @@ from gremlins import paths from gremlins.clients.client import Client from gremlins.clients.registry import CLIENT_FACTORIES -from gremlins.executor.state import State, StateData, build_state, validate_gremlin_id +from gremlins.executor.state import GremlinWrapper, State, StateData, build_state +from gremlins.executor.state_utils import validate_gremlin_id from gremlins.logging_setup import configure_logging from gremlins.permissions.loader import load_policy from gremlins.permissions.validation import validate_policy_against_registry @@ -182,8 +183,9 @@ async def _run(spec_path: pathlib.Path) -> int: if stage.client is None: stage.client = state.client + gremlin = GremlinWrapper(state) try: - await stage.run(state) + await stage.run(gremlin) except Bail as b: cost = getattr(state.client, "total_cost_usd", 0.0) or 0.0 _write_result( diff --git a/gremlins/spawn/pipeline.py b/gremlins/spawn/pipeline.py index a04e9d10..57dc5817 100644 --- a/gremlins/spawn/pipeline.py +++ b/gremlins/spawn/pipeline.py @@ -12,7 +12,8 @@ import sys import traceback -from gremlins.executor.state import StateData, validate_gremlin_id +from gremlins.executor.state import StateData +from gremlins.executor.state_utils import validate_gremlin_id def main(argv: list[str] | None = None) -> int: diff --git a/gremlins/stages/agent.py b/gremlins/stages/agent.py index d07aad41..e46ecbf3 100644 --- a/gremlins/stages/agent.py +++ b/gremlins/stages/agent.py @@ -7,6 +7,7 @@ from gremlins.artifacts.resolve import resolve_in_map from gremlins.artifacts.uri import Uri from gremlins.executor.state import State +from gremlins.protocols import GremlinProtocol from gremlins.stages.agent_runner import run_agent from gremlins.stages.base import Stage, get_client_from_dict from gremlins.stages.outcome import Bail, Done, Outcome @@ -63,7 +64,8 @@ def with_dict(cls, d: dict[str, Any], depth: int = 0) -> Agent: stage.client = get_client_from_dict(d) return stage - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: + state = gremlin if isinstance(gremlin, State) else gremlin.state opts = dict(self.options) raw_model = cast(str | None, opts.pop("model", None)) @@ -73,8 +75,8 @@ async def run(self, state: State) -> Outcome: raise Bail(f"agent {self.name}: {exc}") from exc out_map = { - self.substitute_vars(k, state, resolved): self.substitute_vars( - v, state, resolved + self.substitute_vars(k, gremlin, resolved): self.substitute_vars( + v, gremlin, resolved ) for k, v in self.out_map.items() } @@ -83,10 +85,12 @@ async def run(self, state: State) -> Outcome: state.artifacts.bind(key, Uri.parse(uri_str)) template = "\n\n".join(self.prompts).rstrip() - prompt = self.substitute_vars(template, state, resolved) + prompt = self.substitute_vars(template, gremlin, resolved) raw_path = state.artifact_dir / f"stream-{self.name}.jsonl" - model = self.substitute_vars(raw_model, state, resolved) if raw_model else None + model = ( + self.substitute_vars(raw_model, gremlin, resolved) if raw_model else None + ) await run_agent( state, prompt, label=self.name, raw_path=raw_path, model=model, **opts ) diff --git a/gremlins/stages/base.py b/gremlins/stages/base.py index 01ffd8b9..f6e53a3e 100644 --- a/gremlins/stages/base.py +++ b/gremlins/stages/base.py @@ -4,7 +4,6 @@ from typing import Any, NamedTuple from gremlins.clients.client import Client -from gremlins.executor.state import State from gremlins.protocols import GremlinProtocol from gremlins.stages.outcome import Outcome @@ -49,12 +48,15 @@ def __init__(self, name: str) -> None: self.gremlin = None def substitute_vars( - self, text: str, state: State, extra: dict[str, str] | None = None + self, text: str, gremlin: GremlinProtocol, extra: dict[str, str] | None = None ) -> str: """Replace {var} tokens with framework subs, resolved in: vars, and string options (framework wins on conflict). Unknown tokens and non-word braces (shell ${x}, {read:k}, brace expansion) are left as-is.""" + from gremlins.executor.state import State + string_opts = {k: v for k, v in self.options.items() if isinstance(v, str)} + state = gremlin if isinstance(gremlin, State) else gremlin.state subs = {**string_opts, **(extra or {}), **state.framework_subs(self)} # type: ignore[arg-type] return _VAR_SUB.sub(lambda m: subs.get(m.group(1), m.group(0)), text) @@ -78,5 +80,5 @@ def with_dict(cls, d: dict[str, Any], depth: int = 0) -> Stage: def orchestration_args(cls) -> list[StageInput]: return [] - async def run(self, state: State) -> Outcome: # noqa: ARG002 + async def run(self, gremlin: GremlinProtocol) -> Outcome: # noqa: ARG002 raise NotImplementedError diff --git a/gremlins/stages/composite.py b/gremlins/stages/composite.py index fbbddf70..413e6edb 100644 --- a/gremlins/stages/composite.py +++ b/gremlins/stages/composite.py @@ -11,22 +11,26 @@ def child_state( - parent: State, child: Stage, *, fan_out: bool = False, child_id: str | None = None + parent_state: State, + child: Stage, + *, + fan_out: bool = False, + child_id: str | None = None, ) -> State: """Derive a child State from parent.""" - client = parent.client + client = parent_state.client if child.client is not None and child.client != PACKAGE_DEFAULT: client = child.client if not fan_out: - return dataclasses.replace(parent, client=client) + return dataclasses.replace(parent_state, client=client) if child_id: artifact_dir = _paths.state_root() / child_id / "artifacts" artifact_dir.mkdir(parents=True, exist_ok=True) else: - artifact_dir = parent.artifact_dir / child.name + artifact_dir = parent_state.artifact_dir / child.name artifact_dir.mkdir(parents=True, exist_ok=True) return dataclasses.replace( - parent, + parent_state, client=client, artifact_dir=artifact_dir, child_key=child.name, diff --git a/gremlins/stages/exec.py b/gremlins/stages/exec.py index f91b725e..336a2ecd 100644 --- a/gremlins/stages/exec.py +++ b/gremlins/stages/exec.py @@ -14,6 +14,7 @@ from gremlins.artifacts.schemes import snapshot_head_before from gremlins.artifacts.uri import Uri from gremlins.executor.state import State +from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage from gremlins.stages.outcome import Bail, Done, Outcome from gremlins.utils import proc as _proc @@ -73,7 +74,10 @@ def with_dict(cls, d: dict[str, Any], depth: int = 0) -> Exec: out_map=dict(cast(dict[str, str], raw_out)), ) - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: + from gremlins.executor.state import State + + state = gremlin if isinstance(gremlin, State) else gremlin.state try: extra_env = resolve_in_map(state.artifacts, self.in_map) except ValueError as exc: @@ -84,7 +88,7 @@ async def run(self, state: State) -> Outcome: pre_sha = snapshot_head_before(cwd=pathlib.Path(state.cwd)) cmds = [ - self.substitute_vars(c.rstrip(), state, extra_env) + self.substitute_vars(c.rstrip(), gremlin, extra_env) for c in self.options.get("cmds", []) if c.strip() ] @@ -116,7 +120,7 @@ async def run(self, state: State) -> Outcome: raise Bail(f"exec {self.name}: exited {result.returncode}") for raw_key, raw_uri_str in self.out_map.items(): - key = self.substitute_vars(raw_key, state, extra_env) + key = self.substitute_vars(raw_key, gremlin, extra_env) optional = key.endswith("?") if optional: key = key[:-1] @@ -127,7 +131,7 @@ async def run(self, state: State) -> Outcome: continue try: uri_str = self.substitute_vars( - _sub_reads(raw_uri_str, state.artifacts), state, extra_env + _sub_reads(raw_uri_str, state.artifacts), gremlin, extra_env ) except MissingArtifact: if optional: diff --git a/gremlins/stages/loop.py b/gremlins/stages/loop.py index 963a6886..a166f57d 100644 --- a/gremlins/stages/loop.py +++ b/gremlins/stages/loop.py @@ -10,6 +10,7 @@ from gremlins.artifacts.registry import ArtifactRegistry from gremlins.executor.state import State +from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage, get_client_from_dict from gremlins.stages.composite import child_state as _child_state from gremlins.stages.outcome import Bail, Done, Outcome @@ -134,12 +135,14 @@ def with_dict(cls, d: dict[str, Any], depth: int = 0) -> LoopStage: stage.client = get_client_from_dict(d) return stage - def _build_runners(self, state: State) -> list[Callable[[], Awaitable[Outcome]]]: + def _build_runners( + self, state: State, gremlin: GremlinProtocol + ) -> list[Callable[[], Awaitable[Outcome]]]: result: list[Callable[[], Awaitable[Outcome]]] = [] for child in self.body: cs = _child_state(state, child) base: Callable[[], Awaitable[Any]] = cs.make_runner( - child, scope=self.body, record_stage=False + child, scope=self.body, record_stage=False, gremlin=gremlin ) name = child.name @@ -155,7 +158,10 @@ async def _tracked( result.append(cast(Callable[[], Awaitable[Outcome]], _tracked)) return result - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: + from gremlins.executor.state import State + + state = gremlin if isinstance(gremlin, State) else gremlin.state for iteration in range(1, self._max_iterations + 1): state.record_state_field(loop_iteration=iteration) state.artifacts.unbind(_MARKER_KEY) @@ -167,7 +173,7 @@ async def run(self, state: State) -> Outcome: runners = ( self._body_runners if self._body_runners is not None - else self._build_runners(state) + else self._build_runners(state, gremlin) ) had_failure = await _dispatch_runners( runners, iteration, self._max_iterations, state.artifacts diff --git a/gremlins/stages/parallel.py b/gremlins/stages/parallel.py index b1ce7cd3..ac32473a 100644 --- a/gremlins/stages/parallel.py +++ b/gremlins/stages/parallel.py @@ -21,6 +21,7 @@ from gremlins.artifacts.uri import Uri from gremlins.executor.parallel_state import ParallelGroupState from gremlins.executor.state import State, StateData +from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage from gremlins.stages.composite import child_state as _child_state from gremlins.stages.outcome import Bail, Done, Outcome @@ -169,7 +170,8 @@ def build_runtime_stages( child_stages=child_stages, ).runtime_stages() - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: + state = gremlin if isinstance(gremlin, State) else gremlin.state parent_id = state.data.gremlin_id or "" group_state = dataclasses.replace( state, parent_stage=state.parent_stage or self.name @@ -183,7 +185,7 @@ async def run(self, state: State) -> Outcome: cs = _child_state( group_state, child, fan_out=True, child_id=child_id or None ) - runner = cs.make_runner(child, scope=self.body) + runner = cs.make_runner(child, scope=self.body, gremlin=gremlin) child_runners.append((child.name, cs, runner)) for _, fn in self.build_runtime_stages( child_runners, diff --git a/gremlins/stages/sequence.py b/gremlins/stages/sequence.py index b5d1a36a..b4666722 100644 --- a/gremlins/stages/sequence.py +++ b/gremlins/stages/sequence.py @@ -4,7 +4,7 @@ from typing import Any, cast -from gremlins.executor.state import State +from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage, get_client_from_dict from gremlins.stages.composite import child_state as _child_state from gremlins.stages.outcome import Done, Outcome @@ -34,7 +34,10 @@ def with_dict(cls, d: dict[str, Any], depth: int = 0) -> SequenceStage: stage.client = get_client_from_dict(d) return stage - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: + from gremlins.executor.state import State + + state = gremlin if isinstance(gremlin, State) else gremlin.state key = self.path or self.name done = state.done_for(key) for child in self.body: @@ -42,7 +45,7 @@ async def run(self, state: State) -> Outcome: continue state.data.patch(active_children=[child.name]) runner = _child_state(state, child).make_runner( - child, scope=self.body, record_stage=False + child, scope=self.body, record_stage=False, gremlin=gremlin ) try: await runner() diff --git a/tests/test_active_children.py b/tests/test_active_children.py index d5c6b795..f948bc02 100644 --- a/tests/test_active_children.py +++ b/tests/test_active_children.py @@ -10,9 +10,10 @@ import pytest from gremlins.clients.fake import FakeClaudeClient -from gremlins.executor.state import State, StateData, build_state +from gremlins.executor.state import GremlinWrapper, StateData, build_state from gremlins.fleet.render import build_row from gremlins.fleet.views import _gremlin_to_json # type: ignore[reportPrivateUsage] +from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage from gremlins.stages.loop import LoopStage from gremlins.stages.outcome import Done, Outcome @@ -20,14 +21,15 @@ from gremlins.stages.sequence import SequenceStage -def _stateful(tmp_path: pathlib.Path, gid: str = "test-id") -> State: +def _stateful(tmp_path: pathlib.Path, gid: str = "test-id") -> GremlinProtocol: sf = tmp_path / "state.json" sf.write_text(json.dumps({"id": gid}), encoding="utf-8") - return build_state( + state = build_state( data=StateData(gremlin_id=gid, state_file=sf), client=FakeClaudeClient(), artifact_dir=tmp_path, ) + return GremlinWrapper(state) def _read_state(tmp_path: pathlib.Path) -> dict[str, Any]: @@ -40,32 +42,32 @@ def _read_state(tmp_path: pathlib.Path) -> dict[str, Any]: def test_sequence_active_children_cleared_after_run(tmp_path: pathlib.Path) -> None: - state = _stateful(tmp_path) + gremlin = _stateful(tmp_path) class _Spy(Stage): captured: list[str] | None = None - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: # type: ignore[reportUnusedVariable] _Spy.captured = _read_state(tmp_path).get("active_children") return Done() seq = SequenceStage("seq", body=[_Spy("child-a")]) - asyncio.run(seq.run(state)) + asyncio.run(seq.run(gremlin)) assert _Spy.captured == ["child-a"] assert "active_children" not in _read_state(tmp_path) def test_sequence_active_children_cleared_on_exception(tmp_path: pathlib.Path) -> None: - state = _stateful(tmp_path) + gremlin = _stateful(tmp_path) class _Boom(Stage): - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: # type: ignore[reportUnusedVariable] raise RuntimeError("boom") seq = SequenceStage("seq", body=[_Boom("child-a")]) with pytest.raises(RuntimeError, match="boom"): - asyncio.run(seq.run(state)) + asyncio.run(seq.run(gremlin)) assert "active_children" not in _read_state(tmp_path) @@ -76,31 +78,31 @@ async def run(self, state: State) -> Outcome: def test_loop_active_children_set_and_cleared(tmp_path: pathlib.Path) -> None: - state = _stateful(tmp_path) + gremlin = _stateful(tmp_path) captured: list[list[str] | None] = [] class _Spy(Stage): - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: # type: ignore[reportUnusedVariable] captured.append(_read_state(tmp_path).get("active_children")) return Done() loop = LoopStage("lp", body=[_Spy("body-stage")], max_iterations=1) - asyncio.run(loop.run(state)) + asyncio.run(loop.run(gremlin)) assert captured == [["body-stage"]] assert "active_children" not in _read_state(tmp_path) def test_loop_active_children_cleared_on_exception(tmp_path: pathlib.Path) -> None: - state = _stateful(tmp_path) + gremlin = _stateful(tmp_path) class _Boom(Stage): - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: # type: ignore[reportUnusedVariable] raise RuntimeError("boom") loop = LoopStage("lp", body=[_Boom("body-stage")], max_iterations=1) with pytest.raises(RuntimeError, match="boom"): - asyncio.run(loop.run(state)) + asyncio.run(loop.run(gremlin)) assert "active_children" not in _read_state(tmp_path) diff --git a/tests/test_fleet.py b/tests/test_fleet.py index fcb98434..600652bb 100644 --- a/tests/test_fleet.py +++ b/tests/test_fleet.py @@ -935,7 +935,9 @@ class _OkStage: async def run(self, state): pass - monkeypatch.setattr(_state_mod, "build_state", lambda **_: object()) + mock_state = MagicMock() + mock_state.artifacts = MagicMock() + monkeypatch.setattr(_state_mod, "build_state", lambda **_: mock_state) result = _land._exec_land_stage(_OkStage(), MagicMock(), "", tmp_path) assert result is True @@ -950,7 +952,9 @@ class _BailStage: async def run(self, state): raise Bail("structural") - monkeypatch.setattr(_state_mod, "build_state", lambda **_: object()) + mock_state = MagicMock() + mock_state.artifacts = MagicMock() + monkeypatch.setattr(_state_mod, "build_state", lambda **_: mock_state) result = _land._exec_land_stage(_BailStage(), MagicMock(), "", tmp_path) assert result is False assert "structural" in capsys.readouterr().out diff --git a/tests/test_orchestrator_boss.py b/tests/test_orchestrator_boss.py index e70e08ef..6d80db8c 100644 --- a/tests/test_orchestrator_boss.py +++ b/tests/test_orchestrator_boss.py @@ -10,7 +10,7 @@ import pytest from gremlins.clients.fake import FakeClaudeClient -from gremlins.executor.state import StateData, build_state +from gremlins.executor.state import GremlinWrapper, StateData, build_state from gremlins.pipeline import Pipeline from gremlins.stages.outcome import Bail @@ -61,7 +61,7 @@ def _make_loop(tmp_path: pathlib.Path, worktree: pathlib.Path, signal: dict): artifact_dir=artifact_dir, worktree=worktree, ) - return state, loop_stage + return GremlinWrapper(state), state, loop_stage def test_boss_chain_done_exits_loop(sandbox, tmp_path): @@ -71,8 +71,8 @@ def test_boss_chain_done_exits_loop(sandbox, tmp_path): "reason": None, "operator_followups": [], } - state, loop = _make_loop(tmp_path, sandbox.project, signal) - asyncio.run(loop.run(state)) + gremlin, state, loop = _make_loop(tmp_path, sandbox.project, signal) + asyncio.run(loop.run(gremlin)) assert state.artifacts.read("status") == "pass" @@ -87,9 +87,9 @@ def test_boss_next_plan_needs_fix_and_plan_swap(sandbox, tmp_path): "reason": None, "operator_followups": [], } - state, loop = _make_loop(tmp_path, sandbox.project, signal) + gremlin, state, loop = _make_loop(tmp_path, sandbox.project, signal) with pytest.raises(Bail): - asyncio.run(loop.run(state)) + asyncio.run(loop.run(gremlin)) assert state.artifacts.read("status") == "needs_fix" assert (artifact_dir / "plan.md").read_text(encoding="utf-8") == "# Next\n" @@ -101,7 +101,7 @@ def test_boss_bail_raises_with_reason(sandbox, tmp_path): "reason": "bad state", "operator_followups": [], } - state, loop = _make_loop(tmp_path, sandbox.project, signal) + gremlin, state, loop = _make_loop(tmp_path, sandbox.project, signal) with pytest.raises(Bail, match="bad state"): - asyncio.run(loop.run(state)) + asyncio.run(loop.run(gremlin)) assert state.artifacts.produced("bail") diff --git a/tests/test_parallel_runner.py b/tests/test_parallel_runner.py index 6f2f0b70..06dc6544 100644 --- a/tests/test_parallel_runner.py +++ b/tests/test_parallel_runner.py @@ -339,6 +339,8 @@ async def noop() -> None: def test_parallel_sequence_child_worktree_flows() -> None: """SequenceStage inside a parallel group sees the fanout worktree in all sub-stages.""" + from gremlins.executor.state import GremlinWrapper + from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage from gremlins.stages.outcome import Done, Outcome from gremlins.stages.sequence import SequenceStage @@ -349,8 +351,8 @@ class _CaptureStage(Stage): def __init__(self, name: str) -> None: super().__init__(name) - async def run(self, state: State) -> Outcome: - observed.append(state.worktree) + async def run(self, gremlin: GremlinProtocol) -> Outcome: + observed.append(gremlin.state.worktree) return Done() seq_stage = SequenceStage("seq", body=[_CaptureStage("a"), _CaptureStage("b")]) @@ -363,7 +365,8 @@ async def run(self, state: State) -> Outcome: ) async def seq_runner() -> None: - await seq_stage.run(seq_ctx) + gremlin = GremlinWrapper(seq_ctx) + await seq_stage.run(gremlin) project_root = pathlib.Path.cwd() stages = _make_parallel_stages( @@ -401,23 +404,28 @@ async def async_fn() -> None: def test_make_runner_returns_async_for_any_stage() -> None: + from gremlins.executor.state import GremlinWrapper + from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage from gremlins.stages.outcome import Done, Outcome class AStage(Stage): type = "a-test" - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: return Done() state = build_state( data=StateData(), client=FakeClaudeClient(), artifact_dir=pathlib.Path("/tmp") ) - runner = state.make_runner(AStage("a")) + gremlin = GremlinWrapper(state) + runner = state.make_runner(AStage("a"), gremlin=gremlin) assert inspect.iscoroutinefunction(runner) def test_stages_run_in_order_via_make_runner() -> None: + from gremlins.executor.state import GremlinWrapper + from gremlins.protocols import GremlinProtocol from gremlins.stages.base import Stage from gremlins.stages.outcome import Done, Outcome @@ -426,23 +434,24 @@ def test_stages_run_in_order_via_make_runner() -> None: class StageA(Stage): type = "stage-a" - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: executed.append("a") return Done() class StageB(Stage): type = "stage-b" - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: executed.append("b") return Done() base_state = build_state( data=StateData(), client=FakeClaudeClient(), artifact_dir=pathlib.Path("/tmp") ) + gremlin = GremlinWrapper(base_state) stages = [ - ("a", base_state.make_runner(StageA("a"))), - ("b", base_state.make_runner(StageB("b"))), + ("a", base_state.make_runner(StageA("a"), gremlin=gremlin)), + ("b", base_state.make_runner(StageB("b"), gremlin=gremlin)), ] asyncio.run(run_stages(stages)) assert executed == ["a", "b"] diff --git a/tests/test_skip_if_exists.py b/tests/test_skip_if_exists.py index d1441422..ab720b51 100644 --- a/tests/test_skip_if_exists.py +++ b/tests/test_skip_if_exists.py @@ -9,8 +9,9 @@ from gremlins.artifacts.registry import ArtifactRegistry from gremlins.artifacts.uri import Uri from gremlins.clients.fake import FakeClaudeClient -from gremlins.executor.state import State, StateData, build_state +from gremlins.executor.state import GremlinWrapper, State, StateData, build_state from gremlins.pipeline import Pipeline +from gremlins.protocols import GremlinProtocol from gremlins.stages.agent import Agent from gremlins.stages.base import Stage from gremlins.stages.outcome import Done, Outcome @@ -29,7 +30,7 @@ def __init__(self, name: str, prompts: list[str], options: dict[str, Any]) -> No super().__init__(name) self.run_count = 0 - async def run(self, state: State) -> Outcome: + async def run(self, gremlin: GremlinProtocol) -> Outcome: self.run_count += 1 return Done() @@ -55,7 +56,8 @@ def test_skip_if_exists_skips_when_key_produced(tmp_path: pathlib.Path) -> None: stage = _CountingStage("s", [], {}) stage.skip_if_exists = "my-artifact" - runner = state.make_runner(stage, record_stage=False) + gremlin = GremlinWrapper(state) + runner = state.make_runner(stage, record_stage=False, gremlin=gremlin) asyncio.run(runner()) assert stage.run_count == 0 @@ -67,7 +69,8 @@ def test_skip_if_exists_runs_when_key_absent(tmp_path: pathlib.Path) -> None: stage = _CountingStage("s", [], {}) stage.skip_if_exists = "my-artifact" - runner = state.make_runner(stage, record_stage=False) + gremlin = GremlinWrapper(state) + runner = state.make_runner(stage, record_stage=False, gremlin=gremlin) asyncio.run(runner()) assert stage.run_count == 1 @@ -80,7 +83,8 @@ def test_no_skip_if_exists_always_runs(tmp_path: pathlib.Path) -> None: stage = _CountingStage("s", [], {}) # skip_if_exists is "" by default — should not skip even when key is produced - runner = state.make_runner(stage, record_stage=False) + gremlin = GremlinWrapper(state) + runner = state.make_runner(stage, record_stage=False, gremlin=gremlin) asyncio.run(runner()) assert stage.run_count == 1 diff --git a/tests/test_stage_loop.py b/tests/test_stage_loop.py index 1586d6c7..491d54b0 100644 --- a/tests/test_stage_loop.py +++ b/tests/test_stage_loop.py @@ -9,8 +9,8 @@ import pytest from gremlins.artifacts.uri import Uri +from gremlins.executor.state import GremlinWrapper, StateData, build_state from gremlins.executor.state import State as RuntimeState -from gremlins.executor.state import StateData, build_state from gremlins.stages.exec import Exec as Cmd from gremlins.stages.loop import LoopStage, head_stable, max_iters from gremlins.stages.outcome import Bail, Done @@ -50,7 +50,8 @@ async def runner() -> Done: return Done() loop = LoopStage("loop", body_runners=[runner], max_iterations=3) - outcome = asyncio.run(loop.run(_loop_state(tmp_path))) + state = _loop_state(tmp_path) + outcome = asyncio.run(loop.run(GremlinWrapper(state))) assert outcome == Done() assert calls == ["run"] @@ -72,7 +73,7 @@ async def fix() -> Done: return Done() loop = LoopStage("loop", body_runners=[check, fix], max_iterations=3) - asyncio.run(loop.run(loop_state)) + asyncio.run(loop.run(GremlinWrapper(loop_state))) assert attempt["attempt"] == 2 assert attempt["fixed"] @@ -90,7 +91,8 @@ async def fix() -> Done: return Done() loop = LoopStage("loop", body_runners=[check, fix], max_iterations=3) - asyncio.run(loop.run(_loop_state(tmp_path))) + state = _loop_state(tmp_path) + asyncio.run(loop.run(GremlinWrapper(state))) assert fix_calls == [] @@ -107,7 +109,7 @@ async def fix() -> Done: loop = LoopStage("loop", body_runners=[check, fix], max_iterations=3) with pytest.raises(Bail): - asyncio.run(loop.run(loop_state)) + asyncio.run(loop.run(GremlinWrapper(loop_state))) def test_loop_fix_skipped_on_final_iteration(tmp_path): @@ -127,7 +129,7 @@ async def fix() -> Done: loop = LoopStage("loop", body_runners=[check, fix], max_iterations=3) with pytest.raises(Bail): - asyncio.run(loop.run(loop_state)) + asyncio.run(loop.run(GremlinWrapper(loop_state))) # fix ran for iterations 1 and 2, NOT 3 assert fix_calls == [1, 2] @@ -140,7 +142,8 @@ async def bail_runner() -> Done: loop = LoopStage("loop", body_runners=[bail_runner], max_iterations=3) with pytest.raises(Bail) as exc_info: - asyncio.run(loop.run(_loop_state(tmp_path))) + state = _loop_state(tmp_path) + asyncio.run(loop.run(GremlinWrapper(state))) assert "bail_class=other" in exc_info.value.reason @@ -169,7 +172,7 @@ async def fix() -> Done: loop = LoopStage("loop", body_runners=[check, fix], max_iterations=2) with pytest.raises(Bail): - asyncio.run(loop.run(loop_state)) + asyncio.run(loop.run(GremlinWrapper(loop_state))) bail_file = state_dir / f"bail_{attempt}.json" assert bail_file.exists() @@ -214,7 +217,8 @@ async def runner() -> Done: max_iterations=5, until=max_iters(2), ) - result = asyncio.run(loop.run(_loop_state(tmp_path))) + state = _loop_state(tmp_path) + result = asyncio.run(loop.run(GremlinWrapper(state))) assert result == Done() @@ -299,7 +303,7 @@ async def runner() -> Done: loop = LoopStage("loop", body_runners=[runner], max_iterations=3) with pytest.raises(Bail): - asyncio.run(loop.run(loop_state)) + asyncio.run(loop.run(GremlinWrapper(loop_state))) assert seen_iterations == [1, 2, 3] @@ -329,7 +333,7 @@ async def binder() -> Done: body_runners=[binder], max_iterations=3, ) - asyncio.run(loop.run(state)) + asyncio.run(loop.run(GremlinWrapper(state))) assert bound_count[0] == 2 @@ -358,7 +362,7 @@ async def runner() -> Done: return Done() loop = LoopStage("loop", body_runners=[runner], max_iterations=3, interval=5.0) - asyncio.run(loop.run(loop_state)) + asyncio.run(loop.run(GremlinWrapper(loop_state))) assert count[0] == 2 assert sleep_calls == [5.0] @@ -378,6 +382,7 @@ async def runner() -> Done: return Done() loop = LoopStage("loop", body_runners=[runner], max_iterations=3) - asyncio.run(loop.run(_loop_state(tmp_path))) + state = _loop_state(tmp_path) + asyncio.run(loop.run(GremlinWrapper(state))) assert sleep_calls == []