scheduler orchestration optimization 1#381
Conversation
| extra_body["prediction"] = {"type": "content", "content": " ".join(task.history)[:max_tokens]} | ||
| cp["extra_body"] = extra_body | ||
| row_copy.input_metadata.completion_params = cp | ||
| cp["extra_body"]["prediction"] = " ".join(task.history_snapshot)[:max_tokens] |
There was a problem hiding this comment.
Bug: Shallow copy allows concurrent mutation of shared config
The comment says "Deep copy completion_params to avoid mutating shared config" but dict() only performs a shallow copy. If the base config already contains extra_body (which is common in practice), it will be shared by reference. When multiple speculation tasks from different samples run concurrently, they will all modify the same shared extra_body dict, causing race conditions where one task's prediction value gets overwritten by another. The extra_body dict needs to be explicitly copied or a true deep copy is needed.
| Tracks state for a single dataset sample across multiple runs. | ||
| Enables streaming scheduling where each completed run immediately triggers the next. | ||
| """ | ||
| row: EvaluationRow |
There was a problem hiding this comment.
What's this row for? Do we update this? And for the lock I can see one state shouldn't be consumed by two async tasks?
There was a problem hiding this comment.
this is just a reference and keep track of the history and rollout status for one sample. the row will be deep copied during actual usage
There was a problem hiding this comment.
this samplestate only maintaining the state for each sample.
| tc = time.perf_counter() | ||
| # print(f"run_id {row.execution_metadata.run_id} request_params: {json.dumps(request_params)}") | ||
| response = await acompletion(**request_params) | ||
| print(f"run_id {row.execution_metadata.run_id} time taken: {time.perf_counter() - tc} speculation_enabled: {request_params.get('extra_body', {}).get('prediction', None) is not None}") |
There was a problem hiding this comment.
Bug: Debug print statement left in production code
A debug print statement was left in the production code path. This prints timing information and speculation status for every non-streaming LLM completion call, which will pollute logs and stdout in production. The timing variable tc and the uncommented print at line 103 appear to be debugging/profiling code that wasn't removed before committing.
morgendave
left a comment
There was a problem hiding this comment.
generally looks fine, unblock, need to verify the safety of lock and queue
Refactors the PriorityRolloutScheduler from batch-based to streaming scheduling. Each completed run now immediately schedules the next run, rather than waiting for an entire mini-batch to complete. This improves overall throughput by maximizing concurrent execution.
Key Changes
Streaming Scheduling Architecture
Concurrency Simplification
Priority Queue Behavior
Benefits
Higher throughput: New runs start immediately when a slot opens, no waiting for batch completion
Better resource utilization: Maintains in_group_minibatch_size concurrent runs per sample at all times
Cleaner concurrency model: Single point of concurrency control in the rollout processor
Note
Refactors the rollout scheduler to stream per-run tasks with high/low priorities, adds speculative history injection, and simplifies concurrency to rely on the rollout processor’s semaphore.
SampleStateand redefinesRolloutTaskto a single run with priority(status, row_index, run_index).in_group_minibatch_size(defaults depend onENABLE_SPECULATION).history_snapshotintocompletion_params.extra_body.predictionper run.rollout_processor's semaphore; worker count set tomax_concurrent_rollouts.RolloutTask/SampleStateAPI; adjust expectations for worker scaling, priority ordering, concurrency limits, and groupwise evaluation.default_single_turn_rollout_process.py.Written by Cursor Bugbot for commit fb8debe. This will update automatically on new commits. Configure here.