diff --git a/gremlins/executor/gremlin.py b/gremlins/executor/gremlin.py index 16f60f30..8efeebd9 100644 --- a/gremlins/executor/gremlin.py +++ b/gremlins/executor/gremlin.py @@ -101,6 +101,7 @@ async def run_stages( class Gremlin: registry: ArtifactRegistry + state: State | None def __init__( self, @@ -139,17 +140,18 @@ def __init__( self.state_dir = state_dir self.gremlin_id = gremlin_id self.pipeline_data = pipeline_data - self.worktree_dir = worktree_dir + self._worktree_dir = worktree_dir self.worktree_parent = worktree_parent self.resume_from = resume_from self.repo = repo self.state_file = state_file self.project_root = project_root self.base_ref_sha = base_ref_sha - self.base_ref = base_ref + self._base_ref_init = base_ref self.fetch_worktree = fetch_worktree self.pipeline_path = pipeline_path self.pipeline_args = pipeline_args or [] + self.state = None @property def artifact_dir(self) -> pathlib.Path: @@ -163,6 +165,63 @@ def state_data(self) -> StateData: def finished(self) -> bool: return (self.state_dir / "finished").is_file() + @property + def _s(self) -> State: + if self.state is None: + raise RuntimeError("state not yet initialized") + return self.state + + @property + def client(self) -> Client: + return self._s.client + + @property + def artifacts(self) -> ArtifactRegistry: + return self._s.artifacts + + @property + def cwd(self) -> str: + return self._s.cwd + + @property + def worktree(self) -> pathlib.Path | None: + return self._s.worktree + + @property + def base_ref(self) -> str: + return self._s.base_ref + + @property + def loop_iteration(self) -> int: + return self._s.data.loop_iteration + + @property + def attempt(self) -> str: + return self._s.data.attempt + + def framework_subs(self, stage: StageProtocol) -> dict[str, str]: + return self._s.framework_subs(stage) + + def done_for(self, path: str) -> set[str]: + return self._s.done_for(path) + + def mark_done(self, path: str, child_name: str) -> None: + self._s.mark_done(path, child_name) + + def clear_done(self, path: str) -> None: + self._s.clear_done(path) + + def record_bail(self, reason: str, *, kind: str = "other") -> None: + self._s.record_bail(reason, kind=kind) + + def record_stage_progress( + self, name: str, sub_stage: object = None, *, parent_stage: str = "" + ) -> None: + self._s.record_stage_progress(name, sub_stage, parent_stage=parent_stage) + + def record_state_field(self, **fields: Any) -> None: + self._s.record_state_field(**fields) + async def fork( self, state: State, @@ -270,6 +329,14 @@ def validate_resume_target(self) -> None: f"valid: {valid_names}" ) + @property + def _resolved_cwd(self) -> str: + return ( + str(self._worktree_dir) + if self._worktree_dir is not None + else (self.project_root or str(pathlib.Path.cwd())) + ) + def _set_gremlin_recursive(self, stage: StageProtocol) -> None: stage.gremlin = self body = getattr(stage, "body", []) @@ -279,11 +346,6 @@ def _set_gremlin_recursive(self, stage: StageProtocol) -> None: def _collect_stages( self, stages: Sequence[StageProtocol] ) -> list[tuple[str, Callable[[], Awaitable[Any]]]]: - cwd = ( - str(self.worktree_dir) - if self.worktree_dir is not None - else (self.project_root or str(pathlib.Path.cwd())) - ) built: list[tuple[str, Callable[[], Awaitable[Any]]]] = [] for e in stages: self._set_gremlin_recursive(e) @@ -294,11 +356,11 @@ def _collect_stages( artifact_dir=self.artifact_dir, pipeline_data=self.pipeline_data, repo=self.repo, - cwd=cwd, - worktree=self.worktree_dir, + cwd=self._resolved_cwd, + worktree=self._worktree_dir, worktree_parent=self.worktree_parent, artifacts=self.registry, - base_ref=self.base_ref, + base_ref=self._base_ref_init, ) built.append((e.name, stage_state.make_runner(e, scope=stages))) return built @@ -479,7 +541,7 @@ def initialize_with_runtime( worktree_created: str | None = None try: - if self.worktree_dir is None and self.project_root and self.gremlin_id: + if self._worktree_dir is None and self.project_root and self.gremlin_id: workdir = _git_mod.setup_workdir( self.project_root, self.base_ref_sha, @@ -488,7 +550,7 @@ def initialize_with_runtime( worktree_parent=self.worktree_parent, ) worktree_created = workdir - self.worktree_dir = pathlib.Path(workdir) + self._worktree_dir = pathlib.Path(workdir) st = StateData.load(self.gremlin_id) st.patch( workdir=workdir, @@ -496,12 +558,12 @@ def initialize_with_runtime( setup_kind="worktree-detached", ) - if self.worktree_dir is not None: - os.chdir(self.worktree_dir) + if self._worktree_dir is not None: + os.chdir(self._worktree_dir) self.registry = ArtifactRegistry( artifact_dir=self.artifact_dir, - cwd=self.worktree_dir, + cwd=self._worktree_dir, ) for key, value in (stage_inputs or {}).items(): if value is not None and not self.registry.produced(key): @@ -509,9 +571,22 @@ def initialize_with_runtime( if not self.registry.produced("spec"): self.registry.bind("spec", Uri.parse("file://session/spec.md")) if not self.registry.produced("base_sha"): - sha = _git_mod.head_sha(cwd=self.worktree_dir) + sha = _git_mod.head_sha(cwd=self._worktree_dir) if sha: self.registry.bind("base_sha", Uri.parse(f"git://commit/{sha}")) + + self.state = build_state( + data=StateData(gremlin_id=self.gremlin_id, state_file=None), + client=resolved_client or PACKAGE_DEFAULT, + artifact_dir=self.artifact_dir, + pipeline_data=self.pipeline_data, + repo=self.repo, + cwd=self._resolved_cwd, + worktree=self._worktree_dir, + worktree_parent=self.worktree_parent, + artifacts=self.registry, + base_ref=base_ref, + ) except Exception: if worktree_created: _git_mod.remove_worktree(self.project_root, worktree_created) diff --git a/gremlins/executor/run.py b/gremlins/executor/run.py index bc84015f..72de644b 100644 --- a/gremlins/executor/run.py +++ b/gremlins/executor/run.py @@ -261,7 +261,7 @@ async def run_pipeline( _env_file = paths.project_overlay_dir(_project_root) / "env" if _env_file.is_file(): os.environ["GREMLINS_WORKTREE_PATH"] = ( - str(gremlin.worktree_dir) if gremlin.worktree_dir else "" + str(gremlin.worktree) if gremlin.worktree else "" ) os.environ["GREMLINS_ARTIFACT_DIR"] = str(gremlin.artifact_dir) try: diff --git a/tests/test_gremlin_open.py b/tests/test_gremlin_open.py index ee3ac00f..54f67f2b 100644 --- a/tests/test_gremlin_open.py +++ b/tests/test_gremlin_open.py @@ -85,7 +85,6 @@ def test_gremlin_open_valid_state(sandbox, project_dir, pipeline_yaml): assert gremlin.gremlin_id == gremlin_id assert gremlin.project_root == str(project_dir) - assert gremlin.worktree_dir == pathlib.Path("/tmp/worktree") assert gremlin.pipeline_data is not None diff --git a/tests/test_gremlin_smoke.py b/tests/test_gremlin_smoke.py index 55269c89..bccbc108 100644 --- a/tests/test_gremlin_smoke.py +++ b/tests/test_gremlin_smoke.py @@ -5,6 +5,7 @@ import asyncio import json import os +import pathlib import shutil import subprocess @@ -83,7 +84,7 @@ def test_gremlin_run_in_process(project_dir, pipeline_yaml, sandbox): pipeline_ref=str(pipeline_yaml), project_root=str(project_dir), ) - worktree = gremlin.worktree_dir + worktree = gremlin.worktree asyncio.run(gremlin.run()) rc = 0 finally: @@ -141,3 +142,74 @@ def test_resume_unbind_only_affects_exec_stages(tmp_path): gremlin._unbind_stale_exec_artifacts() assert not gremlin.registry.produced("work-out") assert gremlin.registry.produced("non-exec-artifact") + + +def test_gremlin_state_delegates_after_initialize(project_dir, pipeline_yaml, sandbox): + gremlin_id = "state-delegates-abc123" + sd = sandbox.state / gremlin_id + + saved_cwd = os.getcwd() + worktree = None + try: + gremlin = Gremlin.initialize_with_runtime( + gremlin_id=gremlin_id, + state_dir=sd, + project_dir=project_dir, + pipeline_ref=str(pipeline_yaml), + project_root=str(project_dir), + base_ref="main", + repo="test/repo", + ) + worktree = gremlin.worktree + + assert gremlin.state is not None + assert isinstance(gremlin.artifact_dir, pathlib.Path) + assert gremlin.artifact_dir == sd / "artifacts" + assert gremlin.artifacts is not None + assert isinstance(gremlin.artifacts, ArtifactRegistry) + assert gremlin.client is not None + assert isinstance(gremlin.cwd, str) + assert gremlin.base_ref == "main" + assert gremlin.repo == "test/repo" + assert gremlin.loop_iteration == 1 + assert gremlin.attempt == "" + finally: + os.chdir(saved_cwd) + if worktree and worktree.is_dir(): + shutil.rmtree(worktree, ignore_errors=True) + + +def test_gremlin_state_unchanged_after_run(project_dir, pipeline_yaml, sandbox): + gremlin_id = "state-run-abc123" + sd = sandbox.state / gremlin_id + + saved_cwd = os.getcwd() + worktree = None + rc = 1 + try: + gremlin = Gremlin.initialize_with_runtime( + gremlin_id=gremlin_id, + state_dir=sd, + project_dir=project_dir, + pipeline_ref=str(pipeline_yaml), + project_root=str(project_dir), + base_ref="main", + repo="test/repo", + ) + worktree = gremlin.worktree + + initial_client = gremlin.client + initial_base_ref = gremlin.base_ref + initial_repo = gremlin.repo + + asyncio.run(gremlin.run()) + rc = 0 + + assert gremlin.client == initial_client + assert gremlin.base_ref == initial_base_ref + assert gremlin.repo == initial_repo + finally: + os.chdir(saved_cwd) + StateData.load(gremlin_id).write_terminal_state(rc) + if worktree and worktree.is_dir(): + shutil.rmtree(worktree, ignore_errors=True)