Skip to content
Closed
2 changes: 1 addition & 1 deletion gremlins/cli/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import sys

from gremlins.executor.state import validate_gremlin_id
from gremlins.executor.state_utils import validate_gremlin_id
from gremlins.launcher import resume


Expand Down
6 changes: 5 additions & 1 deletion gremlins/executor/gremlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def run_stages(

class Gremlin:
registry: ArtifactRegistry
state: State | None

def __init__(
self,
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
self.fetch_worktree = fetch_worktree
self.pipeline_path = pipeline_path
self.pipeline_args = pipeline_args or []
self.state = None

@property
def artifact_dir(self) -> pathlib.Path:
Expand Down Expand Up @@ -300,7 +302,9 @@ def _collect_stages(
artifacts=self.registry,
base_ref=self.base_ref,
)
built.append((e.name, stage_state.make_runner(e, scope=stages)))
built.append(
(e.name, stage_state.make_runner(e, scope=stages, gremlin=self))
)
return built

def _unbind_stale_exec_artifacts(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion gremlins/executor/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pathlib
from typing import Any

from gremlins.executor.state import StateData, resolve_state_file
from gremlins.executor.state import StateData
from gremlins.executor.state_utils import resolve_state_file

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions gremlins/executor/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from gremlins.env_file import load_env_file
from gremlins.errors import die
from gremlins.executor.gremlin import Gremlin
from gremlins.executor.state import (
StateData,
from gremlins.executor.state import StateData
from gremlins.executor.state_utils import (
resolve_artifact_dir,
resolve_state_file,
)
Expand Down
79 changes: 36 additions & 43 deletions gremlins/executor/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,32 @@
import math
import os
import pathlib
import re
import secrets
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, cast

from gremlins import paths as _paths
from gremlins.artifacts.registry import ArtifactRegistry
from gremlins.clients.client import Client
from gremlins.executor.state_utils import (
resolve_state_file,
)
from gremlins.utils.state_file import locked_update

if TYPE_CHECKING:
from gremlins.pipeline import Pipeline

from gremlins.protocols import StageProtocol
from gremlins.protocols import GremlinProtocol, StageProtocol
from gremlins.stages.outcome import Done

logger = logging.getLogger(__name__)

_GREMLIN_ID_RE = re.compile(r"^[A-Za-z0-9_-]+$")

BAIL_CLASS_REVIEWER_REQUESTED_CHANGES = "reviewer_requested_changes"
BAIL_CLASS_SECURITY = "security"
BAIL_CLASS_SECRETS = "secrets"
BAIL_CLASS_OTHER = "other"


def validate_gremlin_id(gremlin_id: str) -> None:
"""Raise ValueError if gremlin_id is not a safe, non-path-traversing identifier."""
if ".." in gremlin_id or not _GREMLIN_ID_RE.match(gremlin_id):
raise ValueError(f"gremlin_id contains illegal characters: {gremlin_id!r}")


def resolve_state_file(gremlin_id: str | None) -> pathlib.Path | None:
"""Return path to state.json for gremlin_id, or None when gremlin_id is absent."""
if not gremlin_id:
return None
return _paths.state_root() / gremlin_id / "state.json"


def resolve_artifact_dir(gremlin_id: str | None = None) -> pathlib.Path:
"""Resolve the artifacts directory for the current run."""
state_root = _paths.state_root()
if gremlin_id:
artifact_dir = state_root / gremlin_id / "artifacts"
else:
ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
rand = secrets.token_hex(3)
artifact_dir = state_root / "direct" / f"{ts}-{rand}" / "artifacts"
artifact_dir.mkdir(parents=True, exist_ok=True)
return artifact_dir


def write_state(state_dir: pathlib.Path, data: dict[str, Any]) -> None:
"""Atomically overwrite state.json (no merge)."""
sf = state_dir / "state.json"
Expand All @@ -70,18 +44,6 @@ def write_state(state_dir: pathlib.Path, data: dict[str, Any]) -> None:
os.replace(tmp, sf)


def landable_shape(state: dict[str, Any]) -> str:
"""Classify artifact shape for land dispatch."""
artifacts = list(state.get("artifacts") or [])
prs = [art for art in artifacts if art.get("type") == "pr"]

if not prs:
return "empty"
if len(prs) == 1:
return "one_pr"
return "many_prs"


def _stage_list() -> list[StageProtocol]:
return []

Expand Down Expand Up @@ -487,6 +449,10 @@ class State:
}
)

@property
def registry(self) -> ArtifactRegistry:
return self.artifacts

def framework_subs(self, stage: StageProtocol) -> dict[str, str]:
"""Runtime-owned substitution vars. Stages must not assemble these themselves."""
return {
Expand Down Expand Up @@ -548,6 +514,7 @@ def make_runner(
scope: Sequence[StageProtocol] | None = None,
*,
record_stage: bool = True,
gremlin: GremlinProtocol,
) -> Callable[[], Any]:
base_state = self
gremlin_id = self.data.gremlin_id
Expand Down Expand Up @@ -578,11 +545,37 @@ async def _run_async() -> Any:
entry.skip_if_exists
):
return Done()
return await entry.run(_prepare())
prepared_state = _prepare()
gremlin.state = prepared_state
return await entry.run(gremlin)

return _run_async


class GremlinWrapper:
"""Minimal Gremlin-like wrapper for subprocess contexts."""

def __init__(self, state: State) -> None:
self.state = state
self.registry = state.artifacts

@property
def data(self) -> StateData:
return self.state.data

async def fork(
self,
state: State,
target_id: str,
*,
parent_id: str = "",
group_name: str = "",
child_key: str = "",
pipeline: Any | None = None,
) -> State:
raise NotImplementedError("fork not supported in subprocess context")


def build_state(
data: StateData,
client: Client,
Expand Down
49 changes: 49 additions & 0 deletions gremlins/executor/state_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import datetime
import pathlib
import re
import secrets
from typing import Any

from gremlins import paths as _paths

_GREMLIN_ID_RE = re.compile(r"^[A-Za-z0-9_-]+$")


def validate_gremlin_id(gremlin_id: str) -> None:
"""Raise ValueError if gremlin_id is not a safe, non-path-traversing identifier."""
if ".." in gremlin_id or not _GREMLIN_ID_RE.match(gremlin_id):
raise ValueError(f"gremlin_id contains illegal characters: {gremlin_id!r}")


def resolve_state_file(gremlin_id: str | None) -> pathlib.Path | None:
"""Return path to state.json for gremlin_id, or None when gremlin_id is absent."""
if not gremlin_id:
return None
return _paths.state_root() / gremlin_id / "state.json"


def resolve_artifact_dir(gremlin_id: str | None = None) -> pathlib.Path:
"""Resolve the artifacts directory for the current run."""
state_root = _paths.state_root()
if gremlin_id:
artifact_dir = state_root / gremlin_id / "artifacts"
else:
ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
rand = secrets.token_hex(3)
artifact_dir = state_root / "direct" / f"{ts}-{rand}" / "artifacts"
artifact_dir.mkdir(parents=True, exist_ok=True)
return artifact_dir


def landable_shape(state: dict[str, Any]) -> str:
"""Classify artifact shape for land dispatch."""
artifacts = list(state.get("artifacts") or [])
prs = [art for art in artifacts if art.get("type") == "pr"]

if not prs:
return "empty"
if len(prs) == 1:
return "one_pr"
return "many_prs"
8 changes: 5 additions & 3 deletions gremlins/fleet/land.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gremlins import paths
from gremlins.artifacts.registry import ArtifactRegistry, MissingArtifact
from gremlins.artifacts.resolve import resolve_in_map
from gremlins.executor.state import landable_shape, resolve_artifact_dir
from gremlins.executor.state_utils import landable_shape, resolve_artifact_dir
from gremlins.fleet.resolve import resolve_gremlin
from gremlins.fleet.state import (
liveness_of_state_file,
Expand Down Expand Up @@ -871,7 +871,7 @@ def _exec_land_stage(
import asyncio

from gremlins.clients.client import PACKAGE_DEFAULT
from gremlins.executor.state import StateData, build_state
from gremlins.executor.state import GremlinWrapper, StateData, build_state
from gremlins.stages.outcome import Bail

state = build_state(
Expand All @@ -881,8 +881,10 @@ def _exec_land_stage(
cwd=cwd,
artifacts=registry,
)

wrapper = GremlinWrapper(state)
try:
asyncio.run(land_stage.run(state))
asyncio.run(land_stage.run(wrapper))
return True
except Bail as b:
print(f"error: land: {b.reason}")
Expand Down
3 changes: 2 additions & 1 deletion gremlins/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from gremlins.artifacts.uri import Uri
from gremlins.clients.client import PACKAGE_DEFAULT
from gremlins.executor.gremlin import Gremlin
from gremlins.executor.state import StateData, validate_gremlin_id
from gremlins.executor.state import StateData
from gremlins.executor.state_utils import validate_gremlin_id
from gremlins.pipeline import Pipeline as _PipelineData
from gremlins.pipeline.discovery import list_pipelines, resolve_pipeline_path
from gremlins.utils import git as _git_mod
Expand Down
5 changes: 3 additions & 2 deletions gremlins/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class GremlinProtocol(Protocol):
"""What stages need from a Gremlin."""

registry: Any
state: Any

async def fork(
self,
Expand Down Expand Up @@ -36,6 +37,6 @@ class StageProtocol(Protocol):
client: Any
skip_if_exists: str

async def run(self, state: Any) -> Any:
"""Run this stage with the given execution state."""
async def run(self, gremlin: GremlinProtocol) -> Any:
"""Run this stage with the given gremlin."""
...
6 changes: 4 additions & 2 deletions gremlins/run_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from gremlins import paths
from gremlins.clients.client import Client
from gremlins.clients.registry import CLIENT_FACTORIES
from gremlins.executor.state import State, StateData, build_state, validate_gremlin_id
from gremlins.executor.state import GremlinWrapper, State, StateData, build_state
from gremlins.executor.state_utils import validate_gremlin_id
from gremlins.permissions.loader import load_policy
from gremlins.permissions.validation import validate_policy_against_registry
from gremlins.pipeline import Pipeline
Expand Down Expand Up @@ -164,8 +165,9 @@ async def _run(spec_path: pathlib.Path) -> int:
if stage.client is None:
stage.client = state.client

gremlin = GremlinWrapper(state)
try:
await stage.run(state)
await stage.run(gremlin)
except Bail as b:
cost = getattr(state.client, "total_cost_usd", 0.0) or 0.0
_write_result(
Expand Down
6 changes: 4 additions & 2 deletions gremlins/spawn/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from gremlins import paths
from gremlins.clients.client import Client
from gremlins.clients.registry import CLIENT_FACTORIES
from gremlins.executor.state import State, StateData, build_state, validate_gremlin_id
from gremlins.executor.state import GremlinWrapper, State, StateData, build_state
from gremlins.executor.state_utils import validate_gremlin_id
from gremlins.logging_setup import configure_logging
from gremlins.permissions.loader import load_policy
from gremlins.permissions.validation import validate_policy_against_registry
Expand Down Expand Up @@ -182,8 +183,9 @@ async def _run(spec_path: pathlib.Path) -> int:
if stage.client is None:
stage.client = state.client

gremlin = GremlinWrapper(state)
try:
await stage.run(state)
await stage.run(gremlin)
except Bail as b:
cost = getattr(state.client, "total_cost_usd", 0.0) or 0.0
_write_result(
Expand Down
3 changes: 2 additions & 1 deletion gremlins/spawn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import sys
import traceback

from gremlins.executor.state import StateData, validate_gremlin_id
from gremlins.executor.state import StateData
from gremlins.executor.state_utils import validate_gremlin_id


def main(argv: list[str] | None = None) -> int:
Expand Down
14 changes: 9 additions & 5 deletions gremlins/stages/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gremlins.artifacts.resolve import resolve_in_map
from gremlins.artifacts.uri import Uri
from gremlins.executor.state import State
from gremlins.protocols import GremlinProtocol
from gremlins.stages.agent_runner import run_agent
from gremlins.stages.base import Stage, get_client_from_dict
from gremlins.stages.outcome import Bail, Done, Outcome
Comment on lines 7 to 13
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #3350576190 — part of larger refactoring to eliminate State imports from stage modules.

Expand Down Expand Up @@ -63,7 +64,8 @@ def with_dict(cls, d: dict[str, Any], depth: int = 0) -> Agent:
stage.client = get_client_from_dict(d)
return stage

async def run(self, state: State) -> Outcome:
async def run(self, gremlin: GremlinProtocol) -> Outcome:
state = gremlin if isinstance(gremlin, State) else gremlin.state
opts = dict(self.options)
raw_model = cast(str | None, opts.pop("model", None))

Expand All @@ -73,8 +75,8 @@ async def run(self, state: State) -> Outcome:
raise Bail(f"agent {self.name}: {exc}") from exc

out_map = {
self.substitute_vars(k, state, resolved): self.substitute_vars(
v, state, resolved
self.substitute_vars(k, gremlin, resolved): self.substitute_vars(
v, gremlin, resolved
)
for k, v in self.out_map.items()
}
Expand All @@ -83,10 +85,12 @@ async def run(self, state: State) -> Outcome:
state.artifacts.bind(key, Uri.parse(uri_str))

template = "\n\n".join(self.prompts).rstrip()
prompt = self.substitute_vars(template, state, resolved)
prompt = self.substitute_vars(template, gremlin, resolved)

raw_path = state.artifact_dir / f"stream-{self.name}.jsonl"
model = self.substitute_vars(raw_model, state, resolved) if raw_model else None
model = (
self.substitute_vars(raw_model, gremlin, resolved) if raw_model else None
)
await run_agent(
state, prompt, label=self.name, raw_path=raw_path, model=model, **opts
)
Expand Down
Loading
Loading