Skip to content

Commit 93bdef9

Browse files
committed
update
1 parent d6644e3 commit 93bdef9

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
371371
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
372372
retry_tasks = rollout_processor([row], retry_config)
373373
result = await retry_tasks[0]
374-
374+
375375
# Apply post-processing quality checks if configured
376376
# This must be inside the retry function so ResponseQualityError can trigger retries
377377
if config.post_processor is not None:
@@ -380,7 +380,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
380380
except ResponseQualityError as quality_error:
381381
# Re-raise ResponseQualityError to trigger retry logic
382382
raise quality_error
383-
383+
384384
return result
385385

386386
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
@@ -464,7 +464,11 @@ async def execute_row_with_backoff_and_log(
464464
yield result
465465

466466
finally:
467-
rollout_processor.cleanup()
467+
# Prefer async cleanup if available, fall back to sync
468+
if hasattr(rollout_processor, "aclose"):
469+
await getattr(rollout_processor, "aclose")()
470+
else:
471+
rollout_processor.cleanup()
468472

469473

470474
def sanitize_filename(text: str) -> str:

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def __init__(
4747
self._poll_interval = poll_interval
4848
self._timeout_seconds = timeout_seconds
4949
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
50+
self._session: Optional[aiohttp.ClientSession] = None
51+
self._session_lock = asyncio.Lock()
52+
53+
async def _get_session(self) -> aiohttp.ClientSession:
54+
async with self._session_lock:
55+
if self._session is None or self._session.closed:
56+
self._session = aiohttp.ClientSession()
57+
return self._session
5058

5159
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
5260
tasks: List[asyncio.Task[EvaluationRow]] = []
@@ -88,16 +96,18 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8896

8997
timeout_init = aiohttp.ClientTimeout(total=300)
9098

91-
async with aiohttp.ClientSession() as session:
92-
try:
93-
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
94-
if resp.status >= 400:
95-
body = await resp.text()
96-
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
97-
except asyncio.TimeoutError:
98-
raise TimeoutError(
99-
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
100-
)
99+
try:
100+
session = await self._get_session()
101+
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
102+
if resp.status >= 400:
103+
body = await resp.text()
104+
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
105+
resp.raise_for_status()
106+
await resp.read() # Drain the response body and release the connection back to the pool
107+
except asyncio.TimeoutError:
108+
raise TimeoutError(
109+
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
110+
)
101111

102112
deadline = time.time() + timeout_seconds
103113

@@ -185,5 +195,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
185195
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
186196
return tasks
187197

198+
async def aclose(self) -> None:
199+
"""Async cleanup - preferred when you can await."""
200+
if self._session and not self._session.closed:
201+
await self._session.close()
202+
188203
def cleanup(self) -> None:
189-
return None
204+
"""Sync cleanup - best-effort, schedules close if event loop is running."""
205+
if self._session and not self._session.closed:
206+
try:
207+
loop = asyncio.get_running_loop()
208+
loop.create_task(self._session.close())
209+
except RuntimeError:
210+
# No running event loop - can't safely close the session.
211+
# The session will be garbage collected eventually, but warn about it.
212+
logger.warning(
213+
"RemoteRolloutProcessor.cleanup() called outside of async context. "
214+
"Session may not be properly closed. Use `await processor.aclose()` when possible."
215+
)

eval_protocol/training/gepa_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,11 @@ async def evaluate_with_ep(
503503
}
504504

505505
finally:
506-
rollout_processor.cleanup()
506+
# Prefer async cleanup if available, fall back to sync
507+
if hasattr(rollout_processor, "aclose"):
508+
await getattr(rollout_processor, "aclose")()
509+
else:
510+
rollout_processor.cleanup()
507511

508512
def run_ep_evaluation(
509513
self,

0 commit comments

Comments
 (0)