@@ -47,6 +47,14 @@ def __init__(
4747 self ._poll_interval = poll_interval
4848 self ._timeout_seconds = timeout_seconds
4949 self ._tracing_adapter = FireworksTracingAdapter (base_url = self ._model_base_url )
50+ self ._session : Optional [aiohttp .ClientSession ] = None
51+ self ._session_lock = asyncio .Lock ()
52+
53+ async def _get_session (self ) -> aiohttp .ClientSession :
54+ async with self ._session_lock :
55+ if self ._session is None or self ._session .closed :
56+ self ._session = aiohttp .ClientSession ()
57+ return self ._session
5058
5159 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
5260 tasks : List [asyncio .Task [EvaluationRow ]] = []
@@ -88,16 +96,18 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8896
8997 timeout_init = aiohttp .ClientTimeout (total = 300 )
9098
91- async with aiohttp .ClientSession () as session :
92- try :
93- async with session .post (init_url , json = init_payload .model_dump (), timeout = timeout_init ) as resp :
94- if resp .status >= 400 :
95- body = await resp .text ()
96- raise RuntimeError (f"Remote /init failed (HTTP { resp .status } ): { body } " )
97- except asyncio .TimeoutError :
98- raise TimeoutError (
99- f"The /init endpoint tried { init_url } with { init_payload .model_dump ()} but timed out after 300 seconds."
100- )
99+ try :
100+ session = await self ._get_session ()
101+ async with session .post (init_url , json = init_payload .model_dump (), timeout = timeout_init ) as resp :
102+ if resp .status >= 400 :
103+ body = await resp .text ()
104+ raise RuntimeError (f"Remote /init failed (HTTP { resp .status } ): { body } " )
105+ resp .raise_for_status ()
106+ await resp .read () # Drain the response body and release the connection back to the pool
107+ except asyncio .TimeoutError :
108+ raise TimeoutError (
109+ f"The /init endpoint tried { init_url } with { init_payload .model_dump ()} but timed out after 300 seconds."
110+ )
101111
102112 deadline = time .time () + timeout_seconds
103113
@@ -185,5 +195,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
185195 tasks = [asyncio .create_task (_sem_wrapper (row )) for row in rows ]
186196 return tasks
187197
198+ async def aclose (self ) -> None :
199+ """Async cleanup - preferred when you can await."""
200+ if self ._session and not self ._session .closed :
201+ await self ._session .close ()
202+
188203 def cleanup (self ) -> None :
189- return None
204+ """Sync cleanup - best-effort, schedules close if event loop is running."""
205+ if self ._session and not self ._session .closed :
206+ try :
207+ loop = asyncio .get_running_loop ()
208+ loop .create_task (self ._session .close ())
209+ except RuntimeError :
210+ # No running event loop - can't safely close the session.
211+ # The session will be garbage collected eventually, but warn about it.
212+ logger .warning (
213+ "RemoteRolloutProcessor.cleanup() called outside of async context. "
214+ "Session may not be properly closed. Use `await processor.aclose()` when possible."
215+ )
0 commit comments