Follow-up from v0.6 scope-trimming. Deferred from the initial BatchRunner implementation (v0.6 v1 uses stateless map_batches(fn, ...) with a module-level backend cache inside worker processes).
What v0.6 v1 does
BatchRunner runs backend inference as a stateless Ray Data function. A module-level _BACKEND_CACHE inside each worker process means load() fires once per worker rather than once per batch — good enough for small/medium models.
What stateless misses
For models with very expensive load() (FLUX ≈ 30-60s, GraphCast ≈ similar, SDXL ≈ 20-40s), every new Ray worker process pays the full load cost. Ray Data can spin up many short-lived workers under some schedules — the load cost dominates.
Proposal: opt-in actor-pool mode
Add a BatchSpec.compute: Literal["tasks", "actors"] = "tasks" field (or similar).
When compute="actors":
BatchRunner.run() switches to ds.map_batches(_BackendActor, compute=ActorPoolStrategy(size=num_actors), ...)
_BackendActor.__init__ calls backend.load() once per actor
__call__ runs batch_predict per batch
New BatchSpec fields:
num_actors: int — pool size (independent of ResourceConfig.replicas, which is serving semantics)
- Existing
num_gpus / num_cpus → num_gpus_per_actor / num_cpus_per_actor
Lifecycle wrinkles to handle
- Cold start: first batch per actor blocks on
load(). For FLUX that's minutes. Document this; users can pre-warm with a dummy batch if needed.
- GPU reservation: each actor pins its
num_gpus slot for its lifetime. Pool sizing needs clear docs (num_actors * num_gpus_per_actor <= cluster GPUs).
- Shutdown hygiene: Ray Data kills actors at end of
map_batches, but driver crashes can leave actors lingering. Handle with a finally-block cleanup or a periodic-sweep GC.
When to ship
After the v1 JSONL pipeline is proven end-to-end and we've added a GPU-heavy backend to the batch test matrix (likely FLUX or SDXL smoke via Modal).
References
Follow-up from v0.6 scope-trimming. Deferred from the initial
BatchRunnerimplementation (v0.6 v1 uses statelessmap_batches(fn, ...)with a module-level backend cache inside worker processes).What v0.6 v1 does
BatchRunnerruns backend inference as a stateless Ray Data function. A module-level_BACKEND_CACHEinside each worker process meansload()fires once per worker rather than once per batch — good enough for small/medium models.What stateless misses
For models with very expensive
load()(FLUX ≈ 30-60s, GraphCast ≈ similar, SDXL ≈ 20-40s), every new Ray worker process pays the full load cost. Ray Data can spin up many short-lived workers under some schedules — the load cost dominates.Proposal: opt-in actor-pool mode
Add a
BatchSpec.compute: Literal["tasks", "actors"] = "tasks"field (or similar).When
compute="actors":BatchRunner.run()switches tods.map_batches(_BackendActor, compute=ActorPoolStrategy(size=num_actors), ...)_BackendActor.__init__callsbackend.load()once per actor__call__runsbatch_predictper batchNew
BatchSpecfields:num_actors: int— pool size (independent ofResourceConfig.replicas, which is serving semantics)num_gpus/num_cpus→num_gpus_per_actor/num_cpus_per_actorLifecycle wrinkles to handle
load(). For FLUX that's minutes. Document this; users can pre-warm with a dummy batch if needed.num_gpusslot for its lifetime. Pool sizing needs clear docs (num_actors * num_gpus_per_actor <= cluster GPUs).map_batches, but driver crashes can leave actors lingering. Handle with afinally-block cleanup or a periodic-sweep GC.When to ship
After the v1 JSONL pipeline is proven end-to-end and we've added a GPU-heavy backend to the batch test matrix (likely FLUX or SDXL smoke via Modal).
References
Dataset.map_batches(class callable form): https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html