Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/helix/batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""HELIX minibatch sampler — GEPA parity.

Line-for-line port of
gepa.strategies.batch_sampler.EpochShuffledBatchSampler
(see /tmp/gepa_eval_spec.md §2).
Port of ``gepa.strategies.batch_sampler.EpochShuffledBatchSampler``
(see ``src/gepa/strategies/batch_sampler.py`` in the public GEPA repo,
``github.com/gepa-ai/gepa``).

Also provides :class:`StratifiedBatchSampler`, which is a HELIX extension
over EpochShuffledBatchSampler that guarantees each minibatch of size K
Expand Down
4 changes: 2 additions & 2 deletions src/helix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class EvolutionConfig(BaseModel):
"Max parallel eval workers — bounds both the parent-eval and "
"mutation ThreadPools in the num_parallel_proposals pipeline. "
"GEPA parity: EngineConfig.max_workers "
"(/tmp/gepa-official/src/gepa/optimize_anything.py:485, "
"(src/gepa/optimize_anything.py:485, "
"default os.cpu_count() or 32)."
),
)
Expand Down Expand Up @@ -425,7 +425,7 @@ def model_post_init(self, __context: object) -> None:
# GEPA parity: resolve ``num_parallel_proposals="auto"`` to
# ``max(1, max_workers // minibatch_size)`` once at construction
# time so every downstream consumer sees a plain int. Mirrors
# /tmp/gepa-official/src/gepa/optimize_anything.py:1108-1116.
# GEPA src/gepa/optimize_anything.py:1108-1116.
if self.num_parallel_proposals == "auto":
self.num_parallel_proposals = max(
1, self.max_workers // max(1, self.minibatch_size)
Expand Down
90 changes: 81 additions & 9 deletions src/helix/eval_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
"""HELIX evaluation cache — GEPA parity.

Line-for-line port of the cache layer described in
/tmp/gepa_eval_spec.md §3 (originally at gepa/core/state.py:27-130).
Port of the cache layer from GEPA's ``gepa/core/state.py:27-130``
(public source at ``github.com/gepa-ai/gepa``), extended with a
per-(candidate, example) ``side_info`` slot.

The base GEPA ``CachedEvaluation`` stores only ``(output, score,
objective_scores)``. GEPA's ``OptimizeAnythingAdapter._eval_cache``
(``gepa/adapters/optimize_anything_adapter/optimize_anything_adapter.py:92,
200-216``) caches the richer ``(score, output, side_info)`` tuple in a
*second* adapter-level cache, precisely because reflection prompts need
the side_info to survive a cache hit. HELIX has no analogous adapter
cache — the engine-level cache is the only one — so we fold the
side_info slot into ``CachedEvaluation`` here. This preserves the
LIBERO feedback signal (``evaluation_diagnostics``, ``judge_metrics``,
``evaluator_error``, ``video_path``, ...) across cache hits, which is
what reaches the mutator's ``## Diagnostics`` section
(see ``helix.mutator._render_per_example_diagnostics``).
"""
from __future__ import annotations

Expand Down Expand Up @@ -31,18 +45,29 @@ def _candidate_hash(candidate: dict[str, str]) -> CandidateHash:

@dataclass
class CachedEvaluation(Generic[RolloutOutput]):
"""Per-(candidate, example) cached evaluation entry.

``output``, ``score``, and ``objective_scores`` mirror GEPA's
base ``CachedEvaluation`` (``gepa/core/state.py:37-43``). ``side_info``
is the HELIX-specific extension that lets per-example reflection
diagnostics (LIBERO feedback, evaluator error traces, judge metrics,
media references, ...) survive cache hits — see the module docstring
for the GEPA OptimizeAnythingAdapter precedent.
"""

output: RolloutOutput
score: float
objective_scores: dict[str, float] | None = None
side_info: dict[str, Any] | None = None


@dataclass
class EvaluationCache(Generic[RolloutOutput, DataId]):
_cache: dict[CacheKey, CachedEvaluation[RolloutOutput]] = field(
default_factory=dict
)
# Thread safety (audit-mutation §C4 / audit-budget-caching §C1): the
# parent-minibatch eval now runs inside a ThreadPoolExecutor (see
# Thread safety: the parent-minibatch eval runs inside a
# ThreadPoolExecutor (see
# evolution.py parent-eval parallel stage), so concurrent readers/writers
# can race on ``_cache``. A single lock serialises every access.
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False, compare=False)
Expand All @@ -60,10 +85,11 @@ def put(
output: RolloutOutput,
score: float,
objective_scores: dict[str, float] | None = None,
side_info: dict[str, Any] | None = None,
) -> None:
with self._lock:
self._cache[(_candidate_hash(candidate), example_id)] = CachedEvaluation(
output, score, objective_scores
output, score, objective_scores, side_info
)

