Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions gremlins/executor/gremlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def run_stages(

class Gremlin:
registry: ArtifactRegistry
state: State | None

Comment on lines 102 to 105
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added state: Any to GremlinProtocol to enable type checking for stage code using self.gremlin.state.

def __init__(
self,
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
self.fetch_worktree = fetch_worktree
self.pipeline_path = pipeline_path
self.pipeline_args = pipeline_args or []
self.state = None

@property
def artifact_dir(self) -> pathlib.Path:
Expand Down Expand Up @@ -276,14 +278,17 @@ def _set_gremlin_recursive(self, stage: StageProtocol) -> None:
for nested in body:
self._set_gremlin_recursive(nested)

def _collect_stages(
self, stages: Sequence[StageProtocol]
) -> list[tuple[str, Callable[[], Awaitable[Any]]]]:
cwd = (
def _resolve_cwd(self) -> str:
return (
str(self.worktree_dir)
if self.worktree_dir is not None
else (self.project_root or str(pathlib.Path.cwd()))
)

def _collect_stages(
self, stages: Sequence[StageProtocol]
) -> list[tuple[str, Callable[[], Awaitable[Any]]]]:
cwd = self._resolve_cwd()
built: list[tuple[str, Callable[[], Awaitable[Any]]]] = []
for e in stages:
self._set_gremlin_recursive(e)
Expand All @@ -300,7 +305,16 @@ def _collect_stages(
artifacts=self.registry,
base_ref=self.base_ref,
)
built.append((e.name, stage_state.make_runner(e, scope=stages)))
base_runner = stage_state.make_runner(e, scope=stages)

async def _set_state_and_run(
runner: Callable[[], Awaitable[Any]] = base_runner,
state: State = stage_state,
) -> Any:
self.state = state
return await runner()

built.append((e.name, _set_state_and_run))
Comment on lines +308 to +317
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Wrapped nested stage runners in sequence, loop, and parallel stages to set child.gremlin.state before execution. This ensures nested stages have access to their correct child state.

return built

def _unbind_stale_exec_artifacts(self) -> None:
Expand Down Expand Up @@ -512,6 +526,20 @@ def initialize_with_runtime(
sha = _git_mod.head_sha(cwd=self.worktree_dir)
if sha:
self.registry.bind("base_sha", Uri.parse(f"git://commit/{sha}"))

cwd = self._resolve_cwd()
self.state = build_state(
data=StateData(gremlin_id=self.gremlin_id, state_file=self.state_file),
client=resolved_client or PACKAGE_DEFAULT,
artifact_dir=self.artifact_dir,
pipeline_data=self.pipeline_data,
repo=self.repo,
cwd=cwd,
worktree=self.worktree_dir,
worktree_parent=self.worktree_parent,
artifacts=self.registry,
base_ref=self.base_ref,
)
except Exception:
if worktree_created:
_git_mod.remove_worktree(self.project_root, worktree_created)
Expand Down
1 change: 1 addition & 0 deletions gremlins/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class GremlinProtocol(Protocol):
"""What stages need from a Gremlin."""

registry: Any
state: Any

async def fork(
self,
Expand Down
10 changes: 5 additions & 5 deletions gremlins/stages/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ def _build_runners(self, state: State) -> list[Callable[[], Awaitable[Outcome]]]
)
name = child.name

async def _tracked(
r: Callable[[], Awaitable[Any]] = base, n: str = name
) -> Outcome:
state.data.patch(active_children=[n])
async def _tracked() -> Outcome:
state.data.patch(active_children=[name])
try:
return cast(Outcome, await r())
if child.gremlin is not None:
child.gremlin.state = cs
return cast(Outcome, await base())
finally:
state.data.patch(_delete=("active_children",))

Expand Down
14 changes: 12 additions & 2 deletions gremlins/stages/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,18 @@ async def run(self, state: State) -> Outcome:
cs = _child_state(
group_state, child, fan_out=True, child_id=child_id or None
)
runner = cs.make_runner(child, scope=self.body)
child_runners.append((child.name, cs, runner))
base_runner = cs.make_runner(child, scope=self.body)

async def _set_state_and_run(
runner: Callable[[], Any] = base_runner,
child_stage: Stage = child,
child_state: State = cs,
) -> Any:
if child_stage.gremlin is not None:
child_stage.gremlin.state = child_state
return await runner()

child_runners.append((child.name, cs, _set_state_and_run))
for _, fn in self.build_runtime_stages(
child_runners,
parent_data=state.data,
Expand Down
11 changes: 9 additions & 2 deletions gremlins/stages/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ async def run(self, state: State) -> Outcome:
if child.name in done:
continue
state.data.patch(active_children=[child.name])
runner = _child_state(state, child).make_runner(
child_state = _child_state(state, child)
base_runner = child_state.make_runner(
child, scope=self.body, record_stage=False
)

async def _set_state_and_run() -> Any:
if child.gremlin is not None:
child.gremlin.state = child_state
return await base_runner()

try:
await runner()
await _set_state_and_run()
finally:
state.data.patch(_delete=("active_children",))
state.mark_done(key, child.name)
Expand Down
145 changes: 145 additions & 0 deletions tests/test_gremlin_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Tests for Gremlin.state attribute."""

from __future__ import annotations

import asyncio
import subprocess

import pytest

from gremlins.executor.gremlin import Gremlin
from gremlins.executor.state import State
from gremlins.protocols import StageProtocol
from gremlins.stages.outcome import Done


def _init_git_repo(path) -> None:
subprocess.run(["git", "init"], cwd=str(path), check=True, capture_output=True)
subprocess.run(
["git", "config", "user.name", "Test"],
cwd=str(path),
check=True,
capture_output=True,
)
subprocess.run(
["git", "config", "user.email", "test@test.com"],
cwd=str(path),
check=True,
capture_output=True,
)
subprocess.run(
["git", "add", "-A"],
cwd=str(path),
check=True,
capture_output=True,
)
subprocess.run(
["git", "commit", "-m", "init"],
cwd=str(path),
check=True,
capture_output=True,
)


@pytest.fixture
def project_dir(tmp_path):
"""Git repository for testing."""
d = tmp_path / "project"
d.mkdir()
(d / "file.txt").write_text("hello")
_init_git_repo(d)
return d


@pytest.fixture
def pipeline_yaml(tmp_path):
p = tmp_path / "trivial.yaml"
p.write_text(
"""\
stages:
- name: test
type: exec
options:
cmds:
- "true"
"""
)
return p


def test_gremlin_state_populated_after_initialize(project_dir, pipeline_yaml, sandbox):
"""Gremlin.state is non-None after initialize_with_runtime."""
gremlin_id = "test-state-init"
state_dir = sandbox.state / gremlin_id

gremlin = Gremlin.initialize_with_runtime(
gremlin_id=gremlin_id,
state_dir=state_dir,
project_dir=project_dir,
pipeline_ref=str(pipeline_yaml),
project_root=str(project_dir),
)

assert gremlin.state is not None


def test_gremlin_state_attributes_accessible(project_dir, pipeline_yaml, sandbox):
"""Gremlin.state attributes are accessible after initialize_with_runtime."""
gremlin_id = "test-state-attrs"
state_dir = sandbox.state / gremlin_id

gremlin = Gremlin.initialize_with_runtime(
gremlin_id=gremlin_id,
state_dir=state_dir,
project_dir=project_dir,
pipeline_ref=str(pipeline_yaml),
project_root=str(project_dir),
)

assert gremlin.state is not None
assert gremlin.state.client is not None
assert gremlin.state.artifacts is not None
assert gremlin.state.artifact_dir is not None


class StateCapturingStage(StageProtocol):
"""Test stage that captures gremlin.state when run."""

def __init__(self):
self.name = "capture"
self.type = "test"
self.path = "capture"
self.gremlin = None
self.client = None
self.captured_state = None
self.skip_if_exists = None

async def run(self, state: State):
"""Capture the gremlin.state value during execution."""
if self.gremlin:
self.captured_state = self.gremlin.state
return Done()


def test_gremlin_state_set_before_stage_run(project_dir, pipeline_yaml, sandbox):
"""Gremlin.state is set before stage.run() is called."""
gremlin_id = "test-state-runner"
state_dir = sandbox.state / gremlin_id

gremlin = Gremlin.initialize_with_runtime(
gremlin_id=gremlin_id,
state_dir=state_dir,
project_dir=project_dir,
pipeline_ref=str(pipeline_yaml),
project_root=str(project_dir),
)

stage = StateCapturingStage()

collected = gremlin._collect_stages([stage])
assert len(collected) > 0
assert collected[0][0] == "capture"

asyncio.run(collected[0][1]())
assert stage.captured_state is not None
assert stage.captured_state == gremlin.state
Loading