def get_batch(
Expand Down Expand Up @@ -95,6 +121,7 @@ def put_batch(
outputs: list[RolloutOutput],
scores: list[float],
objective_scores_list: list[dict[str, float]] | None = None,
side_info_list: list[dict[str, Any]] | None = None,
) -> None:
h = _candidate_hash(candidate)
with self._lock:
Expand All @@ -103,6 +130,7 @@ def put_batch(
outputs[i],
scores[i],
objective_scores_list[i] if objective_scores_list else None,
side_info_list[i] if side_info_list else None,
)
TRACE.emit(
EventType.CACHE_PUT,
Expand All @@ -117,14 +145,36 @@ def evaluate_with_cache_full(
fetcher: Callable[[list[DataId]], Any],
evaluator: Callable[
[Any, dict[str, str]],
tuple[list[RolloutOutput], list[float], list[dict[str, float]] | None],
tuple[
list[RolloutOutput],
list[float],
list[dict[str, float]] | None,
list[dict[str, Any]] | None,
],
],
) -> tuple[
dict[DataId, RolloutOutput],
dict[DataId, float],
dict[DataId, dict[str, float]] | None,
dict[DataId, dict[str, Any]] | None,
int,
]:
"""Evaluate ``candidate`` on ``example_ids`` with per-example caching.

Evaluator return shape (4-tuple):

``(outputs, scores, objective_scores, side_infos)``

``objective_scores`` and ``side_infos`` are both optional
(``None`` allowed). When provided they must be positional to
``batch`` (length ``len(batch)``).

Returns ``(outputs_by_id, scores_by_id, objective_by_id,
side_info_by_id, num_actual_evals)``. ``objective_by_id`` and
``side_info_by_id`` are ``None`` when no entry — cached or
fresh — produced that axis; otherwise they cover the union of
cache-hit and fresh-miss ids for which the axis was non-empty.
"""
cached, uncached_ids = self.get_batch(candidate, example_ids)
outputs_by_id: dict[DataId, RolloutOutput] = {
eid: c.output for eid, c in cached.items()
Expand All @@ -133,20 +183,42 @@ def evaluate_with_cache_full(
eid: c.score for eid, c in cached.items()
}
objective_by_id: dict[DataId, dict[str, float]] | None = None
side_info_by_id: dict[DataId, dict[str, Any]] | None = None
for eid, c in cached.items():
if c.objective_scores is not None:
if objective_by_id is None:
objective_by_id = {}
objective_by_id[eid] = c.objective_scores
if c.side_info is not None:
if side_info_by_id is None:
side_info_by_id = {}
side_info_by_id[eid] = c.side_info
if uncached_ids:
batch = fetcher(uncached_ids)
outputs, scores, obj_scores = evaluator(batch, candidate)
outputs, scores, obj_scores, side_infos = evaluator(batch, candidate)
for idx, eid in enumerate(uncached_ids):
outputs_by_id[eid] = outputs[idx]
scores_by_id[eid] = scores[idx]
if obj_scores is not None:
if objective_by_id is None:
objective_by_id = {}
objective_by_id[eid] = obj_scores[idx]
self.put_batch(candidate, uncached_ids, outputs, scores, obj_scores)
return outputs_by_id, scores_by_id, objective_by_id, len(uncached_ids)
if side_infos is not None:
if side_info_by_id is None:
side_info_by_id = {}
side_info_by_id[eid] = side_infos[idx]
self.put_batch(
candidate,
uncached_ids,
outputs,
scores,
obj_scores,
side_infos,
)
return (
outputs_by_id,
scores_by_id,
objective_by_id,
side_info_by_id,
len(uncached_ids),
)
Loading
Loading